async_rusqlite.rs 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. //! Async, pipelined rusqlite client
  2. use std::marker::PhantomData;
  3. use std::path::PathBuf;
  4. use std::sync::atomic::{AtomicUsize, Ordering};
  5. use std::sync::{mpsc as std_mpsc, Arc, Mutex};
  6. use std::thread::spawn;
  7. use std::time::Instant;
  8. use cdk_common::database::Error;
  9. use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, DatabaseTransaction};
  10. use cdk_sql_common::pool::{self, Pool, PooledResource};
  11. use cdk_sql_common::stmt::{Column, ExpectedSqlResponse, Statement as InnerStatement};
  12. use cdk_sql_common::ConversionError;
  13. use rusqlite::{ffi, Connection, ErrorCode, TransactionBehavior};
  14. use tokio::sync::{mpsc, oneshot};
  15. use crate::common::{create_sqlite_pool, from_sqlite, to_sqlite, SqliteConnectionManager};
  16. /// The number of queued SQL statements before it start failing
  17. const SQL_QUEUE_SIZE: usize = 10_000;
  18. /// How many ms is considered a slow query, and it'd be logged for further debugging
  19. const SLOW_QUERY_THRESHOLD_MS: u128 = 20;
  20. /// How many SQLite parallel connections can be used to read things in parallel
  21. const WORKING_THREAD_POOL_SIZE: usize = 5;
  22. #[derive(Debug, Clone)]
  23. pub struct AsyncRusqlite {
  24. sender: mpsc::Sender<DbRequest>,
  25. inflight_requests: Arc<AtomicUsize>,
  26. }
  27. impl From<PathBuf> for AsyncRusqlite {
  28. fn from(value: PathBuf) -> Self {
  29. AsyncRusqlite::new(create_sqlite_pool(value.to_str().unwrap_or_default(), None))
  30. }
  31. }
  32. impl From<&str> for AsyncRusqlite {
  33. fn from(value: &str) -> Self {
  34. AsyncRusqlite::new(create_sqlite_pool(value, None))
  35. }
  36. }
  37. impl From<(&str, &str)> for AsyncRusqlite {
  38. fn from((value, pass): (&str, &str)) -> Self {
  39. AsyncRusqlite::new(create_sqlite_pool(value, Some(pass.to_owned())))
  40. }
  41. }
  42. impl From<(PathBuf, &str)> for AsyncRusqlite {
  43. fn from((value, pass): (PathBuf, &str)) -> Self {
  44. AsyncRusqlite::new(create_sqlite_pool(
  45. value.to_str().unwrap_or_default(),
  46. Some(pass.to_owned()),
  47. ))
  48. }
  49. }
  50. impl From<(&str, String)> for AsyncRusqlite {
  51. fn from((value, pass): (&str, String)) -> Self {
  52. AsyncRusqlite::new(create_sqlite_pool(value, Some(pass)))
  53. }
  54. }
  55. impl From<(PathBuf, String)> for AsyncRusqlite {
  56. fn from((value, pass): (PathBuf, String)) -> Self {
  57. AsyncRusqlite::new(create_sqlite_pool(
  58. value.to_str().unwrap_or_default(),
  59. Some(pass),
  60. ))
  61. }
  62. }
  63. impl From<&PathBuf> for AsyncRusqlite {
  64. fn from(value: &PathBuf) -> Self {
  65. AsyncRusqlite::new(create_sqlite_pool(value.to_str().unwrap_or_default(), None))
  66. }
  67. }
  68. /// Internal request for the database thread
  69. #[derive(Debug)]
  70. enum DbRequest {
  71. Sql(InnerStatement, oneshot::Sender<DbResponse>),
  72. Begin(oneshot::Sender<DbResponse>),
  73. Commit(oneshot::Sender<DbResponse>),
  74. Rollback(oneshot::Sender<DbResponse>),
  75. }
  76. #[derive(Debug)]
  77. enum DbResponse {
  78. Transaction(mpsc::Sender<DbRequest>),
  79. AffectedRows(usize),
  80. Pluck(Option<Column>),
  81. Row(Option<Vec<Column>>),
  82. Rows(Vec<Vec<Column>>),
  83. Error(SqliteError),
  84. Unexpected,
  85. Ok,
  86. }
  87. #[derive(thiserror::Error, Debug)]
  88. enum SqliteError {
  89. #[error(transparent)]
  90. Sqlite(#[from] rusqlite::Error),
  91. #[error(transparent)]
  92. Inner(#[from] Error),
  93. #[error(transparent)]
  94. Pool(#[from] pool::Error<rusqlite::Error>),
  95. /// Duplicate entry
  96. #[error("Duplicate")]
  97. Duplicate,
  98. #[error(transparent)]
  99. Conversion(#[from] ConversionError),
  100. }
  101. impl From<SqliteError> for Error {
  102. fn from(val: SqliteError) -> Self {
  103. match val {
  104. SqliteError::Duplicate => Error::Duplicate,
  105. SqliteError::Conversion(e) => e.into(),
  106. o => Error::Internal(o.to_string()),
  107. }
  108. }
  109. }
  110. /// Process a query
  111. #[inline(always)]
  112. fn process_query(conn: &Connection, statement: InnerStatement) -> Result<DbResponse, SqliteError> {
  113. let start = Instant::now();
  114. let expected_response = statement.expected_response;
  115. let (sql, placeholder_values) = statement.to_sql()?;
  116. let sql = sql.trim_end_matches("FOR UPDATE");
  117. let mut stmt = conn.prepare_cached(sql)?;
  118. for (i, value) in placeholder_values.into_iter().enumerate() {
  119. stmt.raw_bind_parameter(i + 1, to_sqlite(value))?;
  120. }
  121. let columns = stmt.column_count();
  122. let to_return = match expected_response {
  123. ExpectedSqlResponse::AffectedRows => DbResponse::AffectedRows(stmt.raw_execute()?),
  124. ExpectedSqlResponse::Batch => {
  125. conn.execute_batch(sql)?;
  126. DbResponse::Ok
  127. }
  128. ExpectedSqlResponse::ManyRows => {
  129. let mut rows = stmt.raw_query();
  130. let mut results = vec![];
  131. while let Some(row) = rows.next()? {
  132. results.push(
  133. (0..columns)
  134. .map(|i| row.get(i).map(from_sqlite))
  135. .collect::<Result<Vec<_>, _>>()?,
  136. )
  137. }
  138. DbResponse::Rows(results)
  139. }
  140. ExpectedSqlResponse::Pluck => {
  141. let mut rows = stmt.raw_query();
  142. DbResponse::Pluck(
  143. rows.next()?
  144. .map(|row| row.get(0usize).map(from_sqlite))
  145. .transpose()?,
  146. )
  147. }
  148. ExpectedSqlResponse::SingleRow => {
  149. let mut rows = stmt.raw_query();
  150. let row = rows
  151. .next()?
  152. .map(|row| {
  153. (0..columns)
  154. .map(|i| row.get(i).map(from_sqlite))
  155. .collect::<Result<Vec<_>, _>>()
  156. })
  157. .transpose()?;
  158. DbResponse::Row(row)
  159. }
  160. };
  161. let duration = start.elapsed();
  162. if duration.as_millis() > SLOW_QUERY_THRESHOLD_MS {
  163. tracing::warn!("[SLOW QUERY] Took {} ms: {}", duration.as_millis(), sql);
  164. }
  165. Ok(to_return)
  166. }
  167. /// Spawns N number of threads to execute SQL statements
  168. ///
  169. /// Enable parallelism with a pool of threads.
  170. ///
  171. /// There is a main thread, which receives SQL requests and routes them to a worker thread from a
  172. /// fixed-size pool.
  173. ///
  174. /// By doing so, SQLite does synchronization, and Rust will only intervene when a transaction is
  175. /// executed. Transactions are executed in the main thread.
  176. fn rusqlite_spawn_worker_threads(
  177. inflight_requests: Arc<AtomicUsize>,
  178. threads: usize,
  179. ) -> std_mpsc::Sender<(
  180. PooledResource<SqliteConnectionManager>,
  181. InnerStatement,
  182. oneshot::Sender<DbResponse>,
  183. )> {
  184. let (sender, receiver) = std_mpsc::channel::<(
  185. PooledResource<SqliteConnectionManager>,
  186. InnerStatement,
  187. oneshot::Sender<DbResponse>,
  188. )>();
  189. let receiver = Arc::new(Mutex::new(receiver));
  190. for _ in 0..threads {
  191. let rx = receiver.clone();
  192. let inflight_requests = inflight_requests.clone();
  193. spawn(move || loop {
  194. while let Ok((conn, sql, reply_to)) = rx.lock().expect("failed to acquire").recv() {
  195. let result = process_query(&conn, sql);
  196. let _ = match result {
  197. Ok(ok) => reply_to.send(ok),
  198. Err(err) => {
  199. tracing::error!("Failed query with error {:?}", err);
  200. let err = if let SqliteError::Sqlite(rusqlite::Error::SqliteFailure(
  201. ffi::Error {
  202. code,
  203. extended_code,
  204. },
  205. _,
  206. )) = &err
  207. {
  208. if *code == ErrorCode::ConstraintViolation
  209. && (*extended_code == ffi::SQLITE_CONSTRAINT_PRIMARYKEY
  210. || *extended_code == ffi::SQLITE_CONSTRAINT_UNIQUE)
  211. {
  212. SqliteError::Duplicate
  213. } else {
  214. err
  215. }
  216. } else {
  217. err
  218. };
  219. reply_to.send(DbResponse::Error(err))
  220. }
  221. };
  222. drop(conn);
  223. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  224. }
  225. });
  226. }
  227. sender
  228. }
  229. /// # Rusqlite main worker
  230. ///
  231. /// This function takes ownership of a pool of connections to SQLite, executes SQL statements, and
  232. /// returns the results or number of affected rows to the caller. All communications are done
  233. /// through channels. This function is synchronous, but a thread pool exists to execute queries, and
  234. /// SQLite will coordinate data access. Transactions are executed in the main and it takes ownership
  235. /// of the main thread until it is finalized
  236. ///
  237. /// This is meant to be called in their thread, as it will not exit the loop until the communication
  238. /// channel is closed.
  239. fn rusqlite_worker_manager(
  240. mut receiver: mpsc::Receiver<DbRequest>,
  241. pool: Arc<Pool<SqliteConnectionManager>>,
  242. inflight_requests: Arc<AtomicUsize>,
  243. ) {
  244. let send_sql_to_thread =
  245. rusqlite_spawn_worker_threads(inflight_requests.clone(), WORKING_THREAD_POOL_SIZE);
  246. let mut tx_id: usize = 0;
  247. while let Some(request) = receiver.blocking_recv() {
  248. inflight_requests.fetch_add(1, Ordering::Relaxed);
  249. match request {
  250. DbRequest::Sql(statement, reply_to) => {
  251. let conn = match pool.get() {
  252. Ok(conn) => conn,
  253. Err(err) => {
  254. tracing::error!("Failed to acquire a pool connection: {:?}", err);
  255. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  256. let _ = reply_to.send(DbResponse::Error(err.into()));
  257. continue;
  258. }
  259. };
  260. let _ = send_sql_to_thread.send((conn, statement, reply_to));
  261. continue;
  262. }
  263. DbRequest::Begin(reply_to) => {
  264. let (sender, mut receiver) = mpsc::channel(SQL_QUEUE_SIZE);
  265. let mut conn = match pool.get() {
  266. Ok(conn) => conn,
  267. Err(err) => {
  268. tracing::error!("Failed to acquire a pool connection: {:?}", err);
  269. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  270. let _ = reply_to.send(DbResponse::Error(err.into()));
  271. continue;
  272. }
  273. };
  274. let tx = match conn.transaction_with_behavior(TransactionBehavior::Immediate) {
  275. Ok(tx) => tx,
  276. Err(err) => {
  277. tracing::error!("Failed to begin a transaction: {:?}", err);
  278. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  279. let _ = reply_to.send(DbResponse::Error(err.into()));
  280. continue;
  281. }
  282. };
  283. // Transaction has begun successfully, send the `sender` back to the caller
  284. // and wait for statements to execute. On `Drop` the wrapper transaction
  285. // should send a `rollback`.
  286. let _ = reply_to.send(DbResponse::Transaction(sender));
  287. tx_id += 1;
  288. // We intentionally handle the transaction hijacking the main loop, there is
  289. // no point is queueing more operations for SQLite, since transaction have
  290. // exclusive access. In other database implementation this block of code
  291. // should be sent to their own thread to allow concurrency
  292. loop {
  293. let request = if let Some(request) = receiver.blocking_recv() {
  294. request
  295. } else {
  296. // If the receiver loop is broken (i.e no more `senders` are active) and no
  297. // `Commit` statement has been sent, this will trigger a `Rollback`
  298. // automatically
  299. tracing::trace!("Tx {}: Transaction rollback on drop", tx_id);
  300. let _ = tx.rollback();
  301. break;
  302. };
  303. match request {
  304. DbRequest::Commit(reply_to) => {
  305. tracing::trace!("Tx {}: Commit", tx_id);
  306. let _ = reply_to.send(match tx.commit() {
  307. Ok(()) => DbResponse::Ok,
  308. Err(err) => {
  309. tracing::error!("Failed commit {:?}", err);
  310. DbResponse::Error(err.into())
  311. }
  312. });
  313. break;
  314. }
  315. DbRequest::Rollback(reply_to) => {
  316. tracing::trace!("Tx {}: Rollback", tx_id);
  317. let _ = reply_to.send(match tx.rollback() {
  318. Ok(()) => DbResponse::Ok,
  319. Err(err) => {
  320. tracing::error!("Failed rollback {:?}", err);
  321. DbResponse::Error(err.into())
  322. }
  323. });
  324. break;
  325. }
  326. DbRequest::Begin(reply_to) => {
  327. let _ = reply_to.send(DbResponse::Unexpected);
  328. }
  329. DbRequest::Sql(statement, reply_to) => {
  330. tracing::trace!("Tx {}: SQL {:?}", tx_id, statement);
  331. let _ = match process_query(&tx, statement) {
  332. Ok(ok) => reply_to.send(ok),
  333. Err(err) => {
  334. tracing::error!(
  335. "Tx {}: Failed query with error {:?}",
  336. tx_id,
  337. err
  338. );
  339. let err = if let SqliteError::Sqlite(
  340. rusqlite::Error::SqliteFailure(
  341. ffi::Error {
  342. code,
  343. extended_code,
  344. },
  345. _,
  346. ),
  347. ) = &err
  348. {
  349. if *code == ErrorCode::ConstraintViolation
  350. && (*extended_code == ffi::SQLITE_CONSTRAINT_PRIMARYKEY
  351. || *extended_code == ffi::SQLITE_CONSTRAINT_UNIQUE)
  352. {
  353. SqliteError::Duplicate
  354. } else {
  355. err
  356. }
  357. } else {
  358. err
  359. };
  360. reply_to.send(DbResponse::Error(err))
  361. }
  362. };
  363. }
  364. }
  365. }
  366. drop(conn);
  367. }
  368. DbRequest::Commit(reply_to) => {
  369. let _ = reply_to.send(DbResponse::Unexpected);
  370. }
  371. DbRequest::Rollback(reply_to) => {
  372. let _ = reply_to.send(DbResponse::Unexpected);
  373. }
  374. }
  375. // If wasn't a `continue` the transaction is done by reaching this code, and we should
  376. // decrease the inflight_request counter
  377. inflight_requests.fetch_sub(1, Ordering::Relaxed);
  378. }
  379. }
  380. impl AsyncRusqlite {
  381. /// Creates a new Async Rusqlite wrapper.
  382. pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
  383. let (sender, receiver) = mpsc::channel(SQL_QUEUE_SIZE);
  384. let inflight_requests = Arc::new(AtomicUsize::new(0));
  385. let inflight_requests_for_thread = inflight_requests.clone();
  386. spawn(move || {
  387. rusqlite_worker_manager(receiver, pool, inflight_requests_for_thread);
  388. });
  389. Self {
  390. sender,
  391. inflight_requests,
  392. }
  393. }
  394. fn get_queue_sender(&self) -> &mpsc::Sender<DbRequest> {
  395. &self.sender
  396. }
  397. /// Show how many inflight requests
  398. #[allow(dead_code)]
  399. pub fn inflight_requests(&self) -> usize {
  400. self.inflight_requests.load(Ordering::Relaxed)
  401. }
  402. }
  403. #[async_trait::async_trait]
  404. impl DatabaseConnector for AsyncRusqlite {
  405. type Transaction<'a> = Transaction<'a>;
  406. /// Begins a transaction
  407. ///
  408. /// If the transaction is Drop it will trigger a rollback operation
  409. async fn begin(&self) -> Result<Self::Transaction<'_>, Error> {
  410. let (sender, receiver) = oneshot::channel();
  411. self.sender
  412. .send(DbRequest::Begin(sender))
  413. .await
  414. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  415. match receiver
  416. .await
  417. .map_err(|_| Error::Internal("Communication".to_owned()))?
  418. {
  419. DbResponse::Transaction(db_sender) => Ok(Transaction {
  420. db_sender,
  421. _marker: PhantomData,
  422. }),
  423. DbResponse::Error(err) => Err(err.into()),
  424. _ => Err(Error::InvalidDbResponse),
  425. }
  426. }
  427. }
  428. #[async_trait::async_trait]
  429. impl DatabaseExecutor for AsyncRusqlite {
  430. fn name() -> &'static str {
  431. "sqlite"
  432. }
  433. async fn fetch_one(&self, mut statement: InnerStatement) -> Result<Option<Vec<Column>>, Error> {
  434. let (sender, receiver) = oneshot::channel();
  435. statement.expected_response = ExpectedSqlResponse::SingleRow;
  436. self.get_queue_sender()
  437. .send(DbRequest::Sql(statement, sender))
  438. .await
  439. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  440. match receiver
  441. .await
  442. .map_err(|_| Error::Internal("Communication".to_owned()))?
  443. {
  444. DbResponse::Row(row) => Ok(row),
  445. DbResponse::Error(err) => Err(err.into()),
  446. _ => Err(Error::InvalidDbResponse),
  447. }
  448. }
  449. async fn batch(&self, mut statement: InnerStatement) -> Result<(), Error> {
  450. let (sender, receiver) = oneshot::channel();
  451. statement.expected_response = ExpectedSqlResponse::Batch;
  452. self.get_queue_sender()
  453. .send(DbRequest::Sql(statement, sender))
  454. .await
  455. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  456. match receiver
  457. .await
  458. .map_err(|_| Error::Internal("Communication".to_owned()))?
  459. {
  460. DbResponse::Ok => Ok(()),
  461. DbResponse::Error(err) => Err(err.into()),
  462. _ => Err(Error::InvalidDbResponse),
  463. }
  464. }
  465. async fn fetch_all(&self, mut statement: InnerStatement) -> Result<Vec<Vec<Column>>, Error> {
  466. let (sender, receiver) = oneshot::channel();
  467. statement.expected_response = ExpectedSqlResponse::ManyRows;
  468. self.get_queue_sender()
  469. .send(DbRequest::Sql(statement, sender))
  470. .await
  471. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  472. match receiver
  473. .await
  474. .map_err(|_| Error::Internal("Communication".to_owned()))?
  475. {
  476. DbResponse::Rows(row) => Ok(row),
  477. DbResponse::Error(err) => Err(err.into()),
  478. _ => Err(Error::InvalidDbResponse),
  479. }
  480. }
  481. async fn execute(&self, mut statement: InnerStatement) -> Result<usize, Error> {
  482. let (sender, receiver) = oneshot::channel();
  483. statement.expected_response = ExpectedSqlResponse::AffectedRows;
  484. self.get_queue_sender()
  485. .send(DbRequest::Sql(statement, sender))
  486. .await
  487. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  488. match receiver
  489. .await
  490. .map_err(|_| Error::Internal("Communication".to_owned()))?
  491. {
  492. DbResponse::AffectedRows(total) => Ok(total),
  493. DbResponse::Error(err) => Err(err.into()),
  494. _ => Err(Error::InvalidDbResponse),
  495. }
  496. }
  497. async fn pluck(&self, mut statement: InnerStatement) -> Result<Option<Column>, Error> {
  498. let (sender, receiver) = oneshot::channel();
  499. statement.expected_response = ExpectedSqlResponse::Pluck;
  500. self.get_queue_sender()
  501. .send(DbRequest::Sql(statement, sender))
  502. .await
  503. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  504. match receiver
  505. .await
  506. .map_err(|_| Error::Internal("Communication".to_owned()))?
  507. {
  508. DbResponse::Pluck(value) => Ok(value),
  509. DbResponse::Error(err) => Err(err.into()),
  510. _ => Err(Error::InvalidDbResponse),
  511. }
  512. }
  513. }
  514. /// Database transaction
  515. #[derive(Debug)]
  516. pub struct Transaction<'conn> {
  517. db_sender: mpsc::Sender<DbRequest>,
  518. _marker: PhantomData<&'conn ()>,
  519. }
  520. impl Transaction<'_> {
  521. fn get_queue_sender(&self) -> &mpsc::Sender<DbRequest> {
  522. &self.db_sender
  523. }
  524. }
  525. impl Drop for Transaction<'_> {
  526. fn drop(&mut self) {
  527. let (sender, _) = oneshot::channel();
  528. let _ = self.db_sender.try_send(DbRequest::Rollback(sender));
  529. }
  530. }
  531. #[async_trait::async_trait]
  532. impl<'a> DatabaseTransaction<'a> for Transaction<'a> {
  533. async fn commit(self) -> Result<(), Error> {
  534. let (sender, receiver) = oneshot::channel();
  535. self.db_sender
  536. .send(DbRequest::Commit(sender))
  537. .await
  538. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  539. match receiver
  540. .await
  541. .map_err(|_| Error::Internal("Communication".to_owned()))?
  542. {
  543. DbResponse::Ok => Ok(()),
  544. DbResponse::Error(err) => Err(err.into()),
  545. _ => Err(Error::InvalidDbResponse),
  546. }
  547. }
  548. async fn rollback(self) -> Result<(), Error> {
  549. let (sender, receiver) = oneshot::channel();
  550. self.db_sender
  551. .send(DbRequest::Rollback(sender))
  552. .await
  553. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  554. match receiver
  555. .await
  556. .map_err(|_| Error::Internal("Communication".to_owned()))?
  557. {
  558. DbResponse::Ok => Ok(()),
  559. DbResponse::Error(err) => Err(err.into()),
  560. _ => Err(Error::InvalidDbResponse),
  561. }
  562. }
  563. }
  564. #[async_trait::async_trait]
  565. impl DatabaseExecutor for Transaction<'_> {
  566. fn name() -> &'static str {
  567. "sqlite"
  568. }
  569. async fn fetch_one(&self, mut statement: InnerStatement) -> Result<Option<Vec<Column>>, Error> {
  570. let (sender, receiver) = oneshot::channel();
  571. statement.expected_response = ExpectedSqlResponse::SingleRow;
  572. self.get_queue_sender()
  573. .send(DbRequest::Sql(statement, sender))
  574. .await
  575. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  576. match receiver
  577. .await
  578. .map_err(|_| Error::Internal("Communication".to_owned()))?
  579. {
  580. DbResponse::Row(row) => Ok(row),
  581. DbResponse::Error(err) => Err(err.into()),
  582. _ => Err(Error::InvalidDbResponse),
  583. }
  584. }
  585. async fn batch(&self, mut statement: InnerStatement) -> Result<(), Error> {
  586. let (sender, receiver) = oneshot::channel();
  587. statement.expected_response = ExpectedSqlResponse::Batch;
  588. self.get_queue_sender()
  589. .send(DbRequest::Sql(statement, sender))
  590. .await
  591. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  592. match receiver
  593. .await
  594. .map_err(|_| Error::Internal("Communication".to_owned()))?
  595. {
  596. DbResponse::Ok => Ok(()),
  597. DbResponse::Error(err) => Err(err.into()),
  598. _ => Err(Error::InvalidDbResponse),
  599. }
  600. }
  601. async fn fetch_all(&self, mut statement: InnerStatement) -> Result<Vec<Vec<Column>>, Error> {
  602. let (sender, receiver) = oneshot::channel();
  603. statement.expected_response = ExpectedSqlResponse::ManyRows;
  604. self.get_queue_sender()
  605. .send(DbRequest::Sql(statement, sender))
  606. .await
  607. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  608. match receiver
  609. .await
  610. .map_err(|_| Error::Internal("Communication".to_owned()))?
  611. {
  612. DbResponse::Rows(row) => Ok(row),
  613. DbResponse::Error(err) => Err(err.into()),
  614. _ => Err(Error::InvalidDbResponse),
  615. }
  616. }
  617. async fn execute(&self, mut statement: InnerStatement) -> Result<usize, Error> {
  618. let (sender, receiver) = oneshot::channel();
  619. statement.expected_response = ExpectedSqlResponse::AffectedRows;
  620. self.get_queue_sender()
  621. .send(DbRequest::Sql(statement, sender))
  622. .await
  623. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  624. match receiver
  625. .await
  626. .map_err(|_| Error::Internal("Communication".to_owned()))?
  627. {
  628. DbResponse::AffectedRows(total) => Ok(total),
  629. DbResponse::Error(err) => Err(err.into()),
  630. _ => Err(Error::InvalidDbResponse),
  631. }
  632. }
  633. async fn pluck(&self, mut statement: InnerStatement) -> Result<Option<Column>, Error> {
  634. let (sender, receiver) = oneshot::channel();
  635. statement.expected_response = ExpectedSqlResponse::Pluck;
  636. self.get_queue_sender()
  637. .send(DbRequest::Sql(statement, sender))
  638. .await
  639. .map_err(|_| Error::Internal("Communication".to_owned()))?;
  640. match receiver
  641. .await
  642. .map_err(|_| Error::Internal("Communication".to_owned()))?
  643. {
  644. DbResponse::Pluck(value) => Ok(value),
  645. DbResponse::Error(err) => Err(err.into()),
  646. _ => Err(Error::InvalidDbResponse),
  647. }
  648. }
  649. }