Selaa lähdekoodia

Add transaction to the auth db traits

Cesar Rodas 3 kuukautta sitten
vanhempi
säilyke
fde234387c

+ 46 - 28
crates/cdk-common/src/database/mint/auth/mod.rs

@@ -5,61 +5,79 @@ use std::collections::HashMap;
 use async_trait::async_trait;
 use cashu::{AuthRequired, ProtectedEndpoint};
 
+use super::DbTransactionFinalizer;
 use crate::database::Error;
 use crate::mint::MintKeySetInfo;
 use crate::nuts::nut07::State;
 use crate::nuts::{AuthProof, BlindSignature, Id, PublicKey};
 
-/// Mint Database trait
+/// Mint Database transaction
 #[async_trait]
-pub trait MintAuthDatabase {
-    /// Mint Database Error
-    type Err: Into<Error> + From<Error>;
+pub trait MintAuthTransaction<Error>: DbTransactionFinalizer<Err = Error> {
     /// Add Active Keyset
-    async fn set_active_keyset(&self, id: Id) -> Result<(), Self::Err>;
-    /// Get Active Keyset
-    async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err>;
+    async fn set_active_keyset(&mut self, id: Id) -> Result<(), Error>;
 
     /// Add [`MintKeySetInfo`]
-    async fn add_keyset_info(&self, keyset: MintKeySetInfo) -> Result<(), Self::Err>;
-    /// Get [`MintKeySetInfo`]
-    async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err>;
-    /// Get [`MintKeySetInfo`]s
-    async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err>;
+    async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), Error>;
 
     /// Add spent [`AuthProof`]
-    async fn add_proof(&self, proof: AuthProof) -> Result<(), Self::Err>;
-    /// Get [`AuthProof`] state
-    async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err>;
+    async fn add_proof(&mut self, proof: AuthProof) -> Result<(), Error>;
+
     /// Update [`AuthProof`]s state
     async fn update_proof_state(
-        &self,
+        &mut self,
         y: &PublicKey,
         proofs_state: State,
-    ) -> Result<Option<State>, Self::Err>;
+    ) -> Result<Option<State>, Error>;
 
     /// Add [`BlindSignature`]
     async fn add_blind_signatures(
-        &self,
+        &mut self,
         blinded_messages: &[PublicKey],
         blind_signatures: &[BlindSignature],
-    ) -> Result<(), Self::Err>;
-    /// Get [`BlindSignature`]s
-    async fn get_blind_signatures(
-        &self,
-        blinded_messages: &[PublicKey],
-    ) -> Result<Vec<Option<BlindSignature>>, Self::Err>;
+    ) -> Result<(), Error>;
 
     /// Add protected endpoints
     async fn add_protected_endpoints(
-        &self,
+        &mut self,
         protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
-    ) -> Result<(), Self::Err>;
+    ) -> Result<(), Error>;
+
     /// Removed Protected endpoints
     async fn remove_protected_endpoints(
-        &self,
+        &mut self,
         protected_endpoints: Vec<ProtectedEndpoint>,
-    ) -> Result<(), Self::Err>;
+    ) -> Result<(), Error>;
+}
+
+/// Mint Database trait
+#[async_trait]
+pub trait MintAuthDatabase {
+    /// Mint Database Error
+    type Err: Into<Error> + From<Error>;
+
+    /// Begins a transaction
+    async fn begin_transaction<'a>(
+        &'a self,
+    ) -> Result<Box<dyn MintAuthTransaction<Self::Err> + Send + Sync + 'a>, Self::Err>;
+
+    /// Get Active Keyset
+    async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err>;
+
+    /// Get [`MintKeySetInfo`]
+    async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err>;
+    /// Get [`MintKeySetInfo`]s
+    async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err>;
+
+    /// Get [`AuthProof`] state
+    async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err>;
+
+    /// Get [`BlindSignature`]s
+    async fn get_blind_signatures(
+        &self,
+        blinded_messages: &[PublicKey],
+    ) -> Result<Vec<Option<BlindSignature>>, Self::Err>;
+
     /// Get auth for protected_endpoint
     async fn get_auth_for_endpoint(
         &self,

+ 1 - 1
crates/cdk-common/src/database/mint/mod.rs

@@ -21,7 +21,7 @@ mod auth;
 pub mod test;
 
 #[cfg(feature = "auth")]
-pub use auth::MintAuthDatabase;
+pub use auth::{MintAuthDatabase, MintAuthTransaction};
 
 /// KeysDatabaseWriter
 #[async_trait]

+ 2 - 2
crates/cdk-common/src/database/mod.rs

@@ -5,8 +5,6 @@ pub mod mint;
 #[cfg(feature = "wallet")]
 mod wallet;
 
-#[cfg(all(feature = "mint", feature = "auth"))]
-pub use mint::MintAuthDatabase;
 #[cfg(feature = "mint")]
 pub use mint::{
     Database as MintDatabase, DbTransactionFinalizer as MintDbWriterFinalizer,
@@ -16,6 +14,8 @@ pub use mint::{
     SignaturesDatabase as MintSignaturesDatabase,
     SignaturesTransaction as MintSignatureTransaction, Transaction as MintTransaction,
 };
+#[cfg(all(feature = "mint", feature = "auth"))]
+pub use mint::{MintAuthDatabase, MintAuthTransaction};
 #[cfg(feature = "wallet")]
 pub use wallet::Database as WalletDatabase;
 

+ 6 - 6
crates/cdk-integration-tests/src/init_auth_mint.rs

@@ -71,9 +71,9 @@ where
                 acc
             });
 
