Преглед на файлове

Improve the sql database traits

Cesar Rodas преди 6 месеца
родител
ревизия
6ba3899ec0
променени са 3 файла, в които са добавени 81 реда и са изтрити 44 реда
  1. 3 1
      crates/cdk-sql-common/src/database.rs
  2. 71 39
      crates/cdk-sql-common/src/mint/mod.rs
  3. 7 4
      crates/cdk-sql-common/src/pool.rs

+ 3 - 1
crates/cdk-sql-common/src/database.rs

@@ -49,5 +49,7 @@ pub trait DatabaseConnector: Debug + DatabaseExecutor + Send + Sync {
         Self: 'a;
 
     /// Begin a new transaction
-    async fn begin(&self) -> Result<Self::Transaction<'_>, Error>;
+    async fn begin(&mut self) -> Result<Self::Transaction<'_>, Error>
+    where
+        Self: Sized;
 }

+ 71 - 39
crates/cdk-sql-common/src/mint/mod.rs

@@ -11,6 +11,7 @@
 use std::collections::HashMap;
 use std::marker::PhantomData;
 use std::str::FromStr;
+use std::sync::Arc;
 
 use async_trait::async_trait;
 use bitcoin::bip32::DerivationPath;
@@ -39,6 +40,7 @@ use uuid::Uuid;
 
 use crate::common::migrate;
 use crate::database::{DatabaseConnector, DatabaseExecutor, DatabaseTransaction};
+use crate::pool::{Pool, ResourceManager};
 use crate::stmt::{query, Column};
 use crate::{
     column_as_nullable_number, column_as_nullable_string, column_as_number, column_as_string,
@@ -57,11 +59,12 @@ pub use auth::SQLMintAuthDatabase;
 
 /// Mint SQL Database
 #[derive(Debug, Clone)]
-pub struct SQLMintDatabase<DB>
+pub struct SQLMintDatabase<T, DB>
 where
     DB: DatabaseConnector,
+    T: ResourceManager<Resource = DB>,
 {
-    db: DB,
+    db: Arc<Pool<T>>,
 }
 
 /// SQL Transaction Writer
@@ -115,9 +118,10 @@ where
     Ok(())
 }
 
-impl<DB> SQLMintDatabase<DB>
+impl<T, DB> SQLMintDatabase<T, DB>
 where
     DB: DatabaseConnector,
+    T: ResourceManager<Resource = DB>,
 {
     /// Creates a new instance
     pub async fn new<X>(db: X) -> Result<Self, Error>
@@ -130,7 +134,7 @@ where
     }
 
     /// Migrate
