async_rusqlite.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. use std::marker::PhantomData;
  2. use std::sync::atomic::{AtomicUsize, Ordering};
  3. use std::sync::{mpsc as std_mpsc, Arc, Mutex};
  4. use std::thread::spawn;
  5. use std::time::Instant;
  6. use rusqlite::Connection;
  7. use tokio::sync::{mpsc, oneshot};
  8. use crate::common::SqliteConnectionManager;
  9. use crate::mint::Error;
  10. use crate::pool::{Pool, PooledResource};
  11. use crate::stmt::{Column, ExpectedSqlResponse, Statement as InnerStatement, Value};
  12. /// The number of queued SQL statements before it start failing
  13. const SQL_QUEUE_SIZE: usize = 10_000;
  14. /// How many ms is considered a slow query, and it'd be logged for further debugging
  15. const SLOW_QUERY_THRESHOLD_MS: u128 = 20;
  16. /// How many SQLite parallel connections can be used to read things in parallel
  17. const WORKING_THREAD_POOL_SIZE: usize = 5;
  18. #[derive(Debug, Clone)]
  19. pub struct AsyncRusqlite {
  20. sender: mpsc::Sender<DbRequest>,
  21. inflight_requests: Arc<AtomicUsize>,
  22. }
  23. /// Internal request for the database thread
  24. #[derive(Debug)]
  25. pub enum DbRequest {
  26. Sql(InnerStatement, oneshot::Sender<DbResponse>),
  27. Begin(oneshot::Sender<DbResponse>),
  28. Commit(oneshot::Sender<DbResponse>),
  29. Rollback(oneshot::Sender<DbResponse>),
  30. }
  31. #[derive(Debug)]
  32. pub enum DbResponse {
  33. Transaction(mpsc::Sender<DbRequest>),
  34. AffectedRows(usize),
  35. Pluck(Option<Column>),
  36. Row(Option<Vec<Column>>),
  37. Rows(Vec<Vec<Column>>),
  38. Error(Error),
  39. Unexpected,
  40. Ok,
  41. }
  42. /// Statement for the async_rusqlite wrapper
  43. pub struct Statement(InnerStatement);
  44. impl Statement {
  45. /// Bind a variable
  46. pub fn bind<C, V>(self, name: C, value: V) -> Self
  47. where
  48. C: ToString,
  49. V: Into<Value>,
  50. {
  51. Self(self.0.bind(name, value))
  52. }
  53. /// Bind vec
  54. pub fn bind_vec<C, V>(self, name: C, value: Vec<V>) -> Self
  55. where
  56. C: ToString,
  57. V: Into<Value>,
  58. {
  59. Self(self.0.bind_vec(name, value))
  60. }
  61. /// Executes a query and return the number of affected rows
  62. pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
  63. where
  64. C: DatabaseExecutor + Send + Sync,
  65. {
  66. conn.execute(self.0).await
  67. }
  68. /// Returns the first column of the first row of the query result
  69. pub async fn pluck<C>(self, conn: &C) -> Result<Option<Column>, Error>
  70. where
  71. C: DatabaseExecutor + Send + Sync,
  72. {
  73. conn.pluck(self.0).await
  74. }
  75. /// Returns the first row of the query result
  76. pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
  77. where
  78. C: DatabaseExecutor + Send + Sync,
  79. {
  80. conn.fetch_one(self.0).await
  81. }
  82. /// Returns all rows of the query result
  83. pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
  84. where
  85. C: DatabaseExecutor + Send + Sync,
  86. {
  87. conn.fetch_all(self.0).await
  88. }
  89. }
  90. /// Process a query
  91. #[inline(always)]
  92. fn process_query(conn: &Connection, sql: InnerStatement) -> Result<DbResponse, Error> {
  93. let start = Instant::now();
  94. let mut args = sql.args;
  95. let mut stmt = conn.prepare_cached(&sql.sql)?;
  96. let total_parameters = stmt.parameter_count();
  97. for index in 1..=total_parameters {
  98. let value = if let Some(value) = stmt.parameter_name(index).map(|name| {
  99. args.remove(name)
  100. .ok_or(Error::MissingParameter(name.to_owned()))
  101. }) {
  102. value?
  103. } else {
  104. continue;
  105. };
  106. stmt.raw_bind_parameter(index, value)?;
  107. }
  108. let columns = stmt.column_count();
  109. let to_return = match sql.expected_response {
  110. ExpectedSqlResponse::AffectedRows => DbResponse::AffectedRows(stmt.raw_execute()?),
  111. ExpectedSqlResponse::ManyRows => {
  112. let mut rows = stmt.raw_query();
  113. let mut results = vec![];
  114. while let Some(row) = rows.next()? {
  115. results.push(
  116. (0..columns)
  117. .map(|i| row.get(i))
  118. .collect::<Result<Vec<_>, _>>()?,
  119. )
  120. }
  121. DbResponse::Rows(results)
  122. }
  123. ExpectedSqlResponse::Pluck => {
  124. let mut rows = stmt.raw_query();
  125. DbResponse::Pluck(rows.next()?.map(|row| row.get(0usize)).transpose()?)
  126. }
  127. ExpectedSqlResponse::SingleRow => {
  128. let mut rows = stmt.raw_query();
  129. let row = rows
  130. .next()?
  131. .map(|row| {
  132. (0..columns)
  133. .map(|i| row.get(i))
  134. .collect::<Result<Vec<_>, _>>()
  135. })
  136. .transpose()?;
  137. DbResponse::Row(row)
  138. }
  139. };
  140. let duration = start.elapsed();
  141. if duration.as_millis() > SLOW_QUERY_THRESHOLD_MS {
  142. tracing::error!("[SLOW QUERY] Took {} ms: {}", duration.as_millis(), sql.sql);
  143. }
  144. Ok(to_return)
  145. }
  146. /// Spawns N number of threads to execute SQL statements
  147. ///
  148. /// Enable parallelism with a pool of threads.
  149. ///
  150. /// There is a main thread, which receives SQL requests and routes them to a worker thread from a
  151. /// fixed-size pool.
  152. ///
  153. /// By doing so, SQLite does synchronization, and Rust will only intervene when a transaction is
  154. /// executed. Transactions are executed in the main thread.
  155. fn rusqlite_spawn_worker_threads(
  156. inflight_requests: Arc<AtomicUsize>,
  157. threads: usize,
  158. ) -> std_mpsc::Sender<(
  159. PooledResource<SqliteConnectionManager>,
  160. InnerStatement,
  161. oneshot::Sender<DbResponse>,
  162. )> {
  163. let (sender, receiver) = std_mpsc::channel::<(
  164. PooledResource<SqliteConnectionManager>,
  165. InnerStatement,
  166. oneshot::Sender<DbResponse>,
  167. )>();
  168. let receiver = Arc::new(Mutex::new(receiver));
  169. for _ in 0..threads {
  170. let rx = receiver.clone();
  171. let inflight_requests = inflight_requests.clone();
  172. spawn(move || loop {
  173. while let Ok((conn, sql, reply_to)) = rx.lock().unwrap().recv() {
  174. tracing::info!("Execute query: {}", sql.sql);
  175. let result = process_query(&conn, sql);
  176. let _ = match result {
  177. Ok(ok) => reply_to.send(ok),
  178. Err(err) => {
  179. tracing::error!("Failed query with error {:?}", err);
  180. reply_to.send(DbResponse::Error(err))
  181. }
  182. };
  183. drop(conn);
  184. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  185. }
  186. });
  187. }
  188. sender
  189. }
  190. /// # Rusqlite main worker
  191. ///
  192. /// This function takes ownership of a pool of connections to SQLite, executes SQL statements, and
  193. /// returns the results or number of affected rows to the caller. All communications are done
  194. /// through channels. This function is synchronous, but a thread pool exists to execute queries, and
  195. /// SQLite will coordinate data access. Transactions are executed in the main and it takes ownership
  196. /// of the main thread until it is finalized
  197. ///
  198. /// This is meant to be called in their thread, as it will not exit the loop until the communication
  199. /// channel is closed.
  200. fn rusqlite_worker_manager(
  201. mut receiver: mpsc::Receiver<DbRequest>,
  202. pool: Arc<Pool<SqliteConnectionManager>>,
  203. inflight_requests: Arc<AtomicUsize>,
  204. ) {
  205. let send_sql_to_thread =
  206. rusqlite_spawn_worker_threads(inflight_requests.clone(), WORKING_THREAD_POOL_SIZE);
  207. let mut tx_id: usize = 0;
  208. while let Some(request) = receiver.blocking_recv() {
  209. inflight_requests.fetch_add(1, Ordering::Relaxed);
  210. match request {
  211. DbRequest::Sql(sql, reply_to) => {
  212. let conn = match pool.get() {
  213. Ok(conn) => conn,
  214. Err(err) => {
  215. tracing::error!("Failed to acquire a pool connection: {:?}", err);
  216. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  217. let _ = reply_to.send(DbResponse::Error(err.into()));
  218. continue;
  219. }
  220. };
  221. let _ = send_sql_to_thread.send((conn, sql, reply_to));
  222. continue;
  223. }
  224. DbRequest::Begin(reply_to) => {
  225. let (sender, mut receiver) = mpsc::channel(SQL_QUEUE_SIZE);
  226. let mut conn = match pool.get() {
  227. Ok(conn) => conn,
  228. Err(err) => {
  229. tracing::error!("Failed to acquire a pool connection: {:?}", err);
  230. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  231. let _ = reply_to.send(DbResponse::Error(err.into()));
  232. continue;
  233. }
  234. };
  235. let tx = match conn.transaction() {
  236. Ok(tx) => tx,
  237. Err(err) => {
  238. tracing::error!("Failed to begin a transaction: {:?}", err);
  239. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  240. let _ = reply_to.send(DbResponse::Error(err.into()));
  241. continue;
  242. }
  243. };
  244. // Transaction has begun successfully, send the `sender` back to the caller
  245. // and wait for statements to execute. On `Drop` the wrapper transaction
  246. // should send a `rollback`.
  247. let _ = reply_to.send(DbResponse::Transaction(sender));
  248. tx_id += 1;
  249. // We intentionally handle the transaction hijacking the main loop, there is
  250. // no point is queueing more operations for SQLite, since transaction have
  251. // exclusive access. In other database implementation this block of code
  252. // should be sent to their own thread to allow concurrency
  253. loop {
  254. let request = if let Some(request) = receiver.blocking_recv() {
  255. request
  256. } else {
  257. // If the receiver loop is broken (i.e no more `senders` are active) and no
  258. // `Commit` statement has been sent, this will trigger a `Rollback`
  259. // automatically
  260. tracing::info!("Tx {}: Transaction rollback on drop", tx_id);
  261. let _ = tx.rollback();
  262. break;
  263. };
  264. match request {
  265. DbRequest::Commit(reply_to) => {
  266. tracing::info!("Tx {}: Commit", tx_id);
  267. let _ = reply_to.send(match tx.commit() {
  268. Ok(()) => DbResponse::Ok,
  269. Err(err) => DbResponse::Error(err.into()),
  270. });
  271. break;
  272. }
  273. DbRequest::Rollback(reply_to) => {
  274. tracing::info!("Tx {}: Rollback", tx_id);
  275. let _ = reply_to.send(match tx.rollback() {
  276. Ok(()) => DbResponse::Ok,
  277. Err(err) => DbResponse::Error(err.into()),
  278. });
  279. break;
  280. }
  281. DbRequest::Begin(reply_to) => {
  282. let _ = reply_to.send(DbResponse::Unexpected);
  283. }
  284. DbRequest::Sql(sql, reply_to) => {
  285. tracing::info!("Tx {}: SQL {}", tx_id, sql.sql);
  286. let _ = match process_query(&tx, sql) {
  287. Ok(ok) => reply_to.send(ok),
  288. Err(err) => reply_to.send(DbResponse::Error(err)),
  289. };
  290. }
  291. }
  292. }
  293. drop(conn);
  294. }
  295. DbRequest::Commit(reply_to) => {
  296. let _ = reply_to.send(DbResponse::Unexpected);
  297. }
  298. DbRequest::Rollback(reply_to) => {
  299. let _ = reply_to.send(DbResponse::Unexpected);
  300. }
  301. }
  302. // If wasn't a `continue` the transaction is done by reaching this code, and we should
  303. // decrease the inflight_request counter
  304. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  305. }
  306. }
  307. #[async_trait::async_trait]
  308. pub trait DatabaseExecutor {
  309. /// Returns the connection to the database thread (or the on-going transaction)
  310. fn get_queue_sender(&self) -> mpsc::Sender<DbRequest>;
  311. /// Executes a query and returns the affected rows
  312. async fn execute(&self, mut statement: InnerStatement) -> Result<usize, Error> {
  313. let (sender, receiver) = oneshot::channel();
  314. statement.expected_response = ExpectedSqlResponse::AffectedRows;
  315. self.get_queue_sender()
  316. .send(DbRequest::Sql(statement, sender))
  317. .await
  318. .map_err(|_| Error::Communication)?;
  319. match receiver.await.map_err(|_| Error::Communication)? {
  320. DbResponse::AffectedRows(n) => Ok(n),
  321. DbResponse::Error(err) => Err(err),
  322. _ => Err(Error::InvalidDbResponse),
  323. }
  324. }
  325. /// Runs the query and returns the first row or None
  326. async fn fetch_one(&self, mut statement: InnerStatement) -> Result<Option<Vec<Column>>, Error> {
  327. let (sender, receiver) = oneshot::channel();
  328. statement.expected_response = ExpectedSqlResponse::SingleRow;
  329. self.get_queue_sender()
  330. .send(DbRequest::Sql(statement, sender))
  331. .await
  332. .map_err(|_| Error::Communication)?;
  333. match receiver.await.map_err(|_| Error::Communication)? {
  334. DbResponse::Row(row) => Ok(row),
  335. DbResponse::Error(err) => Err(err),
  336. _ => Err(Error::InvalidDbResponse),
  337. }
  338. }
  339. /// Runs the query and returns the first row or None
  340. async fn fetch_all(&self, mut statement: InnerStatement) -> Result<Vec<Vec<Column>>, Error> {
  341. let (sender, receiver) = oneshot::channel();
  342. statement.expected_response = ExpectedSqlResponse::ManyRows;
  343. self.get_queue_sender()
  344. .send(DbRequest::Sql(statement, sender))
  345. .await
  346. .map_err(|_| Error::Communication)?;
  347. match receiver.await.map_err(|_| Error::Communication)? {
  348. DbResponse::Rows(rows) => Ok(rows),
  349. DbResponse::Error(err) => Err(err),
  350. _ => Err(Error::InvalidDbResponse),
  351. }
  352. }
  353. async fn pluck(&self, mut statement: InnerStatement) -> Result<Option<Column>, Error> {
  354. let (sender, receiver) = oneshot::channel();
  355. statement.expected_response = ExpectedSqlResponse::Pluck;
  356. self.get_queue_sender()
  357. .send(DbRequest::Sql(statement, sender))
  358. .await
  359. .map_err(|_| Error::Communication)?;
  360. match receiver.await.map_err(|_| Error::Communication)? {
  361. DbResponse::Pluck(value) => Ok(value),
  362. DbResponse::Error(err) => Err(err),
  363. _ => Err(Error::InvalidDbResponse),
  364. }
  365. }
  366. }
  367. #[inline(always)]
  368. pub fn query<T>(sql: T) -> Statement
  369. where
  370. T: ToString,
  371. {
  372. Statement(crate::stmt::Statement::new(sql))
  373. }
  374. impl AsyncRusqlite {
  375. /// Creates a new Async Rusqlite wrapper.
  376. pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
  377. let (sender, receiver) = mpsc::channel(SQL_QUEUE_SIZE);
  378. let inflight_requests = Arc::new(AtomicUsize::new(0));
  379. let inflight_requests_for_thread = inflight_requests.clone();
  380. spawn(move || {
  381. rusqlite_worker_manager(receiver, pool, inflight_requests_for_thread);
  382. });
  383. Self {
  384. sender,
  385. inflight_requests,
  386. }
  387. }
  388. /// Show how many inflight requests
  389. #[allow(dead_code)]
  390. pub fn inflight_requests(&self) -> usize {
  391. self.inflight_requests.load(Ordering::Relaxed)
  392. }
  393. /// Begins a transaction
  394. ///
  395. /// If the transaction is Drop it will trigger a rollback operation
  396. pub async fn begin(&self) -> Result<Transaction<'_>, Error> {
  397. let (sender, receiver) = oneshot::channel();
  398. self.sender
  399. .send(DbRequest::Begin(sender))
  400. .await
  401. .map_err(|_| Error::Communication)?;
  402. match receiver.await.map_err(|_| Error::Communication)? {
  403. DbResponse::Transaction(db_sender) => Ok(Transaction {
  404. db_sender,
  405. _marker: PhantomData,
  406. }),
  407. DbResponse::Error(err) => Err(err),
  408. _ => Err(Error::InvalidDbResponse),
  409. }
  410. }
  411. }
  412. impl DatabaseExecutor for AsyncRusqlite {
  413. #[inline(always)]
  414. fn get_queue_sender(&self) -> mpsc::Sender<DbRequest> {
  415. self.sender.clone()
  416. }
  417. }
  418. pub struct Transaction<'conn> {
  419. db_sender: mpsc::Sender<DbRequest>,
  420. _marker: PhantomData<&'conn ()>,
  421. }
  422. impl Drop for Transaction<'_> {
  423. fn drop(&mut self) {
  424. let (sender, _) = oneshot::channel();
  425. let _ = self.db_sender.try_send(DbRequest::Rollback(sender));
  426. }
  427. }
  428. impl Transaction<'_> {
  429. pub async fn commit(self) -> Result<(), Error> {
  430. let (sender, receiver) = oneshot::channel();
  431. self.db_sender
  432. .send(DbRequest::Commit(sender))
  433. .await
  434. .map_err(|_| Error::Communication)?;
  435. match receiver.await.map_err(|_| Error::Communication)? {
  436. DbResponse::Ok => Ok(()),
  437. DbResponse::Error(err) => Err(err),
  438. _ => Err(Error::InvalidDbResponse),
  439. }
  440. }
  441. pub async fn rollback(self) -> Result<(), Error> {
  442. let (sender, receiver) = oneshot::channel();
  443. self.db_sender
  444. .send(DbRequest::Rollback(sender))
  445. .await
  446. .map_err(|_| Error::Communication)?;
  447. match receiver.await.map_err(|_| Error::Communication)? {
  448. DbResponse::Ok => Ok(()),
  449. DbResponse::Error(err) => Err(err),
  450. _ => Err(Error::InvalidDbResponse),
  451. }
  452. }
  453. }
  454. impl DatabaseExecutor for Transaction<'_> {
  455. /// Get the internal sender to the SQL queue
  456. #[inline(always)]
  457. fn get_queue_sender(&self) -> mpsc::Sender<DbRequest> {
  458. self.db_sender.clone()
  459. }
  460. }