浏览代码

Enable parallelism with a pool of threads.

There is a main thread, which receives SQL requests and routes them to a worker
thread from a fixed-size pool.

By doing so, SQLite does synchronization, and Rust will only intervene when a
transaction is executed. Transactions are executed in the main thread.
Cesar Rodas 4 月之前
父节点
当前提交
9cf763e8f7
共有 2 个文件被更改,包括 85 次插入21 次删除
  1. 83 19
      crates/cdk-sqlite/src/mint/async_rusqlite.rs
  2. 2 2
      crates/cdk-sqlite/src/mint/auth/mod.rs

+ 83 - 19
crates/cdk-sqlite/src/mint/async_rusqlite.rs

@@ -1,7 +1,6 @@
 use std::marker::PhantomData;
-use std::sync::Arc;
-//use std::sync::atomic::AtomicUsize;
-//use std::sync::Arc;
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::{mpsc as std_mpsc, Arc, Mutex};
 use std::thread::spawn;
 
 use rusqlite::Connection;
@@ -9,15 +8,16 @@ use tokio::sync::{mpsc, oneshot};
 
 use crate::common::SqliteConnectionManager;
 use crate::mint::Error;
-use crate::pool::Pool;
+use crate::pool::{Pool, PooledResource};
 use crate::stmt::{Column, ExpectedSqlResponse, Statement as InnerStatement, Value};
 
-const BUFFER_REQUEST_SIZE: usize = 10_000;
+const SQL_QUEUE_SIZE: usize = 10_000;
+const WORKING_THREAD_POOL_SIZE: usize = 5;
 
 #[derive(Debug, Clone)]
 pub struct AsyncRusqlite {
     sender: mpsc::Sender<DbRequest>,
-    //inflight_requests: Arc<AtomicUsize>,
+    inflight_requests: Arc<AtomicUsize>,
 }
 
 /// Internal request for the database thread
@@ -138,6 +138,52 @@ fn process_query(conn: &Connection, sql: InnerStatement) -> Result<DbResponse, E
     })
 }
 
+/// Spawns N number of threads to execute SQL statements
+///
+/// Enable parallelism with a pool of threads.
+///
+/// There is a main thread, which receives SQL requests and routes them to a worker thread from a
+/// fixed-size pool.
+///
+/// By doing so, SQLite does synchronization, and Rust will only intervene when a transaction is
+/// executed. Transactions are executed in the main thread.
+fn rusqlite_spawn_worker_threads(
+    inflight_requests: Arc<AtomicUsize>,
+    threads: usize,
+) -> std_mpsc::Sender<(
+    PooledResource<SqliteConnectionManager>,
+    InnerStatement,
+    oneshot::Sender<DbResponse>,
+)> {
+    let (sender, receiver) = std_mpsc::channel::<(
+        PooledResource<SqliteConnectionManager>,
+        InnerStatement,
+        oneshot::Sender<DbResponse>,
+    )>();
+    let receiver = Arc::new(Mutex::new(receiver));
+
+    for _ in 0..threads {
+        let rx = receiver.clone();
+        let inflight_requests = inflight_requests.clone();
+        spawn(move || loop {
+            while let Ok((conn, sql, reply_to)) = rx.lock().unwrap().recv() {
+                tracing::info!("Execute query: {}", sql.sql);
+                let result = process_query(&conn, sql);
+                let _ = match result {
+                    Ok(ok) => reply_to.send(ok),
+                    Err(err) => {
+                        tracing::error!("Failed query with error {:?}", err);
+                        reply_to.send(DbResponse::Error(err))
+                    }
+                };
+                drop(conn);
+                inflight_requests.fetch_sub(1, Ordering::Relaxed);
+            }
+        });
+    }
+    sender
+}
+
 /// # Rusqlite main worker
 ///
 /// This function takes ownership of a pool of connections to SQLite, executes SQL statements, and
@@ -148,34 +194,37 @@ fn process_query(conn: &Connection, sql: InnerStatement) -> Result<DbResponse, E
 ///
 /// This is meant to be called in their thread, as it will not exit the loop until the communication
 /// channel is closed.