-    async fn migrate(conn: &DB) -> Result<(), Error> {
+    async fn migrate(mut conn: DB) -> Result<(), Error> {
         let tx = conn.begin().await?;
         migrate(&tx, DB::name(), MIGRATIONS).await?;
         tx.commit().await?;
@@ -142,9 +146,10 @@ where
     where
         R: serde::de::DeserializeOwned,
     {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         let value = column_as_string!(query(r#"SELECT value FROM config WHERE id = :id LIMIT 1"#)?
             .bind("id", id.to_owned())
-            .pluck(&self.db)
+            .pluck(&*conn)
             .await?
             .ok_or(Error::UnknownQuoteTTL)?);
 
@@ -288,11 +293,13 @@ where
     type Err = Error;
 
     async fn commit(self: Box<Self>) -> Result<(), Error> {
-        Ok(self.inner.commit().await?)
+        self.inner.commit().await?;
+        Ok(())
     }
 
     async fn rollback(self: Box<Self>) -> Result<(), Error> {
-        Ok(self.inner.rollback().await?)
+        self.inner.rollback().await?;
+        Ok(())
     }
 }
 
@@ -416,26 +423,29 @@ where
 }
 
 #[async_trait]
-impl<DB> MintKeysDatabase for SQLMintDatabase<DB>
+impl<T, DB> MintKeysDatabase for SQLMintDatabase<T, DB>
 where
     DB: DatabaseConnector,
+    T: ResourceManager<Resource = DB>,
 {
     type Err = Error;
 
     async fn begin_transaction<'a>(
         &'a self,
     ) -> Result<Box<dyn MintKeyDatabaseTransaction<'a, Error> + Send + Sync + 'a>, Error> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(Box::new(SQLTransaction {
-            inner: self.db.begin().await?,
+            inner: conn.begin().await?,
             _phantom: PhantomData,
         }))
     }
 
     async fn get_active_keyset_id(&self, unit: &CurrencyUnit) -> Result<Option<Id>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(
             query(r#" SELECT id FROM keyset WHERE active = 1 AND unit IS :unit"#)?
                 .bind("unit", unit.to_string())
-                .pluck(&self.db)
+                .pluck(&*conn)
                 .await?
                 .map(|id| match id {
                     Column::Text(text) => Ok(Id::from_str(&text)?),
@@ -447,8 +457,9 @@ where
     }
 
     async fn get_active_keysets(&self) -> Result<HashMap<CurrencyUnit, Id>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(query(r#"SELECT id, unit FROM keyset WHERE active = 1"#)?
-            .fetch_all(&self.db)
+            .fetch_all(&*conn)
             .await?
             .into_iter()
             .map(|row| {
@@ -461,6 +472,7 @@ where
     }
 
     async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(query(
             r#"SELECT
                 id,
@@ -477,13 +489,14 @@ where
                 WHERE id=:id"#,
         )?
         .bind("id", id.to_string())
-        .fetch_one(&self.db)
+        .fetch_one(&*conn)
         .await?
         .map(sql_row_to_keyset_info)
         .transpose()?)
     }
 
     async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(query(
             r#"SELECT
                 id,
@@ -499,7 +512,7 @@ where
                 keyset
             "#,
         )?
-        .fetch_all(&self.db)
+        .fetch_all(&*conn)
         .await?
         .into_iter()
         .map(sql_row_to_keyset_info)
@@ -1028,15 +1041,18 @@ VALUES (:quote_id, :amount, :timestamp);
 }
 
 #[async_trait]
-impl<DB> MintQuotesDatabase for SQLMintDatabase<DB>
+impl<T, DB> MintQuotesDatabase for SQLMintDatabase<T, DB>
 where
     DB: DatabaseConnector,
+    T: ResourceManager<Resource = DB>,
 {
     type Err = Error;
 
     async fn get_mint_quote(&self, quote_id: &Uuid) -> Result<Option<MintQuote>, Self::Err> {
-        let payments = get_mint_quote_payments(&self.db, quote_id).await?;
-        let issuance = get_mint_quote_issuance(&self.db, quote_id).await?;
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
+
+        let payments = get_mint_quote_payments(&*conn, quote_id).await?;
+        let issuance = get_mint_quote_issuance(&*conn, quote_id).await?;
 
         Ok(query(
             r#"
@@ -1058,7 +1074,7 @@ where
             WHERE id = :id"#,
         )?
         .bind("id", quote_id.as_hyphenated().to_string())
-        .fetch_one(&self.db)
+        .fetch_one(&*conn)
         .await?
         .map(|row| sql_row_to_mint_quote(row, payments, issuance))
         .transpose()?)
@@ -1068,6 +1084,7 @@ where
         &self,
         request: &str,
     ) -> Result<Option<MintQuote>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         let mut mint_quote = query(
             r#"
             SELECT
@@ -1088,14 +1105,14 @@ where
             WHERE request = :request"#,
         )?
         .bind("request", request.to_owned())
-        .fetch_one(&self.db)
+        .fetch_one(&*conn)
         .await?
         .map(|row| sql_row_to_mint_quote(row, vec![], vec![]))
         .transpose()?;
 
         if let Some(quote) = mint_quote.as_mut() {
-            let payments = get_mint_quote_payments(&self.db, &quote.id).await?;
-            let issuance = get_mint_quote_issuance(&self.db, &quote.id).await?;
+            let payments = get_mint_quote_payments(&*conn, &quote.id).await?;
+            let issuance = get_mint_quote_issuance(&*conn, &quote.id).await?;
             quote.issuance = issuance;
             quote.payments = payments;
         }
@@ -1107,6 +1124,7 @@ where
         &self,
         request_lookup_id: &PaymentIdentifier,
     ) -> Result<Option<MintQuote>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         let mut mint_quote = query(
             r#"
             SELECT
@@ -1130,15 +1148,15 @@ where
         )?
         .bind("request_lookup_id", request_lookup_id.to_string())
         .bind("request_lookup_id_kind", request_lookup_id.kind())
-        .fetch_one(&self.db)
+        .fetch_one(&*conn)
         .await?
         .map(|row| sql_row_to_mint_quote(row, vec![], vec![]))
         .transpose()?;
 
         // TODO: these should use an sql join so they can be done in one query
         if let Some(quote) = mint_quote.as_mut() {
-            let payments = get_mint_quote_payments(&self.db, &quote.id).await?;
-            let issuance = get_mint_quote_issuance(&self.db, &quote.id).await?;
+            let payments = get_mint_quote_payments(&*conn, &quote.id).await?;
+            let issuance = get_mint_quote_issuance(&*conn, &quote.id).await?;
             quote.issuance = issuance;
             quote.payments = payments;
         }
@@ -1147,6 +1165,7 @@ where
     }
 
     async fn get_mint_quotes(&self) -> Result<Vec<MintQuote>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         let mut mint_quotes = query(
             r#"
             SELECT
@@ -1166,15 +1185,15 @@ where
                 mint_quote
             "#,
         )?
-        .fetch_all(&self.db)
+        .fetch_all(&*conn)
         .await?
         .into_iter()
         .map(|row| sql_row_to_mint_quote(row, vec![], vec![]))
         .collect::<Result<Vec<_>, _>>()?;
 
         for quote in mint_quotes.as_mut_slice() {
-            let payments = get_mint_quote_payments(&self.db, &quote.id).await?;
-            let issuance = get_mint_quote_issuance(&self.db, &quote.id).await?;
+            let payments = get_mint_quote_payments(&*conn, &quote.id).await?;
+            let issuance = get_mint_quote_issuance(&*conn, &quote.id).await?;
             quote.issuance = issuance;
             quote.payments = payments;
         }
@@ -1183,6 +1202,7 @@ where
     }
 
     async fn get_melt_quote(&self, quote_id: &Uuid) -> Result<Option<mint::MeltQuote>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(query(
             r#"
             SELECT
@@ -1207,13 +1227,14 @@ where
             "#,
         )?
         .bind("id", quote_id.as_hyphenated().to_string())
-        .fetch_one(&self.db)
+        .fetch_one(&*conn)
         .await?
         .map(sql_row_to_melt_quote)
         .transpose()?)
     }
 
     async fn get_melt_quotes(&self) -> Result<Vec<mint::MeltQuote>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(query(
             r#"
             SELECT
@@ -1235,7 +1256,7 @@ where
                 melt_quote
             "#,
         )?
-        .fetch_all(&self.db)
+        .fetch_all(&*conn)
         .await?
         .into_iter()
         .map(sql_row_to_melt_quote)
@@ -1244,13 +1265,15 @@ where
 }
 
 #[async_trait]
-impl<DB> MintProofsDatabase for SQLMintDatabase<DB>
+impl<T, DB> MintProofsDatabase for SQLMintDatabase<T, DB>
 where
     DB: DatabaseConnector,
+    T: ResourceManager<Resource = DB>,
 {
     type Err = Error;
 
     async fn get_proofs_by_ys(&self, ys: &[PublicKey]) -> Result<Vec<Option<Proof>>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         let mut proofs = query(
             r#"
             SELECT
@@ -1267,7 +1290,7 @@ where
             "#,
         )?
         .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
-        .fetch_all(&self.db)
+        .fetch_all(&*conn)
         .await?
         .into_iter()
         .map(|mut row| {
@@ -1286,6 +1309,7 @@ where
     }
 
     async fn get_proof_ys_by_quote_id(&self, quote_id: &Uuid) -> Result<Vec<PublicKey>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(query(
             r#"
             SELECT
@@ -1301,7 +1325,7 @@ where
             "#,
         )?
         .bind("quote_id", quote_id.as_hyphenated().to_string())
-        .fetch_all(&self.db)
+        .fetch_all(&*conn)
         .await?
         .into_iter()
         .map(sql_row_to_proof)
@@ -1310,7 +1334,8 @@ where
     }
 
     async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
-        let mut current_states = get_current_states(&self.db, ys).await?;
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
+        let mut current_states = get_current_states(&*conn, ys).await?;
 
         Ok(ys.iter().map(|y| current_states.remove(y)).collect())
     }
@@ -1319,6 +1344,7 @@ where
         &self,
         keyset_id: &Id,
     ) -> Result<(Proofs, Vec<Option<State>>), Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(query(
             r#"
             SELECT
@@ -1335,7 +1361,7 @@ where
             "#,
         )?
         .bind("keyset_id", keyset_id.to_string())
-        .fetch_all(&self.db)
+        .fetch_all(&*conn)
         .await?
         .into_iter()
         .map(sql_row_to_proof_with_state)
@@ -1436,9 +1462,10 @@ where
 }
 
 #[async_trait]
-impl<DB> MintSignaturesDatabase for SQLMintDatabase<DB>
+impl<T, DB> MintSignaturesDatabase for SQLMintDatabase<T, DB>
 where
     DB: DatabaseConnector,
+    T: ResourceManager<Resource = DB>,
 {
     type Err = Error;
 
@@ -1446,6 +1473,7 @@ where
         &self,
         blinded_messages: &[PublicKey],
     ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         let mut blinded_signatures = query(
             r#"SELECT
                 keyset_id,
@@ -1466,7 +1494,7 @@ where
                 .map(|b_| b_.to_bytes().to_vec())
                 .collect(),
         )
-        .fetch_all(&self.db)
+        .fetch_all(&*conn)
         .await?
         .into_iter()
         .map(|mut row| {
@@ -1490,6 +1518,7 @@ where
         &self,
         keyset_id: &Id,
     ) -> Result<Vec<BlindSignature>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(query(
             r#"
             SELECT
@@ -1505,7 +1534,7 @@ where
             "#,
         )?
         .bind("keyset_id", keyset_id.to_string())
-        .fetch_all(&self.db)
+        .fetch_all(&*conn)
         .await?
         .into_iter()
         .map(sql_row_to_blind_signature)
@@ -1517,6 +1546,7 @@ where
         &self,
         quote_id: &Uuid,
     ) -> Result<Vec<BlindSignature>, Self::Err> {
+        let conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(query(
             r#"
             SELECT
@@ -1532,7 +1562,7 @@ where
             "#,
         )?
         .bind("quote_id", quote_id.to_string())
-        .fetch_all(&self.db)
+        .fetch_all(&*conn)
         .await?
         .into_iter()
         .map(sql_row_to_blind_signature)
@@ -1541,15 +1571,17 @@ where
 }
 
 #[async_trait]
-impl<DB> MintDatabase<Error> for SQLMintDatabase<DB>
+impl<T, DB> MintDatabase<Error> for SQLMintDatabase<T, DB>
 where
     DB: DatabaseConnector,
+    T: ResourceManager<Resource = DB>,
 {
     async fn begin_transaction<'a>(
         &'a self,
     ) -> Result<Box<dyn database::MintTransaction<'a, Error> + Send + Sync + 'a>, Error> {
+        let mut conn = self.db.get().map_err(|e| Error::Database(Box::new(e)))?;
         Ok(Box::new(SQLTransaction {
-            inner: self.db.begin().await?,
+            inner: conn.begin().await?,
             _phantom: PhantomData,
         }))
     }

+ 7 - 4
crates/cdk-sql-common/src/pool.rs

@@ -10,7 +10,10 @@ use std::time::Duration;
 
 /// Pool error
 #[derive(thiserror::Error, Debug)]
-pub enum Error<E> {
+pub enum Error<E>
+where
+    E: std::error::Error + Send + Sync + 'static,
+{
     /// Mutex Poison Error
     #[error("Internal: PoisonError")]
     Poison,
@@ -27,13 +30,13 @@ pub enum Error<E> {
 /// Trait to manage resources
 pub trait ResourceManager: Debug {
     /// The resource to be pooled
-    type Resource: Debug;
+    type Resource: Debug + Send + Sync;
 
     /// The configuration that is needed in order to create the resource
-    type Config: Clone + Debug;
+    type Config: Clone + Debug + Send + Sync;
 
     /// The error the resource may return when creating a new instance
-    type Error: Debug;
+    type Error: Debug + std::error::Error + Send + Sync + 'static;
 
     /// Creates a new resource with a given config.
     ///