async_rusqlite.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. use std::marker::PhantomData;
  2. use std::sync::atomic::{AtomicUsize, Ordering};
  3. use std::sync::{mpsc as std_mpsc, Arc, Mutex};
  4. use std::thread::spawn;
  5. use rusqlite::Connection;
  6. use tokio::sync::{mpsc, oneshot};
  7. use crate::common::SqliteConnectionManager;
  8. use crate::mint::Error;
  9. use crate::pool::{Pool, PooledResource};
  10. use crate::stmt::{Column, ExpectedSqlResponse, Statement as InnerStatement, Value};
  11. const SQL_QUEUE_SIZE: usize = 10_000;
  12. const WORKING_THREAD_POOL_SIZE: usize = 5;
  13. #[derive(Debug, Clone)]
  14. pub struct AsyncRusqlite {
  15. sender: mpsc::Sender<DbRequest>,
  16. inflight_requests: Arc<AtomicUsize>,
  17. }
  18. /// Internal request for the database thread
  19. #[derive(Debug)]
  20. pub enum DbRequest {
  21. Sql(InnerStatement, oneshot::Sender<DbResponse>),
  22. Begin(oneshot::Sender<DbResponse>),
  23. Commit(oneshot::Sender<DbResponse>),
  24. Rollback(oneshot::Sender<DbResponse>),
  25. }
  26. #[derive(Debug)]
  27. pub enum DbResponse {
  28. Transaction(mpsc::Sender<DbRequest>),
  29. AffectedRows(usize),
  30. Pluck(Option<Column>),
  31. Row(Option<Vec<Column>>),
  32. Rows(Vec<Vec<Column>>),
  33. Error(Error),
  34. Unexpected,
  35. Ok,
  36. }
  37. /// Statement for the async_rusqlite wrapper
  38. pub struct Statement(InnerStatement);
  39. impl Statement {
  40. /// Bind a variable
  41. pub fn bind<C, V>(self, name: C, value: V) -> Self
  42. where
  43. C: ToString,
  44. V: Into<Value>,
  45. {
  46. Self(self.0.bind(name, value))
  47. }
  48. /// Bind vec
  49. pub fn bind_vec<C, V>(self, name: C, value: Vec<V>) -> Self
  50. where
  51. C: ToString,
  52. V: Into<Value>,
  53. {
  54. Self(self.0.bind_vec(name, value))
  55. }
  56. /// Executes a query and return the number of affected rows
  57. pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
  58. where
  59. C: DatabaseExecutor + Send + Sync,
  60. {
  61. conn.execute(self.0).await
  62. }
  63. /// Returns the first column of the first row of the query result
  64. pub async fn pluck<C>(self, conn: &C) -> Result<Option<Column>, Error>
  65. where
  66. C: DatabaseExecutor + Send + Sync,
  67. {
  68. conn.pluck(self.0).await
  69. }
  70. /// Returns the first row of the query result
  71. pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
  72. where
  73. C: DatabaseExecutor + Send + Sync,
  74. {
  75. conn.fetch_one(self.0).await
  76. }
  77. /// Returns all rows of the query result
  78. pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
  79. where
  80. C: DatabaseExecutor + Send + Sync,
  81. {
  82. conn.fetch_all(self.0).await
  83. }
  84. }
  85. /// Process a query
  86. #[inline(always)]
  87. fn process_query(conn: &Connection, sql: InnerStatement) -> Result<DbResponse, Error> {
  88. let mut stmt = conn.prepare_cached(&sql.sql)?;
  89. for (name, value) in sql.args {
  90. let index = stmt
  91. .parameter_index(&name)
  92. .map_err(|_| Error::MissingParameter(name.clone()))?
  93. .ok_or(Error::MissingParameter(name))?;
  94. stmt.raw_bind_parameter(index, value)?;
  95. }
  96. let columns = stmt.column_count();
  97. Ok(match sql.expected_response {
  98. ExpectedSqlResponse::AffectedRows => DbResponse::AffectedRows(stmt.raw_execute()?),
  99. ExpectedSqlResponse::ManyRows => {
  100. let mut rows = stmt.raw_query();
  101. let mut results = vec![];
  102. while let Some(row) = rows.next()? {
  103. results.push(
  104. (0..columns)
  105. .map(|i| row.get(i))
  106. .collect::<Result<Vec<_>, _>>()?,
  107. )
  108. }
  109. DbResponse::Rows(results)
  110. }
  111. ExpectedSqlResponse::Pluck => {
  112. let mut rows = stmt.raw_query();
  113. DbResponse::Pluck(rows.next()?.map(|row| row.get(0usize)).transpose()?)
  114. }
  115. ExpectedSqlResponse::SingleRow => {
  116. let mut rows = stmt.raw_query();
  117. let row = rows
  118. .next()?
  119. .map(|row| {
  120. (0..columns)
  121. .map(|i| row.get(i))
  122. .collect::<Result<Vec<_>, _>>()
  123. })
  124. .transpose()?;
  125. DbResponse::Row(row)
  126. }
  127. })
  128. }
  129. /// Spawns N number of threads to execute SQL statements
  130. ///
  131. /// Enable parallelism with a pool of threads.
  132. ///
  133. /// There is a main thread, which receives SQL requests and routes them to a worker thread from a
  134. /// fixed-size pool.
  135. ///
  136. /// By doing so, SQLite does synchronization, and Rust will only intervene when a transaction is
  137. /// executed. Transactions are executed in the main thread.
  138. fn rusqlite_spawn_worker_threads(
  139. inflight_requests: Arc<AtomicUsize>,
  140. threads: usize,
  141. ) -> std_mpsc::Sender<(
  142. PooledResource<SqliteConnectionManager>,
  143. InnerStatement,
  144. oneshot::Sender<DbResponse>,
  145. )> {
  146. let (sender, receiver) = std_mpsc::channel::<(
  147. PooledResource<SqliteConnectionManager>,
  148. InnerStatement,
  149. oneshot::Sender<DbResponse>,
  150. )>();
  151. let receiver = Arc::new(Mutex::new(receiver));
  152. for _ in 0..threads {
  153. let rx = receiver.clone();
  154. let inflight_requests = inflight_requests.clone();
  155. spawn(move || loop {
  156. while let Ok((conn, sql, reply_to)) = rx.lock().unwrap().recv() {
  157. tracing::info!("Execute query: {}", sql.sql);
  158. let result = process_query(&conn, sql);
  159. let _ = match result {
  160. Ok(ok) => reply_to.send(ok),
  161. Err(err) => {
  162. tracing::error!("Failed query with error {:?}", err);
  163. reply_to.send(DbResponse::Error(err))
  164. }
  165. };
  166. drop(conn);
  167. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  168. }
  169. });
  170. }
  171. sender
  172. }
  173. /// # Rusqlite main worker
  174. ///
  175. /// This function takes ownership of a pool of connections to SQLite, executes SQL statements, and
  176. /// returns the results or number of affected rows to the caller. All communications are done
  177. /// through channels. This function is synchronous, but a thread pool exists to execute queries, and
  178. /// SQLite will coordinate data access. Transactions are executed in the main and it takes ownership
  179. /// of the main thread until it is finalized
  180. ///
  181. /// This is meant to be called in their thread, as it will not exit the loop until the communication
  182. /// channel is closed.
  183. fn rusqlite_worker_manager(
  184. mut receiver: mpsc::Receiver<DbRequest>,
  185. pool: Arc<Pool<SqliteConnectionManager>>,
  186. inflight_requests: Arc<AtomicUsize>,
  187. ) {
  188. let send_sql_to_thread =
  189. rusqlite_spawn_worker_threads(inflight_requests.clone(), WORKING_THREAD_POOL_SIZE);
  190. let mut tx_id: usize = 0;
  191. while let Some(request) = receiver.blocking_recv() {
  192. inflight_requests.fetch_add(1, Ordering::Relaxed);
  193. match request {
  194. DbRequest::Sql(sql, reply_to) => {
  195. let conn = match pool.get() {
  196. Ok(conn) => conn,
  197. Err(err) => {
  198. tracing::error!("Failed to acquire a pool connection: {:?}", err);
  199. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  200. let _ = reply_to.send(DbResponse::Error(err.into()));
  201. continue;
  202. }
  203. };
  204. let _ = send_sql_to_thread.send((conn, sql, reply_to));
  205. continue;
  206. }
  207. DbRequest::Begin(reply_to) => {
  208. let (sender, mut receiver) = mpsc::channel(SQL_QUEUE_SIZE);
  209. let mut conn = match pool.get() {
  210. Ok(conn) => conn,
  211. Err(err) => {
  212. tracing::error!("Failed to acquire a pool connection: {:?}", err);
  213. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  214. let _ = reply_to.send(DbResponse::Error(err.into()));
  215. continue;
  216. }
  217. };
  218. let tx = match conn.transaction() {
  219. Ok(tx) => tx,
  220. Err(err) => {
  221. tracing::error!("Failed to begin a transaction: {:?}", err);
  222. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  223. let _ = reply_to.send(DbResponse::Error(err.into()));
  224. continue;
  225. }
  226. };
  227. // Transaction has begun successfully, send the `sender` back to the caller
  228. // and wait for statements to execute. On `Drop` the wrapper transaction
  229. // should send a `rollback`.
  230. let _ = reply_to.send(DbResponse::Transaction(sender));
  231. tx_id += 1;
  232. // We intentionally handle the transaction hijacking the main loop, there is
  233. // no point is queueing more operations for SQLite, since transaction have
  234. // exclusive access. In other database implementation this block of code
  235. // should be sent to their own thread to allow concurrency
  236. loop {
  237. let request = if let Some(request) = receiver.blocking_recv() {
  238. request
  239. } else {
  240. // If the receiver loop is broken (i.e no more `senders` are active) and no
  241. // `Commit` statement has been sent, this will trigger a `Rollback`
  242. // automatically
  243. tracing::info!("Tx {}: Transaction rollback on drop", tx_id);
  244. let _ = tx.rollback();
  245. break;
  246. };
  247. match request {
  248. DbRequest::Commit(reply_to) => {
  249. tracing::info!("Tx {}: Commit", tx_id);
  250. let _ = reply_to.send(match tx.commit() {
  251. Ok(()) => DbResponse::Ok,
  252. Err(err) => DbResponse::Error(err.into()),
  253. });
  254. break;
  255. }
  256. DbRequest::Rollback(reply_to) => {
  257. tracing::info!("Tx {}: Rollback", tx_id);
  258. let _ = reply_to.send(match tx.rollback() {
  259. Ok(()) => DbResponse::Ok,
  260. Err(err) => DbResponse::Error(err.into()),
  261. });
  262. break;
  263. }
  264. DbRequest::Begin(reply_to) => {
  265. let _ = reply_to.send(DbResponse::Unexpected);
  266. }
  267. DbRequest::Sql(sql, reply_to) => {
  268. tracing::info!("Tx {}: SQL {}", tx_id, sql.sql);
  269. let _ = match process_query(&tx, sql) {
  270. Ok(ok) => reply_to.send(ok),
  271. Err(err) => reply_to.send(DbResponse::Error(err)),
  272. };
  273. }
  274. }
  275. }
  276. drop(conn);
  277. }
  278. DbRequest::Commit(reply_to) => {
  279. let _ = reply_to.send(DbResponse::Unexpected);
  280. }
  281. DbRequest::Rollback(reply_to) => {
  282. let _ = reply_to.send(DbResponse::Unexpected);
  283. }
  284. }
  285. // If wasn't a `continue` the transaction is done by reaching this code, and we should
  286. // decrease the inflight_request counter
  287. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  288. }
  289. }
  290. #[async_trait::async_trait]
  291. pub trait DatabaseExecutor {
  292. /// Returns the connection to the database thread (or the on-going transaction)
  293. fn get_queue_sender(&self) -> mpsc::Sender<DbRequest>;
  294. /// Executes a query and returns the affected rows
  295. async fn execute(&self, mut statement: InnerStatement) -> Result<usize, Error> {
  296. let (sender, receiver) = oneshot::channel();
  297. statement.expected_response = ExpectedSqlResponse::AffectedRows;
  298. self.get_queue_sender()
  299. .send(DbRequest::Sql(statement, sender))
  300. .await
  301. .map_err(|_| Error::Communication)?;
  302. match receiver.await.map_err(|_| Error::Communication)? {
  303. DbResponse::AffectedRows(n) => Ok(n),
  304. DbResponse::Error(err) => Err(err),
  305. _ => Err(Error::InvalidDbResponse),
  306. }
  307. }
  308. /// Runs the query and returns the first row or None
  309. async fn fetch_one(&self, mut statement: InnerStatement) -> Result<Option<Vec<Column>>, Error> {
  310. let (sender, receiver) = oneshot::channel();
  311. statement.expected_response = ExpectedSqlResponse::SingleRow;
  312. self.get_queue_sender()
  313. .send(DbRequest::Sql(statement, sender))
  314. .await
  315. .map_err(|_| Error::Communication)?;
  316. match receiver.await.map_err(|_| Error::Communication)? {
  317. DbResponse::Row(row) => Ok(row),
  318. DbResponse::Error(err) => Err(err),
  319. _ => Err(Error::InvalidDbResponse),
  320. }
  321. }
  322. /// Runs the query and returns the first row or None
  323. async fn fetch_all(&self, mut statement: InnerStatement) -> Result<Vec<Vec<Column>>, Error> {
  324. let (sender, receiver) = oneshot::channel();
  325. statement.expected_response = ExpectedSqlResponse::ManyRows;
  326. self.get_queue_sender()
  327. .send(DbRequest::Sql(statement, sender))
  328. .await
  329. .map_err(|_| Error::Communication)?;
  330. match receiver.await.map_err(|_| Error::Communication)? {
  331. DbResponse::Rows(rows) => Ok(rows),
  332. DbResponse::Error(err) => Err(err),
  333. _ => Err(Error::InvalidDbResponse),
  334. }
  335. }
  336. async fn pluck(&self, mut statement: InnerStatement) -> Result<Option<Column>, Error> {
  337. let (sender, receiver) = oneshot::channel();
  338. statement.expected_response = ExpectedSqlResponse::Pluck;
  339. self.get_queue_sender()
  340. .send(DbRequest::Sql(statement, sender))
  341. .await
  342. .map_err(|_| Error::Communication)?;
  343. match receiver.await.map_err(|_| Error::Communication)? {
  344. DbResponse::Pluck(value) => Ok(value),
  345. DbResponse::Error(err) => Err(err),
  346. _ => Err(Error::InvalidDbResponse),
  347. }
  348. }
  349. }
  350. #[inline(always)]
  351. pub fn query<T>(sql: T) -> Statement
  352. where
  353. T: ToString,
  354. {
  355. Statement(crate::stmt::Statement::new(sql))
  356. }
  357. impl AsyncRusqlite {
  358. /// Creates a new Async Rusqlite wrapper.
  359. pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
  360. let (sender, receiver) = mpsc::channel(SQL_QUEUE_SIZE);
  361. let inflight_requests = Arc::new(AtomicUsize::new(0));
  362. let inflight_requests_for_thread = inflight_requests.clone();
  363. spawn(move || {
  364. rusqlite_worker_manager(receiver, pool, inflight_requests_for_thread);
  365. });
  366. Self {
  367. sender,
  368. inflight_requests,
  369. }
  370. }
  371. /// Show how many inflight requests
  372. #[allow(dead_code)]
  373. pub fn inflight_requests(&self) -> usize {
  374. self.inflight_requests.load(Ordering::Relaxed)
  375. }
  376. /// Begins a transaction
  377. ///
  378. /// If the transaction is Drop it will trigger a rollback operation
  379. pub async fn begin(&self) -> Result<Transaction<'_>, Error> {
  380. let (sender, receiver) = oneshot::channel();
  381. self.sender
  382. .send(DbRequest::Begin(sender))
  383. .await
  384. .map_err(|_| Error::Communication)?;
  385. match receiver.await.map_err(|_| Error::Communication)? {
  386. DbResponse::Transaction(db_sender) => Ok(Transaction {
  387. db_sender,
  388. _marker: PhantomData,
  389. }),
  390. DbResponse::Error(err) => Err(err),
  391. _ => Err(Error::InvalidDbResponse),
  392. }
  393. }
  394. }
  395. impl DatabaseExecutor for AsyncRusqlite {
  396. #[inline(always)]
  397. fn get_queue_sender(&self) -> mpsc::Sender<DbRequest> {
  398. self.sender.clone()
  399. }
  400. }
  401. pub struct Transaction<'conn> {
  402. db_sender: mpsc::Sender<DbRequest>,
  403. _marker: PhantomData<&'conn ()>,
  404. }
  405. impl Drop for Transaction<'_> {
  406. fn drop(&mut self) {
  407. let (sender, _) = oneshot::channel();
  408. let _ = self.db_sender.try_send(DbRequest::Rollback(sender));
  409. }
  410. }
  411. impl Transaction<'_> {
  412. pub async fn commit(self) -> Result<(), Error> {
  413. let (sender, receiver) = oneshot::channel();
  414. self.db_sender
  415. .send(DbRequest::Commit(sender))
  416. .await
  417. .map_err(|_| Error::Communication)?;
  418. match receiver.await.map_err(|_| Error::Communication)? {
  419. DbResponse::Ok => Ok(()),
  420. DbResponse::Error(err) => Err(err),
  421. _ => Err(Error::InvalidDbResponse),
  422. }
  423. }
  424. pub async fn rollback(self) -> Result<(), Error> {
  425. let (sender, receiver) = oneshot::channel();
  426. self.db_sender
  427. .send(DbRequest::Rollback(sender))
  428. .await
  429. .map_err(|_| Error::Communication)?;
  430. match receiver.await.map_err(|_| Error::Communication)? {
  431. DbResponse::Ok => Ok(()),
  432. DbResponse::Error(err) => Err(err),
  433. _ => Err(Error::InvalidDbResponse),
  434. }
  435. }
  436. }
  437. impl DatabaseExecutor for Transaction<'_> {
  438. /// Get the internal sender to the SQL queue
  439. #[inline(always)]
  440. fn get_queue_sender(&self) -> mpsc::Sender<DbRequest> {
  441. self.db_sender.clone()
  442. }
  443. }