-fn rusqlite_worker(
+fn rusqlite_worker_manager(
     mut receiver: mpsc::Receiver<DbRequest>,
-
     pool: Arc<Pool<SqliteConnectionManager>>,
+    inflight_requests: Arc<AtomicUsize>,
 ) {
+    let send_sql_to_thread =
+        rusqlite_spawn_worker_threads(inflight_requests.clone(), WORKING_THREAD_POOL_SIZE);
+
     while let Some(request) = receiver.blocking_recv() {
+        inflight_requests.fetch_add(1, Ordering::Relaxed);
         match request {
             DbRequest::Sql(sql, reply_to) => {
                 let conn = match pool.get() {
                     Ok(conn) => conn,
                     Err(err) => {
+                        tracing::error!("Failed to acquire a pool connection: {:?}", err);
+                        inflight_requests.fetch_sub(1, Ordering::Relaxed);
                         let _ = reply_to.send(DbResponse::Error(err.into()));
                         continue;
                     }
                 };
 
-                let result = process_query(&conn, sql);
-                let _ = match result {
-                    Ok(ok) => reply_to.send(ok),
-                    Err(err) => reply_to.send(DbResponse::Error(err)),
-                };
-                drop(conn);
+                let _ = send_sql_to_thread.send((conn, sql, reply_to));
             }
             DbRequest::Begin(reply_to) => {
-                let (sender, mut receiver) = mpsc::channel(BUFFER_REQUEST_SIZE);
+                let (sender, mut receiver) = mpsc::channel(SQL_QUEUE_SIZE);
                 let mut conn = match pool.get() {
                     Ok(conn) => conn,
                     Err(err) => {
+                        tracing::error!("Failed to acquire a pool connection: {:?}", err);
+                        inflight_requests.fetch_sub(1, Ordering::Relaxed);
                         let _ = reply_to.send(DbResponse::Error(err.into()));
                         continue;
                     }
@@ -184,6 +233,8 @@ fn rusqlite_worker(
                 let tx = match conn.transaction() {
                     Ok(tx) => tx,
                     Err(err) => {
+                        tracing::error!("Failed to begin a transaction: {:?}", err);
+                        inflight_requests.fetch_sub(1, Ordering::Relaxed);
                         let _ = reply_to.send(DbResponse::Error(err.into()));
                         continue;
                     }
@@ -192,7 +243,6 @@ fn rusqlite_worker(
                 // Transaction has begun successfully, send the `sender` back to the caller
                 // and wait for statements to execute. On `Drop` the wrapper transaction
                 // should send a `rollback`.
-                //
                 let _ = reply_to.send(DbResponse::Transaction(sender));
 
                 // We intentionally handle the transaction hijacking the main loop, there is
@@ -206,6 +256,7 @@ fn rusqlite_worker(
                         // If the receiver loop is broken (i.e no more `senders` are active) and no
                         // `Commit` statement has been sent, this will trigger a `Rollback`
                         // automatically
+                        tracing::info!("Transaction rollback on drop");
                         let _ = tx.rollback();
                         break;
                     };
@@ -246,6 +297,10 @@ fn rusqlite_worker(
                 let _ = reply_to.send(DbResponse::Unexpected);
             }
         }
+
+        // If wasn't a `continue` the transaction is done by reaching this code, and we should
+        // decrease the inflight_request counter
+        inflight_requests.fetch_sub(1, Ordering::Relaxed);
     }
 }
 
@@ -324,18 +379,27 @@ pub fn query<T: ToString>(sql: T) -> Statement {
 }
 
 impl AsyncRusqlite {
+    /// Creates a new Async Rusqlite wrapper.
     pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
-        let (sender, receiver) = mpsc::channel(BUFFER_REQUEST_SIZE);
+        let (sender, receiver) = mpsc::channel(SQL_QUEUE_SIZE);
+        let inflight_requests = Arc::new(AtomicUsize::new(0));
+        let inflight_requests_for_thread = inflight_requests.clone();
         spawn(move || {
-            rusqlite_worker(receiver, pool);
+            rusqlite_worker_manager(receiver, pool, inflight_requests_for_thread);
         });
 
         Self {
             sender,
-            //inflight_requests: Arc::new(0.into()),
+            inflight_requests,
         }
     }
 
+    /// Show how many inflight requests
+    #[allow(dead_code)]
+    pub fn inflight_requests(&self) -> usize {
+        self.inflight_requests.load(Ordering::Relaxed)
+    }
+
     /// Begins a transaction
     ///
     /// If the transaction is Drop it will trigger a rollback operation

+ 2 - 2
crates/cdk-sqlite/src/mint/auth/mod.rs

@@ -239,11 +239,11 @@ impl MintAuthDatabase for MintSqliteAuthDatabase {
     async fn add_blind_signatures(
         &self,
         blinded_messages: &[PublicKey],
-        blinded_signatures: &[BlindSignature],
+        blind_signatures: &[BlindSignature],
     ) -> Result<(), Self::Err> {
         let transaction = self.pool.begin().await?;
 
-        for (message, signature) in blinded_messages.iter().zip(blinded_signatures) {
+        for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
             query(
                 r#"
                     INSERT