-    auth_database
-        .add_protected_endpoints(blind_auth_endpoints)
-        .await?;
+    let mut tx = auth_database.begin_transaction().await?;
+
+    tx.add_protected_endpoints(blind_auth_endpoints).await?;
 
     let mut clear_auth_endpoint = HashMap::new();
     clear_auth_endpoint.insert(
@@ -81,9 +81,9 @@ where
         AuthRequired::Clear,
     );
 
-    auth_database
-        .add_protected_endpoints(clear_auth_endpoint)
-        .await?;
+    tx.add_protected_endpoints(clear_auth_endpoint).await?;
+
+    tx.commit().await?;
 
     mint_builder = mint_builder.with_auth_localstore(Arc::new(auth_database));
 

+ 5 - 6
crates/cdk-mintd/src/main.rs

@@ -526,12 +526,11 @@ async fn main() -> anyhow::Result<()> {
 
         mint_builder = mint_builder.set_blind_auth_settings(auth_settings.mint_max_bat);
 
-        auth_localstore
-            .remove_protected_endpoints(unprotected_endpoints)
-            .await?;
-        auth_localstore
-            .add_protected_endpoints(protected_endpoints)
-            .await?;
+        let mut tx = auth_localstore.begin_transaction().await?;
+
+        tx.remove_protected_endpoints(unprotected_endpoints).await?;
+        tx.add_protected_endpoints(protected_endpoints).await?;
+        tx.commit().await?;
     }
 
     let mint = mint_builder.build().await?;

+ 2 - 2
crates/cdk-signatory/src/common.rs

