| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498 |
- use std::marker::PhantomData;
- use std::sync::atomic::{AtomicUsize, Ordering};
- use std::sync::{mpsc as std_mpsc, Arc, Mutex};
- use std::thread::spawn;
- use rusqlite::Connection;
- use tokio::sync::{mpsc, oneshot};
- use crate::common::SqliteConnectionManager;
- use crate::mint::Error;
- use crate::pool::{Pool, PooledResource};
- use crate::stmt::{Column, ExpectedSqlResponse, Statement as InnerStatement, Value};
- const SQL_QUEUE_SIZE: usize = 10_000;
- const WORKING_THREAD_POOL_SIZE: usize = 5;
- #[derive(Debug, Clone)]
- pub struct AsyncRusqlite {
- sender: mpsc::Sender<DbRequest>,
- inflight_requests: Arc<AtomicUsize>,
- }
- /// Internal request for the database thread
- #[derive(Debug)]
- pub enum DbRequest {
- Sql(InnerStatement, oneshot::Sender<DbResponse>),
- Begin(oneshot::Sender<DbResponse>),
- Commit(oneshot::Sender<DbResponse>),
- Rollback(oneshot::Sender<DbResponse>),
- }
- #[derive(Debug)]
- pub enum DbResponse {
- Transaction(mpsc::Sender<DbRequest>),
- AffectedRows(usize),
- Pluck(Option<Column>),
- Row(Option<Vec<Column>>),
- Rows(Vec<Vec<Column>>),
- Error(Error),
- Unexpected,
- Ok,
- }
- /// Statement for the async_rusqlite wrapper
- pub struct Statement(InnerStatement);
- impl Statement {
- /// Bind a variable
- pub fn bind<C, V>(self, name: C, value: V) -> Self
- where
- C: ToString,
- V: Into<Value>,
- {
- Self(self.0.bind(name, value))
- }
- /// Bind vec
- pub fn bind_vec<C, V>(self, name: C, value: Vec<V>) -> Self
- where
- C: ToString,
- V: Into<Value>,
- {
- Self(self.0.bind_vec(name, value))
- }
- /// Executes a query and return the number of affected rows
- pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
- where
- C: DatabaseExecutor + Send + Sync,
- {
- conn.execute(self.0).await
- }
- /// Returns the first column of the first row of the query result
- pub async fn pluck<C>(self, conn: &C) -> Result<Option<Column>, Error>
- where
- C: DatabaseExecutor + Send + Sync,
- {
- conn.pluck(self.0).await
- }
- /// Returns the first row of the query result
- pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
- where
- C: DatabaseExecutor + Send + Sync,
- {
- conn.fetch_one(self.0).await
- }
- /// Returns all rows of the query result
- pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
- where
- C: DatabaseExecutor + Send + Sync,
- {
- conn.fetch_all(self.0).await
- }
- }
- /// Process a query
- #[inline(always)]
- fn process_query(conn: &Connection, sql: InnerStatement) -> Result<DbResponse, Error> {
- let mut stmt = conn.prepare_cached(&sql.sql)?;
- for (name, value) in sql.args {
- let index = stmt
- .parameter_index(&name)
- .map_err(|_| Error::MissingParameter(name.clone()))?
- .ok_or(Error::MissingParameter(name))?;
- stmt.raw_bind_parameter(index, value)?;
- }
- let columns = stmt.column_count();
- Ok(match sql.expected_response {
- ExpectedSqlResponse::AffectedRows => DbResponse::AffectedRows(stmt.raw_execute()?),
- ExpectedSqlResponse::ManyRows => {
- let mut rows = stmt.raw_query();
- let mut results = vec![];
- while let Some(row) = rows.next()? {
- results.push(
- (0..columns)
- .map(|i| row.get(i))
- .collect::<Result<Vec<_>, _>>()?,
- )
- }
- DbResponse::Rows(results)
- }
- ExpectedSqlResponse::Pluck => {
- let mut rows = stmt.raw_query();
- DbResponse::Pluck(rows.next()?.map(|row| row.get(0usize)).transpose()?)
- }
- ExpectedSqlResponse::SingleRow => {
- let mut rows = stmt.raw_query();
- let row = rows
- .next()?
- .map(|row| {
- (0..columns)
- .map(|i| row.get(i))
- .collect::<Result<Vec<_>, _>>()
- })
- .transpose()?;
- DbResponse::Row(row)
- }
- })
- }
- /// Spawns N number of threads to execute SQL statements
- ///
- /// Enable parallelism with a pool of threads.
- ///
- /// There is a main thread, which receives SQL requests and routes them to a worker thread from a
- /// fixed-size pool.
- ///
- /// By doing so, SQLite does synchronization, and Rust will only intervene when a transaction is
- /// executed. Transactions are executed in the main thread.
- fn rusqlite_spawn_worker_threads(
- inflight_requests: Arc<AtomicUsize>,
- threads: usize,
- ) -> std_mpsc::Sender<(
- PooledResource<SqliteConnectionManager>,
- InnerStatement,
- oneshot::Sender<DbResponse>,
- )> {
- let (sender, receiver) = std_mpsc::channel::<(
- PooledResource<SqliteConnectionManager>,
- InnerStatement,
- oneshot::Sender<DbResponse>,
- )>();
- let receiver = Arc::new(Mutex::new(receiver));
- for _ in 0..threads {
- let rx = receiver.clone();
- let inflight_requests = inflight_requests.clone();
- spawn(move || loop {
- while let Ok((conn, sql, reply_to)) = rx.lock().unwrap().recv() {
- tracing::info!("Execute query: {}", sql.sql);
- let result = process_query(&conn, sql);
- let _ = match result {
- Ok(ok) => reply_to.send(ok),
- Err(err) => {
- tracing::error!("Failed query with error {:?}", err);
- reply_to.send(DbResponse::Error(err))
- }
- };
- drop(conn);
- inflight_requests.fetch_sub(1, Ordering::Relaxed);
- }
- });
- }
- sender
- }
- /// # Rusqlite main worker
- ///
- /// This function takes ownership of a pool of connections to SQLite, executes SQL statements, and
- /// returns the results or number of affected rows to the caller. All communications are done
- /// through channels. This function is synchronous, but a thread pool exists to execute queries, and
- /// SQLite will coordinate data access. Transactions are executed in the main and it takes ownership
- /// of the main thread until it is finalized
- ///
- /// This is meant to be called in their thread, as it will not exit the loop until the communication
- /// channel is closed.
- fn rusqlite_worker_manager(
- mut receiver: mpsc::Receiver<DbRequest>,
- pool: Arc<Pool<SqliteConnectionManager>>,
- inflight_requests: Arc<AtomicUsize>,
- ) {
- let send_sql_to_thread =
- rusqlite_spawn_worker_threads(inflight_requests.clone(), WORKING_THREAD_POOL_SIZE);
- let mut tx_id: usize = 0;
- while let Some(request) = receiver.blocking_recv() {
- inflight_requests.fetch_add(1, Ordering::Relaxed);
- match request {
- DbRequest::Sql(sql, reply_to) => {
- let conn = match pool.get() {
- Ok(conn) => conn,
- Err(err) => {
- tracing::error!("Failed to acquire a pool connection: {:?}", err);
- inflight_requests.fetch_sub(1, Ordering::Relaxed);
- let _ = reply_to.send(DbResponse::Error(err.into()));
- continue;
- }
- };
- let _ = send_sql_to_thread.send((conn, sql, reply_to));
- continue;
- }
- DbRequest::Begin(reply_to) => {
- let (sender, mut receiver) = mpsc::channel(SQL_QUEUE_SIZE);
- let mut conn = match pool.get() {
- Ok(conn) => conn,
- Err(err) => {
- tracing::error!("Failed to acquire a pool connection: {:?}", err);
- inflight_requests.fetch_sub(1, Ordering::Relaxed);
- let _ = reply_to.send(DbResponse::Error(err.into()));
- continue;
- }
- };
- let tx = match conn.transaction() {
- Ok(tx) => tx,
- Err(err) => {
- tracing::error!("Failed to begin a transaction: {:?}", err);
- inflight_requests.fetch_sub(1, Ordering::Relaxed);
- let _ = reply_to.send(DbResponse::Error(err.into()));
- continue;
- }
- };
- // Transaction has begun successfully, send the `sender` back to the caller
- // and wait for statements to execute. On `Drop` the wrapper transaction
- // should send a `rollback`.
- let _ = reply_to.send(DbResponse::Transaction(sender));
- tx_id += 1;
- // We intentionally handle the transaction hijacking the main loop, there is
- // no point is queueing more operations for SQLite, since transaction have
- // exclusive access. In other database implementation this block of code
- // should be sent to their own thread to allow concurrency
- loop {
- let request = if let Some(request) = receiver.blocking_recv() {
- request
- } else {
- // If the receiver loop is broken (i.e no more `senders` are active) and no
- // `Commit` statement has been sent, this will trigger a `Rollback`
- // automatically
- tracing::info!("Tx {}: Transaction rollback on drop", tx_id);
- let _ = tx.rollback();
- break;
- };
- match request {
- DbRequest::Commit(reply_to) => {
- tracing::info!("Tx {}: Commit", tx_id);
- let _ = reply_to.send(match tx.commit() {
- Ok(()) => DbResponse::Ok,
- Err(err) => DbResponse::Error(err.into()),
- });
- break;
- }
- DbRequest::Rollback(reply_to) => {
- tracing::info!("Tx {}: Rollback", tx_id);
- let _ = reply_to.send(match tx.rollback() {
- Ok(()) => DbResponse::Ok,
- Err(err) => DbResponse::Error(err.into()),
- });
- break;
- }
- DbRequest::Begin(reply_to) => {
- let _ = reply_to.send(DbResponse::Unexpected);
- }
- DbRequest::Sql(sql, reply_to) => {
- tracing::info!("Tx {}: SQL {}", tx_id, sql.sql);
- let _ = match process_query(&tx, sql) {
- Ok(ok) => reply_to.send(ok),
- Err(err) => reply_to.send(DbResponse::Error(err)),
- };
- }
- }
- }
- drop(conn);
- }
- DbRequest::Commit(reply_to) => {
- let _ = reply_to.send(DbResponse::Unexpected);
- }
- DbRequest::Rollback(reply_to) => {
- let _ = reply_to.send(DbResponse::Unexpected);
- }
- }
- // If wasn't a `continue` the transaction is done by reaching this code, and we should
- // decrease the inflight_request counter
- inflight_requests.fetch_sub(1, Ordering::Relaxed);
- }
- }
- #[async_trait::async_trait]
- pub trait DatabaseExecutor {
- /// Returns the connection to the database thread (or the on-going transaction)
- fn get_queue_sender(&self) -> mpsc::Sender<DbRequest>;
- /// Executes a query and returns the affected rows
- async fn execute(&self, mut statement: InnerStatement) -> Result<usize, Error> {
- let (sender, receiver) = oneshot::channel();
- statement.expected_response = ExpectedSqlResponse::AffectedRows;
- self.get_queue_sender()
- .send(DbRequest::Sql(statement, sender))
- .await
- .map_err(|_| Error::Communication)?;
- match receiver.await.map_err(|_| Error::Communication)? {
- DbResponse::AffectedRows(n) => Ok(n),
- DbResponse::Error(err) => Err(err),
- _ => Err(Error::InvalidDbResponse),
- }
- }
- /// Runs the query and returns the first row or None
- async fn fetch_one(&self, mut statement: InnerStatement) -> Result<Option<Vec<Column>>, Error> {
- let (sender, receiver) = oneshot::channel();
- statement.expected_response = ExpectedSqlResponse::SingleRow;
- self.get_queue_sender()
- .send(DbRequest::Sql(statement, sender))
- .await
- .map_err(|_| Error::Communication)?;
- match receiver.await.map_err(|_| Error::Communication)? {
- DbResponse::Row(row) => Ok(row),
- DbResponse::Error(err) => Err(err),
- _ => Err(Error::InvalidDbResponse),
- }
- }
- /// Runs the query and returns the first row or None
- async fn fetch_all(&self, mut statement: InnerStatement) -> Result<Vec<Vec<Column>>, Error> {
- let (sender, receiver) = oneshot::channel();
- statement.expected_response = ExpectedSqlResponse::ManyRows;
- self.get_queue_sender()
- .send(DbRequest::Sql(statement, sender))
- .await
- .map_err(|_| Error::Communication)?;
- match receiver.await.map_err(|_| Error::Communication)? {
- DbResponse::Rows(rows) => Ok(rows),
- DbResponse::Error(err) => Err(err),
- _ => Err(Error::InvalidDbResponse),
- }
- }
- async fn pluck(&self, mut statement: InnerStatement) -> Result<Option<Column>, Error> {
- let (sender, receiver) = oneshot::channel();
- statement.expected_response = ExpectedSqlResponse::Pluck;
- self.get_queue_sender()
- .send(DbRequest::Sql(statement, sender))
- .await
- .map_err(|_| Error::Communication)?;
- match receiver.await.map_err(|_| Error::Communication)? {
- DbResponse::Pluck(value) => Ok(value),
- DbResponse::Error(err) => Err(err),
- _ => Err(Error::InvalidDbResponse),
- }
- }
- }
- #[inline(always)]
- pub fn query<T>(sql: T) -> Statement
- where
- T: ToString,
- {
- Statement(crate::stmt::Statement::new(sql))
- }
- impl AsyncRusqlite {
- /// Creates a new Async Rusqlite wrapper.
- pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
- let (sender, receiver) = mpsc::channel(SQL_QUEUE_SIZE);
- let inflight_requests = Arc::new(AtomicUsize::new(0));
- let inflight_requests_for_thread = inflight_requests.clone();
- spawn(move || {
- rusqlite_worker_manager(receiver, pool, inflight_requests_for_thread);
- });
- Self {
- sender,
- inflight_requests,
- }
- }
- /// Show how many inflight requests
- #[allow(dead_code)]
- pub fn inflight_requests(&self) -> usize {
- self.inflight_requests.load(Ordering::Relaxed)
- }
- /// Begins a transaction
- ///
- /// If the transaction is Drop it will trigger a rollback operation
- pub async fn begin(&self) -> Result<Transaction<'_>, Error> {
- let (sender, receiver) = oneshot::channel();
- self.sender
- .send(DbRequest::Begin(sender))
- .await
- .map_err(|_| Error::Communication)?;
- match receiver.await.map_err(|_| Error::Communication)? {
- DbResponse::Transaction(db_sender) => Ok(Transaction {
- db_sender,
- _marker: PhantomData,
- }),
- DbResponse::Error(err) => Err(err),
- _ => Err(Error::InvalidDbResponse),
- }
- }
- }
- impl DatabaseExecutor for AsyncRusqlite {
- #[inline(always)]
- fn get_queue_sender(&self) -> mpsc::Sender<DbRequest> {
- self.sender.clone()
- }
- }
- pub struct Transaction<'conn> {
- db_sender: mpsc::Sender<DbRequest>,
- _marker: PhantomData<&'conn ()>,
- }
- impl Drop for Transaction<'_> {
- fn drop(&mut self) {
- let (sender, _) = oneshot::channel();
- let _ = self.db_sender.try_send(DbRequest::Rollback(sender));
- }
- }
- impl Transaction<'_> {
- pub async fn commit(self) -> Result<(), Error> {
- let (sender, receiver) = oneshot::channel();
- self.db_sender
- .send(DbRequest::Commit(sender))
- .await
- .map_err(|_| Error::Communication)?;
- match receiver.await.map_err(|_| Error::Communication)? {
- DbResponse::Ok => Ok(()),
- DbResponse::Error(err) => Err(err),
- _ => Err(Error::InvalidDbResponse),
- }
- }
- pub async fn rollback(self) -> Result<(), Error> {
- let (sender, receiver) = oneshot::channel();
- self.db_sender
- .send(DbRequest::Rollback(sender))
- .await
- .map_err(|_| Error::Communication)?;
- match receiver.await.map_err(|_| Error::Communication)? {
- DbResponse::Ok => Ok(()),
- DbResponse::Error(err) => Err(err),
- _ => Err(Error::InvalidDbResponse),
- }
- }
- }
- impl DatabaseExecutor for Transaction<'_> {
- /// Get the internal sender to the SQL queue
- #[inline(always)]
- fn get_queue_sender(&self) -> mpsc::Sender<DbRequest> {
- self.db_sender.clone()
- }
- }
|