فهرست منبع

Custom pool

Implement a simple connection pool to manage multiple connections to SQLite.

The goal is to reduce external dependencies, even with popular crates such as
r2d2, to minimize the vendor-supply attack vector.

If other parts of the code require a generic connection pool, this module can
be promoted to be an independent crate.
Cesar Rodas 5 ماه پیش
والد
کامیت
d3cb8b0d84

+ 2 - 3
crates/cdk-sqlite/Cargo.toml

@@ -16,14 +16,13 @@ default = ["mint", "wallet", "auth"]
 mint = ["cdk-common/mint"]
 wallet = ["cdk-common/wallet"]
 auth = ["cdk-common/auth"]
-sqlcipher = ["r2d2_sqlite/bundled-sqlcipher"]
+sqlcipher = ["rusqlite/bundled-sqlcipher"]
 
 [dependencies]
 async-trait.workspace = true
 cdk-common = { workspace = true, features = ["test"] }
 bitcoin.workspace = true
-r2d2_sqlite = { version = "0.19.0", features = ["bundled"] }
-r2d2 = { version = "0.8" }
+rusqlite = { version = "0.27", features = ["bundled"]}
 thiserror.workspace = true
 tokio.workspace = true
 tracing.workspace = true

+ 67 - 26
crates/cdk-sqlite/src/common.rs

@@ -1,22 +1,37 @@
-use r2d2::{Pool, PooledConnection};
-use r2d2_sqlite::rusqlite::params;
-use r2d2_sqlite::SqliteConnectionManager;
+use std::sync::Arc;
 