@@ -25,7 +25,7 @@ pub async fn init_keysets(
     // Get keysets info from DB
     let keysets_infos = localstore.get_keyset_infos().await?;
 
-    let mut tx = localstore.begin_transaction().await.expect("begin");
+    let mut tx = localstore.begin_transaction().await?;
     if !keysets_infos.is_empty() {
         tracing::debug!("Setting all saved keysets to inactive");
         for keyset in keysets_infos.clone() {
@@ -114,7 +114,7 @@ pub async fn init_keysets(
         }
     }
 
-    tx.commit().await.expect("commit");
+    tx.commit().await?;
 
     Ok((active_keysets, active_keyset_units))
 }

+ 154 - 154
crates/cdk-sqlite/src/mint/auth/mod.rs

@@ -6,14 +6,14 @@ use std::path::Path;
 use std::str::FromStr;
 
 use async_trait::async_trait;
-use cdk_common::database::{self, MintAuthDatabase};
+use cdk_common::database::{self, MintAuthDatabase, MintAuthTransaction};
 use cdk_common::mint::MintKeySetInfo;
 use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
 use cdk_common::{AuthRequired, ProtectedEndpoint};
 use tracing::instrument;
 
 use super::async_rusqlite::AsyncRusqlite;
-use super::{sqlite_row_to_blind_signature, sqlite_row_to_keyset_info};
+use super::{sqlite_row_to_blind_signature, sqlite_row_to_keyset_info, SqliteTransaction};
 use crate::column_as_string;
 use crate::common::{create_sqlite_pool, migrate};
 use crate::mint::async_rusqlite::query;
@@ -56,11 +56,9 @@ impl MintSqliteAuthDatabase {
 }
 
 #[async_trait]
-impl MintAuthDatabase for MintSqliteAuthDatabase {
-    type Err = database::Error;
-
+impl MintAuthTransaction<database::Error> for SqliteTransaction<'_> {
     #[instrument(skip(self))]
-    async fn set_active_keyset(&self, id: Id) -> Result<(), Self::Err> {
+    async fn set_active_keyset(&mut self, id: Id) -> Result<(), database::Error> {
         tracing::info!("Setting auth keyset {id} active");
         query(
             r#"
@@ -72,30 +70,13 @@ impl MintAuthDatabase for MintSqliteAuthDatabase {
             "#,
         )
         .bind(":id", id.to_string())
-        .execute(&self.pool)
+        .execute(&self.transaction)
         .await?;
 
         Ok(())
     }
 
-    async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
-        Ok(query(
-            r#"
-            SELECT
-                id
-            FROM
-                keyset
-            WHERE
-                active = 1;
-            "#,
-        )
-        .pluck(&self.pool)
-        .await?
-        .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
-        .transpose()?)
-    }
-
-    async fn add_keyset_info(&self, keyset: MintKeySetInfo) -> Result<(), Self::Err> {
+    async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), database::Error> {
         query(
             r#"
         INSERT INTO
@@ -125,12 +106,159 @@ impl MintAuthDatabase for MintSqliteAuthDatabase {
         .bind(":derivation_path", keyset.derivation_path.to_string())
         .bind(":max_order", keyset.max_order)
         .bind(":derivation_path_index", keyset.derivation_path_index)
-        .execute(&self.pool)
+        .execute(&self.transaction)
         .await?;
 
         Ok(())
     }
 
+    async fn add_proof(&mut self, proof: AuthProof) -> Result<(), database::Error> {
+        if let Err(err) = query(
+            r#"
+                INSERT INTO proof
+                (y, keyset_id, secret, c, state)
+                VALUES
+                (:y, :keyset_id, :secret, :c, :state)
+                "#,
+        )
+        .bind(":y", proof.y()?.to_bytes().to_vec())
+        .bind(":keyset_id", proof.keyset_id.to_string())
+        .bind(":secret", proof.secret.to_string())
+        .bind(":c", proof.c.to_bytes().to_vec())
+        .bind(":state", "UNSPENT".to_string())
+        .execute(&self.transaction)
+        .await
+        {
+            tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
+        }
+        Ok(())
+    }
+
+    async fn update_proof_state(
+        &mut self,
+        y: &PublicKey,
+        proofs_state: State,
+    ) -> Result<Option<State>, Self::Err> {
+        let current_state = query(r#"SELECT state FROM proof WHERE y = :y"#)
+            .bind(":y", y.to_bytes().to_vec())
+            .pluck(&self.transaction)
+            .await?
+            .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
+            .transpose()?;
+
+        query(r#"UPDATE proof SET state = :new_state WHERE state = :state AND y = :y"#)
+            .bind(":y", y.to_bytes().to_vec())
+            .bind(
+                ":state",
+                current_state.as_ref().map(|state| state.to_string()),
+            )
+            .bind(":new_state", proofs_state.to_string())
+            .execute(&self.transaction)
+            .await?;
+
+        Ok(current_state)
+    }
+
+    async fn add_blind_signatures(
+        &mut self,
+        blinded_messages: &[PublicKey],
+        blind_signatures: &[BlindSignature],
+    ) -> Result<(), database::Error> {
+        for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
+            query(
+                r#"
+                       INSERT
+                       INTO blind_signature
+                       (y, amount, keyset_id, c)
+                       VALUES
+                       (:y, :amount, :keyset_id, :c)
+                   "#,
+            )
+            .bind(":y", message.to_bytes().to_vec())
+            .bind(":amount", u64::from(signature.amount) as i64)
+            .bind(":keyset_id", signature.keyset_id.to_string())
+            .bind(":c", signature.c.to_bytes().to_vec())
+            .execute(&self.transaction)
+            .await?;
+        }
+
+        Ok(())
+    }
+
+    async fn add_protected_endpoints(
+        &mut self,
+        protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
+    ) -> Result<(), database::Error> {
+        for (endpoint, auth) in protected_endpoints.iter() {
+            if let Err(err) = query(
+                r#"
+                 INSERT OR REPLACE INTO protected_endpoints
+                 (endpoint, auth)
+                 VALUES (:endpoint, :auth);
+                 "#,
+            )
+            .bind(":endpoint", serde_json::to_string(endpoint)?)
+            .bind(":auth", serde_json::to_string(auth)?)
+            .execute(&self.transaction)
+            .await
+            {
+                tracing::debug!(
+                    "Attempting to add protected endpoint. Skipping.... {:?}",
+                    err
+                );
+            }
+        }
+
+        Ok(())
+    }
+    async fn remove_protected_endpoints(
+        &mut self,
+        protected_endpoints: Vec<ProtectedEndpoint>,
+    ) -> Result<(), database::Error> {
+        query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)
+            .bind_vec(
+                ":endpoints",
+                protected_endpoints
+                    .iter()
+                    .map(serde_json::to_string)
+                    .collect::<Result<_, _>>()?,
+            )
+            .execute(&self.transaction)
+            .await?;
+        Ok(())
+    }
+}
+
+#[async_trait]
+impl MintAuthDatabase for MintSqliteAuthDatabase {
+    type Err = database::Error;
+
+    async fn begin_transaction<'a>(
+        &'a self,
+    ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
+    {
+        Ok(Box::new(SqliteTransaction {
+            transaction: self.pool.begin().await?,
+        }))
+    }
+
+    async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
+        Ok(query(
+            r#"
+            SELECT
+                id
+            FROM
+                keyset
+            WHERE
+                active = 1;
+            "#,
+        )
+        .pluck(&self.pool)
+        .await?
+        .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
+        .transpose()?)
+    }
+
     async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
         Ok(query(
             r#"SELECT
@@ -177,28 +305,6 @@ impl MintAuthDatabase for MintSqliteAuthDatabase {
         .collect::<Result<Vec<_>, _>>()?)
     }
 
-    async fn add_proof(&self, proof: AuthProof) -> Result<(), Self::Err> {
-        if let Err(err) = query(
-            r#"
-            INSERT INTO proof
-            (y, keyset_id, secret, c, state)
-            VALUES
-            (:y, :keyset_id, :secret, :c, :state)
-            "#,
-        )
-        .bind(":y", proof.y()?.to_bytes().to_vec())
-        .bind(":keyset_id", proof.keyset_id.to_string())
-        .bind(":secret", proof.secret.to_string())
-        .bind(":c", proof.c.to_bytes().to_vec())
-        .bind(":state", "UNSPENT".to_string())
-        .execute(&self.pool)
-        .await
-        {
-            tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
-        }
-        Ok(())
-    }
-
     async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
         let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)
             .bind_vec(":ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
@@ -216,65 +322,6 @@ impl MintAuthDatabase for MintSqliteAuthDatabase {
         Ok(ys.iter().map(|y| current_states.remove(y)).collect())
     }
 
-    async fn update_proof_state(
-        &self,
-        y: &PublicKey,
-        proofs_state: State,
-    ) -> Result<Option<State>, Self::Err> {
-        let transaction = self.pool.begin().await?;
-
-        let current_state = query(r#"SELECT state FROM proof WHERE y = :y"#)
-            .bind(":y", y.to_bytes().to_vec())
-            .pluck(&transaction)
-            .await?
-            .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
-            .transpose()?;
-
-        query(r#"UPDATE proof SET state = :new_state WHERE state = :state AND y = :y"#)
-            .bind(":y", y.to_bytes().to_vec())
-            .bind(
-                ":state",
-                current_state.as_ref().map(|state| state.to_string()),
-            )
-            .bind(":new_state", proofs_state.to_string())
-            .execute(&transaction)
-            .await?;
-
-        transaction.commit().await?;
-
-        Ok(current_state)
-    }
-
-    async fn add_blind_signatures(
-        &self,
-        blinded_messages: &[PublicKey],
-        blind_signatures: &[BlindSignature],
-    ) -> Result<(), Self::Err> {
-        let transaction = self.pool.begin().await?;
-
-        for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
-            query(
-                r#"
-                    INSERT
-                    INTO blind_signature
-                    (y, amount, keyset_id, c)
-                    VALUES
-                    (:y, :amount, :keyset_id, :c)
-                "#,
-            )
-            .bind(":y", message.to_bytes().to_vec())
-            .bind(":amount", u64::from(signature.amount) as i64)
-            .bind(":keyset_id", signature.keyset_id.to_string())
-            .bind(":c", signature.c.to_bytes().to_vec())
-            .execute(&transaction)
-            .await?;
-        }
-
-        transaction.commit().await?;
-
-        Ok(())
-    }
-
     async fn get_blind_signatures(
         &self,
         blinded_messages: &[PublicKey],
@@ -319,53 +366,6 @@ impl MintAuthDatabase for MintSqliteAuthDatabase {
             .collect())
     }
 
-    async fn add_protected_endpoints(
-        &self,
-        protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
-    ) -> Result<(), Self::Err> {
-        let transaction = self.pool.begin().await?;
-
-        for (endpoint, auth) in protected_endpoints.iter() {
-            if let Err(err) = query(
-                r#"
-                INSERT OR REPLACE INTO protected_endpoints
-                (endpoint, auth)
-                VALUES (:endpoint, :auth);
-                "#,
-            )
-            .bind(":endpoint", serde_json::to_string(endpoint)?)
-            .bind(":auth", serde_json::to_string(auth)?)
-            .execute(&transaction)
-            .await
-            {
-                tracing::debug!(
-                    "Attempting to add protected endpoint. Skipping.... {:?}",
-                    err
-                );
-            }
-        }
-
-        transaction.commit().await?;
-
-        Ok(())
-    }
-    async fn remove_protected_endpoints(
-        &self,
-        protected_endpoints: Vec<ProtectedEndpoint>,
-    ) -> Result<(), Self::Err> {
-        query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)
-            .bind_vec(
-                ":endpoints",
-                protected_endpoints
-                    .iter()
-                    .map(serde_json::to_string)
-                    .collect::<Result<_, _>>()?,
-            )
-            .execute(&self.pool)
-            .await?;
-        Ok(())
-    }
-
     async fn get_auth_for_endpoint(
         &self,
         protected_endpoint: ProtectedEndpoint,

+ 1 - 24
crates/cdk-sqlite/src/mint/mod.rs

@@ -366,7 +366,7 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
     async fn add_melt_quote(&mut self, quote: mint::MeltQuote) -> Result<(), Self::Err> {
         query(
             r#"
-            INSERT INTO melt_quote
+            INSERT OR REPLACE INTO melt_quote
             (
                 id, unit, amount, request, fee_reserve, state,
                 expiry, payment_preimage, request_lookup_id, msat_to_pay,
@@ -378,29 +378,6 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
                 :expiry, :payment_preimage, :request_lookup_id, :msat_to_pay,
                 :created_time, :paid_time
             )
-            ON CONFLICT(id) DO UPDATE SET
-                unit = excluded.unit,
-                amount = excluded.amount,
-                request = excluded.request,
-                fee_reserve = excluded.fee_reserve,
-                state = excluded.state,
-                expiry = excluded.expiry,
-                payment_preimage = excluded.payment_preimage,
-                request_lookup_id = excluded.request_lookup_id,
-                msat_to_pay = excluded.msat_to_pay,
-                created_time = excluded.created_time,
-                paid_time = excluded.paid_time
-            ON CONFLICT(request_lookup_id) DO UPDATE SET
-                unit = excluded.unit,
-                amount = excluded.amount,
-                request = excluded.request,
-                fee_reserve = excluded.fee_reserve,
-                state = excluded.state,
-                expiry = excluded.expiry,
-                payment_preimage = excluded.payment_preimage,
-                id = excluded.id,
-                created_time = excluded.created_time,
-                paid_time = excluded.paid_time;
         "#,
         )
         .bind(":id", quote.id.to_string())

+ 9 - 8
crates/cdk/src/mint/auth/mod.rs

@@ -163,17 +163,16 @@ impl Mint {
             err
         })?;
 
+        let mut tx = auth_localstore.begin_transaction().await?;
+
         // Add proof to the database
-        auth_localstore
-            .add_proof(proof.clone())
-            .await
-            .map_err(|err| {
-                tracing::error!("Failed to add proof to database: {:?}", err);
-                err
-            })?;
+        tx.add_proof(proof.clone()).await.map_err(|err| {
+            tracing::error!("Failed to add proof to database: {:?}", err);
+            err
+        })?;
 
         // Update proof state to spent
-        let state = match auth_localstore.update_proof_state(&y, State::Spent).await {
+        let state = match tx.update_proof_state(&y, State::Spent).await {
             Ok(state) => {
                 tracing::debug!(
                     "Successfully updated proof state to SPENT, previous state: {:?}",
@@ -205,6 +204,8 @@ impl Mint {
             }
         };
 
+        tx.commit().await?;
+
         Ok(())
     }