-/// Create a configured rusqlite connection to a SQLite database.
-/// For SQLCipher support, enable the "sqlcipher" feature and pass a password.
-pub fn create_sqlite_pool(
-    path: &str,
-    #[cfg(feature = "sqlcipher")] password: String,
-) -> Result<Pool<SqliteConnectionManager>, r2d2::Error> {
-    let (manager, is_memory) = if path.contains(":memory:") {
-        (SqliteConnectionManager::memory(), true)
-    } else {
-        (SqliteConnectionManager::file(path), false)
-    };
+use rusqlite::{params, Connection};
+
+use crate::pool::{Pool, ResourceManager};
+
+#[derive(Debug)]
+pub(crate) struct Config {
+    path: Option<String>,
+    password: Option<String>,
+}
+
+#[derive(Debug)]
+pub(crate) struct SqliteConnectionManager;
+
+impl ResourceManager for SqliteConnectionManager {
+    type Config = Config;
+
+    type Resource = Connection;
+
+    type Error = rusqlite::Error;
 
-    let manager = manager.with_init(move |conn| {
-        #[cfg(feature = "sqlcipher")]
-        conn.execute_batch(&format!("pragma key = {};", password))?;
+    fn new_resource(
+        config: &Self::Config,
+    ) -> Result<Self::Resource, crate::pool::Error<Self::Error>> {
+        let conn = if let Some(path) = config.path.as_ref() {
+            Connection::open(path)?
+        } else {
+            Connection::open_in_memory()?
+        };
+
+        if let Some(password) = config.password.as_ref() {
+            conn.execute_batch(&format!("pragma key = {password};"))?;
+        }
 
         conn.execute_batch(
             r#"
@@ -28,19 +43,45 @@ pub fn create_sqlite_pool(
             "#,
         )?;
 
-        Ok(())
-    });
+        Ok(conn)
+    }
+}
+
+/// Create a configured rusqlite connection to a SQLite database.
+/// For SQLCipher support, enable the "sqlcipher" feature and pass a password.
+pub fn create_sqlite_pool(
+    path: &str,
+    #[cfg(feature = "sqlcipher")] password: String,
+) -> Arc<Pool<SqliteConnectionManager>> {
+    #[cfg(feature = "sqlcipher")]
+    let password = Some(password);
+
+    #[cfg(not(feature = "sqlcipher"))]
+    let password = None;
+
+    let (config, max_size) = if path.contains(":memory:") {
+        (
+            Config {
+                path: None,
+                password,
+            },
+            1,
+        )
+    } else {
+        (
+            Config {
+                path: Some(path.to_owned()),
+                password,
+            },
+            20,
+        )
+    };
 
-    r2d2::Pool::builder()
-        .max_size(if is_memory { 1 } else { 20 })
-        .build(manager)
+    Pool::new(config, max_size)
 }
 
 /// Migrates the migration generated by `build.rs`
-pub fn migrate(
-    mut conn: PooledConnection<SqliteConnectionManager>,
-    migrations: &[(&str, &str)],
-) -> Result<(), r2d2_sqlite::rusqlite::Error> {
+pub fn migrate(conn: &mut Connection, migrations: &[(&str, &str)]) -> Result<(), rusqlite::Error> {
     let tx = conn.transaction()?;
     tx.execute(
         r#"

+ 1 - 0
crates/cdk-sqlite/src/lib.rs

@@ -5,6 +5,7 @@
 
 mod common;
 mod macros;
+mod pool;
 mod stmt;
 
 #[cfg(feature = "mint")]

+ 6 - 4
crates/cdk-sqlite/src/mint/async_rusqlite.rs

@@ -1,13 +1,15 @@
 use std::marker::PhantomData;
+use std::sync::Arc;
 //use std::sync::atomic::AtomicUsize;
 //use std::sync::Arc;
 use std::thread::spawn;
 
-use r2d2_sqlite::rusqlite::Connection;
-use r2d2_sqlite::SqliteConnectionManager;
+use rusqlite::Connection;
 use tokio::sync::{mpsc, oneshot};
 
+use crate::common::SqliteConnectionManager;
 use crate::mint::Error;
+use crate::pool::Pool;
 use crate::stmt::{Column, ExpectedSqlResponse, Statement as InnerStatement, Value};
 
 const BUFFER_REQUEST_SIZE: usize = 10_000;
@@ -137,7 +139,7 @@ fn process_query(conn: &Connection, sql: InnerStatement) -> Result<DbResponse, E
 
 fn rusqlite_worker(
     mut receiver: mpsc::Receiver<DbRequest>,
-    pool: r2d2::Pool<SqliteConnectionManager>,
+    pool: Arc<Pool<SqliteConnectionManager>>,
 ) {
     while let Some(request) = receiver.blocking_recv() {
         match request {
@@ -306,7 +308,7 @@ pub fn query<T: ToString>(sql: T) -> Statement {
 }
 
 impl AsyncRusqlite {
-    pub fn new(pool: r2d2::Pool<SqliteConnectionManager>) -> Self {
+    pub fn new(pool: Arc<Pool<SqliteConnectionManager>>) -> Self {
         let (sender, receiver) = mpsc::channel(BUFFER_REQUEST_SIZE);
         spawn(move || {
             rusqlite_worker(receiver, pool);

+ 4 - 3
crates/cdk-sqlite/src/mint/auth/mod.rs

@@ -1,6 +1,7 @@
 //! SQLite Mint Auth
 
 use std::collections::HashMap;
+use std::ops::DerefMut;
 use std::path::Path;
 use std::str::FromStr;
 
@@ -34,10 +35,10 @@ impl MintSqliteAuthDatabase {
         let pool = create_sqlite_pool(
             path.as_ref().to_str().ok_or(Error::InvalidDbPath)?,
             "".to_owned(),
-        )?;
+        );
         #[cfg(not(feature = "sqlcipher"))]
-        let pool = create_sqlite_pool(path.as_ref().to_str().ok_or(Error::InvalidDbPath)?)?;
-        migrate(pool.get()?, migrations::MIGRATIONS)?;
+        let pool = create_sqlite_pool(path.as_ref().to_str().ok_or(Error::InvalidDbPath)?);
+        migrate(pool.get()?.deref_mut(), migrations::MIGRATIONS)?;
 
         Ok(Self {
             pool: AsyncRusqlite::new(pool),

+ 2 - 2
crates/cdk-sqlite/src/mint/error.rs

@@ -7,11 +7,11 @@ use thiserror::Error;
 pub enum Error {
     /// SQLX Error
     #[error(transparent)]
-    Sqlite(#[from] r2d2_sqlite::rusqlite::Error),
+    Sqlite(#[from] rusqlite::Error),
 
     /// Pool error
     #[error(transparent)]
-    Pool(#[from] r2d2::Error),
+    Pool(#[from] crate::pool::Error<rusqlite::Error>),
     /// Invalid UUID
     #[error("Invalid UUID: {0}")]
     InvalidUuid(String),

+ 5 - 4
crates/cdk-sqlite/src/mint/mod.rs

@@ -1,6 +1,7 @@
 //! SQLite Mint
 
 use std::collections::HashMap;
+use std::ops::DerefMut;
 use std::path::Path;
 use std::str::FromStr;
 
@@ -55,8 +56,8 @@ impl MintSqliteDatabase {
     /// Create new [`MintSqliteDatabase`]
     #[cfg(not(feature = "sqlcipher"))]
     pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
-        let pool = create_sqlite_pool(path.as_ref().to_str().ok_or(Error::InvalidDbPath)?)?;
-        migrate(pool.get()?, migrations::MIGRATIONS)?;
+        let pool = create_sqlite_pool(path.as_ref().to_str().ok_or(Error::InvalidDbPath)?);
+        migrate(pool.get()?.deref_mut(), migrations::MIGRATIONS)?;
 
         Ok(Self {
             pool: async_rusqlite::AsyncRusqlite::new(pool),
@@ -69,8 +70,8 @@ impl MintSqliteDatabase {
         let pool = create_sqlite_pool(
             path.as_ref().to_str().ok_or(Error::InvalidDbPath)?,
             password,
-        )?;
-        migrate(pool.get()?, migrations::MIGRATIONS)?;
+        );
+        migrate(pool.get()?.deref_mut(), migrations::MIGRATIONS)?;
 
         Ok(Self {
             pool: async_rusqlite::AsyncRusqlite::new(pool),

+ 163 - 0
crates/cdk-sqlite/src/pool.rs

@@ -0,0 +1,163 @@
+//! Very simple connection pool, to avoid an external dependency on r2d2 and other crates. If this
+//! endup work it can be re-used in other parts of the project and may be promoted to its own
+//! generic crate
+
+use std::fmt::Debug;
+use std::ops::{Deref, DerefMut};
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::{Arc, Condvar, Mutex};
+
+#[derive(thiserror::Error, Debug)]
+pub enum Error<E> {
+    /// Mutex Poison Error
+    #[error("Internal: PoisonError")]
+    PoisonError,
+
+    /// Internal database error
+    #[error(transparent)]
+    ResourceError(#[from] E),
+}
+
+/// Trait to manage resources
+pub trait ResourceManager: Debug {
+    type Resource: Debug;
+
+    type Config: Debug;
+
+    type Error: Debug;
+
+    /// Creates a new resource with a given config
+    fn new_resource(config: &Self::Config) -> Result<Self::Resource, Error<Self::Error>>;
+
+    /// The object is dropped
+    fn drop(_resource: Self::Resource) {}
+}
+
+/// Generic connection pool of resources R
+#[derive(Debug)]
+pub struct Pool<RM>
+where
+    RM: ResourceManager,
+{
+    config: RM::Config,
+    queue: Mutex<Vec<RM::Resource>>,
+    in_use: AtomicUsize,
+    max_size: usize,
+    waiter: Condvar,
+}
+
+pub struct WrappedResource<RM>
+where
+    RM: ResourceManager,
+{
+    resource: Option<RM::Resource>,
+    pool: Arc<Pool<RM>>,
+}
+
+impl<RM> Drop for WrappedResource<RM>
+where
+    RM: ResourceManager,
+{
+    fn drop(&mut self) {
+        if let Some(resource) = self.resource.take() {
+            let mut active_resource = self.pool.queue.lock().expect("active_resource");
+            active_resource.push(resource);
+            self.pool.in_use.fetch_sub(1, Ordering::AcqRel);
+
+            // Notify a waiting thread
+            self.pool.waiter.notify_one();
+        }
+    }
+}
+
+impl<RM> Deref for WrappedResource<RM>
+where
+    RM: ResourceManager,
+{
+    type Target = RM::Resource;
+
+    fn deref(&self) -> &Self::Target {
+        self.resource.as_ref().expect("resource already dropped")
+    }
+}
+
+impl<RM> DerefMut for WrappedResource<RM>
+where
+    RM: ResourceManager,
+{
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.resource.as_mut().expect("resource already dropped")
+    }
+}
+
+impl<RM> Pool<RM>
+where
+    RM: ResourceManager,
+{
+    /// Creates a new pool
+    pub fn new(config: RM::Config, max_size: usize) -> Arc<Self> {
+        Arc::new(Self {
+            config,
+            queue: Default::default(),
+            in_use: Default::default(),
+            waiter: Default::default(),
+            max_size,
+        })
+    }
+
+    pub fn get(self: &Arc<Self>) -> Result<WrappedResource<RM>, Error<RM::Error>> {
+        let mut resources = self.queue.lock().map_err(|_| Error::PoisonError)?;
+
+        loop {
+            if let Some(resource) = resources.pop() {
+                drop(resources);
+                self.in_use.fetch_add(1, Ordering::AcqRel);
+
+                return Ok(WrappedResource {
+                    resource: Some(resource),
+                    pool: self.clone(),
+                });
+            }
+
+            if self.in_use.load(Ordering::Relaxed) < self.max_size {
+                drop(resources);
+                self.in_use.fetch_add(1, Ordering::AcqRel);
+
+                return Ok(WrappedResource {
+                    resource: Some(RM::new_resource(&self.config)?),
+                    pool: self.clone(),
+                });
+            }
+
+            resources = self
+                .waiter
+                .wait(resources)
+                .map_err(|_| Error::PoisonError)?;
+        }
+    }
+}
+
+impl<RM> Drop for Pool<RM>
+where
+    RM: ResourceManager,
+{
+    fn drop(&mut self) {
+        if let Ok(mut resources) = self.queue.lock() {
+            loop {
+                while let Some(resource) = resources.pop() {
+                    RM::drop(resource);
+                }
+
+                if self.in_use.load(Ordering::Relaxed) == 0 {
+                    break;
+                }
+
+                resources = if let Ok(resources) = self.waiter.wait(resources) {
+                    resources
+                } else {
+                    break;
+                };
+            }
+        }
+    }
+}

+ 11 - 10
crates/cdk-sqlite/src/stmt.rs

@@ -1,11 +1,12 @@
-use r2d2::PooledConnection;
-use r2d2_sqlite::rusqlite::{self, CachedStatement};
-use r2d2_sqlite::SqliteConnectionManager;
+use rusqlite::{self, CachedStatement};
 
-pub type Value = r2d2_sqlite::rusqlite::types::Value;
+use crate::common::SqliteConnectionManager;
+use crate::pool::WrappedResource;
+
+pub type Value = rusqlite::types::Value;
 
 /// The Column type
-pub type Column = r2d2_sqlite::rusqlite::types::Value;
+pub type Column = rusqlite::types::Value;
 
 /// Expected Sql response
 #[derive(Debug, Clone, Copy, Default)]
@@ -86,7 +87,7 @@ impl Statement {
 
     fn get_stmt(
         self,
-        conn: &PooledConnection<SqliteConnectionManager>,
+        conn: &WrappedResource<SqliteConnectionManager>,
     ) -> rusqlite::Result<CachedStatement<'_>> {
         let mut stmt = conn.prepare_cached(&self.sql)?;
         for (name, value) in self.args {
@@ -104,7 +105,7 @@ impl Statement {
     /// Executes a query and returns the affected rows
     pub fn plunk(
         self,
-        conn: &PooledConnection<SqliteConnectionManager>,
+        conn: &WrappedResource<SqliteConnectionManager>,
     ) -> rusqlite::Result<Option<Value>> {
         let mut stmt = self.get_stmt(conn)?;
         let mut rows = stmt.raw_query();
@@ -114,7 +115,7 @@ impl Statement {
     /// Executes a query and returns the affected rows
     pub fn execute(
         self,
-        conn: &PooledConnection<SqliteConnectionManager>,
+        conn: &WrappedResource<SqliteConnectionManager>,
     ) -> rusqlite::Result<usize> {
         self.get_stmt(conn)?.raw_execute()
     }
@@ -122,7 +123,7 @@ impl Statement {
     /// Runs the query and returns the first row or None
     pub fn fetch_one(
         self,
-        conn: &PooledConnection<SqliteConnectionManager>,
+        conn: &WrappedResource<SqliteConnectionManager>,
     ) -> rusqlite::Result<Option<Vec<Column>>> {
         let mut stmt = self.get_stmt(conn)?;
         let columns = stmt.column_count();
@@ -139,7 +140,7 @@ impl Statement {
     /// Runs the query and returns the first row or None
     pub fn fetch_all(
         self,
-        conn: &PooledConnection<SqliteConnectionManager>,
+        conn: &WrappedResource<SqliteConnectionManager>,
     ) -> rusqlite::Result<Vec<Vec<Column>>> {
         let mut stmt = self.get_stmt(conn)?;
         let columns = stmt.column_count();

+ 2 - 2
crates/cdk-sqlite/src/wallet/error.rs

@@ -7,10 +7,10 @@ use thiserror::Error;
 pub enum Error {
     /// SQLX Error
     #[error(transparent)]
-    Sqlite(#[from] r2d2_sqlite::rusqlite::Error),
+    Sqlite(#[from] rusqlite::Error),
     /// Pool error
     #[error(transparent)]
-    Pool(#[from] r2d2::Error),
+    Pool(#[from] crate::pool::Error<rusqlite::Error>),
 
     /// Missing columns
     #[error("Not enough elements: expected {0}, got {1}")]

+ 8 - 7
crates/cdk-sqlite/src/wallet/mod.rs

@@ -1,8 +1,10 @@
 //! SQLite Wallet Database
 
 use std::collections::HashMap;
+use std::ops::DerefMut;
 use std::path::Path;
 use std::str::FromStr;
+use std::sync::Arc;
 
 use async_trait::async_trait;
 use cdk_common::common::ProofInfo;
@@ -16,11 +18,10 @@ use cdk_common::{
     SecretKey, SpendingConditions, State,
 };
 use error::Error;
-use r2d2::Pool;
-use r2d2_sqlite::SqliteConnectionManager;
 use tracing::instrument;
 
-use crate::common::{create_sqlite_pool, migrate};
+use crate::common::{create_sqlite_pool, migrate, SqliteConnectionManager};
+use crate::pool::Pool;
 use crate::stmt::{Column, Statement};
 use crate::{
     column_as_binary, column_as_nullable_binary, column_as_nullable_number,
@@ -36,7 +37,7 @@ mod migrations;
 /// Wallet SQLite Database
 #[derive(Debug, Clone)]
 pub struct WalletSqliteDatabase {
-    pool: Pool<SqliteConnectionManager>,
+    pool: Arc<Pool<SqliteConnectionManager>>,
 }
 
 impl WalletSqliteDatabase {
@@ -44,7 +45,7 @@ impl WalletSqliteDatabase {
     #[cfg(not(feature = "sqlcipher"))]
     pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
         let db = Self {
-            pool: create_sqlite_pool(path.as_ref().to_str().ok_or(Error::InvalidDbPath)?)?,
+            pool: create_sqlite_pool(path.as_ref().to_str().ok_or(Error::InvalidDbPath)?),
         };
         db.migrate()?;
         Ok(db)
@@ -57,7 +58,7 @@ impl WalletSqliteDatabase {
             pool: create_sqlite_pool(
                 path.as_ref().to_str().ok_or(Error::InvalidDbPath)?,
                 password,
-            )?,
+            ),
         };
         db.migrate()?;
         Ok(db)
@@ -65,7 +66,7 @@ impl WalletSqliteDatabase {
 
     /// Migrate [`WalletSqliteDatabase`]
     fn migrate(&self) -> Result<(), Error> {
-        migrate(self.pool.get()?, migrations::MIGRATIONS)?;
+        migrate(self.pool.get()?.deref_mut(), migrations::MIGRATIONS)?;
         Ok(())
     }
 }