浏览代码

Merge pull request #730 from crodas/fix/race-condition-state-update

Fix race conditions with proof state updates.
C 6 月之前
父节点
当前提交
15e10c0e90

+ 1 - 0
crates/cdk-common/Cargo.toml

@@ -13,6 +13,7 @@ readme = "README.md"
 [features]
 default = ["mint", "wallet"]
 swagger = ["dep:utoipa", "cashu/swagger"]
+test = []
 bench = []
 wallet = ["cashu/wallet"]
 mint = ["cashu/mint", "dep:uuid"]

+ 3 - 0
crates/cdk-common/src/database/mint/mod.rs

@@ -17,6 +17,9 @@ use crate::nuts::{
 #[cfg(feature = "auth")]
 mod auth;
 
+#[cfg(feature = "test")]
+pub mod test;
+
 #[cfg(feature = "auth")]
 pub use auth::MintAuthDatabase;
 

+ 83 - 0
crates/cdk-common/src/database/mint/test.rs

@@ -0,0 +1,83 @@
+//! Macro with default tests
+//!
+//! This set is generic and checks the default and expected behaviour for a mint database
+//! implementation
+use std::fmt::Debug;
+use std::str::FromStr;
+
+use cashu::secret::Secret;
+use cashu::{Amount, CurrencyUnit, SecretKey};
+
+use super::*;
+use crate::mint::MintKeySetInfo;
+
+#[inline]
+async fn setup_keyset<E: Debug, DB: Database<E>>(db: &DB) -> Id {
+    let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap();
+    let keyset_info = MintKeySetInfo {
+        id: keyset_id,
+        unit: CurrencyUnit::Sat,
+        active: true,
+        valid_from: 0,
+        valid_to: None,
+        derivation_path: bitcoin::bip32::DerivationPath::from_str("m/0'/0'/0'").unwrap(),
+        derivation_path_index: Some(0),
+        max_order: 32,
+        input_fee_ppk: 0,
+    };
+    db.add_keyset_info(keyset_info).await.unwrap();
+    keyset_id
+}
+
+/// State transition test
+pub async fn state_transition<E: Debug, DB: Database<E>>(db: DB) {
+    let keyset_id = setup_keyset(&db).await;
+
+    let proofs = vec![
+        Proof {
+            amount: Amount::from(100),
+            keyset_id,
+            secret: Secret::generate(),
+            c: SecretKey::generate().public_key(),
+            witness: None,
+            dleq: None,
+        },
+        Proof {
+            amount: Amount::from(200),
+            keyset_id,
+            secret: Secret::generate(),
+            c: SecretKey::generate().public_key(),
+            witness: None,
+            dleq: None,
+        },
+    ];
+
+    // Add proofs to database
+    db.add_proofs(proofs.clone(), None).await.unwrap();
+
+    // Mark one proof as `pending`
+    assert!(db
+        .update_proofs_states(&[proofs[0].y().unwrap()], State::Pending)
+        .await
+        .is_ok());
+
+    // Attempt to select the `pending` proof, as `pending` again (which should fail)
+    assert!(db
+        .update_proofs_states(&[proofs[0].y().unwrap()], State::Pending)
+        .await
+        .is_err());
+}
+
+/// Unit test that is expected to be passed for a correct database implementation
+#[macro_export]
+macro_rules! mint_db_test {
+    ($make_db_fn:ident) => {
+        mint_db_test!(state_transition, $make_db_fn);
+    };
+    ($name:ident, $make_db_fn:ident) => {
+        #[tokio::test]
+        async fn $name() {
+            cdk_common::database::mint::test::$name($make_db_fn().await).await;
+        }
+    };
+}

+ 15 - 1
crates/cdk-common/src/database/mod.rs

@@ -1,7 +1,7 @@
 //! CDK Database
 
 #[cfg(feature = "mint")]
-mod mint;
+pub mod mint;
 #[cfg(feature = "wallet")]
 mod wallet;
 
@@ -53,4 +53,18 @@ pub enum Error {
     /// Invalid keyset
     #[error("Unknown or invalid keyset")]
     InvalidKeysetId,
+    #[cfg(feature = "mint")]
+    /// Invalid state transition
+    #[error("Invalid state transition")]
+    InvalidStateTransition(crate::state::Error),
+}
+
+#[cfg(feature = "mint")]
+impl From<crate::state::Error> for Error {
+    fn from(state: crate::state::Error) -> Self {
+        match state {
+            crate::state::Error::AlreadySpent => Error::AttemptUpdateSpentProof,
+            _ => Error::InvalidStateTransition(state),
+        }
+    }
 }

+ 22 - 1
crates/cdk-common/src/error.rs

@@ -317,7 +317,7 @@ pub enum Error {
     NUT22(#[from] crate::nuts::nut22::Error),
     /// Database Error
     #[error(transparent)]
-    Database(#[from] crate::database::Error),
+    Database(crate::database::Error),
     /// Payment Error
     #[error(transparent)]
     #[cfg(feature = "mint")]
@@ -502,6 +502,27 @@ impl From<Error> for ErrorResponse {
     }
 }
 
+#[cfg(feature = "mint")]
+impl From<crate::database::Error> for Error {
+    fn from(db_error: crate::database::Error) -> Self {
+        match db_error {
+            crate::database::Error::InvalidStateTransition(state) => match state {
+                crate::state::Error::Pending => Self::TokenPending,
+                crate::state::Error::AlreadySpent => Self::TokenAlreadySpent,
+                state => Self::Database(crate::database::Error::InvalidStateTransition(state)),
+            },
+            db_error => Self::Database(db_error),
+        }
+    }
+}
+
+#[cfg(not(feature = "mint"))]
+impl From<crate::database::Error> for Error {
+    fn from(db_error: crate::database::Error) -> Self {
+        Self::Database(db_error)
+    }
+}
+
 impl From<ErrorResponse> for Error {
     fn from(err: ErrorResponse) -> Error {
         match err.code {

+ 2 - 0
crates/cdk-common/src/lib.rs

@@ -16,6 +16,8 @@ pub mod mint;
 #[cfg(feature = "mint")]
 pub mod payment;
 pub mod pub_sub;
+#[cfg(feature = "mint")]
+pub mod state;
 pub mod subscription;
 #[cfg(feature = "wallet")]
 pub mod wallet;

+ 39 - 0
crates/cdk-common/src/state.rs

@@ -0,0 +1,39 @@
+//! State transition rules
+
+use cashu::State;
+
+/// State transition Error
+#[derive(thiserror::Error, Debug)]
+pub enum Error {
+    /// Pending Token
+    #[error("Token already pending for another update")]
+    Pending,
+    /// Already spent
+    #[error("Token already spent")]
+    AlreadySpent,
+    /// Invalid transition
+    #[error("Invalid transition: From {0} to {1}")]
+    InvalidTransition(State, State),
+}
+
+#[inline]
+/// Check if the state transition is allowed
+pub fn check_state_transition(current_state: State, new_state: State) -> Result<(), Error> {
+    let is_valid_transition = match current_state {
+        State::Unspent => matches!(new_state, State::Pending | State::Spent),
+        State::Pending => matches!(new_state, State::Unspent | State::Spent),
+        // Any other state shouldn't be updated by the mint, and the wallet does not use this
+        // function
+        _ => false,
+    };
+
+    if !is_valid_transition {
+        Err(match current_state {
+            State::Pending => Error::Pending,
+            State::Spent => Error::AlreadySpent,
+            _ => Error::InvalidTransition(current_state, new_state),
+        })
+    } else {
+        Ok(())
+    }
+}

+ 1 - 1
crates/cdk-redb/Cargo.toml

@@ -19,7 +19,7 @@ auth = ["cdk-common/auth"]
 
 [dependencies]
 async-trait.workspace = true
-cdk-common.workspace = true
+cdk-common = { workspace = true, features = ["test"] }
 redb = "2.4.0"
 thiserror.workspace = true
 tracing.workspace = true

+ 15 - 8
crates/cdk-redb/src/mint/mod.rs

@@ -15,6 +15,7 @@ use cdk_common::database::{
 use cdk_common::dhke::hash_to_curve;
 use cdk_common::mint::{self, MintKeySetInfo, MintQuote};
 use cdk_common::nut00::ProofsMethods;
+use cdk_common::state::check_state_transition;
 use cdk_common::util::unix_time;
 use cdk_common::{
     BlindSignature, CurrencyUnit, Id, MeltBolt11Request, MeltQuoteState, MintInfo, MintQuoteState,
@@ -787,21 +788,19 @@ impl MintProofsDatabase for MintRedbDatabase {
                 for y in ys {
                     let current_state = match table.get(y.to_bytes()).map_err(Error::from)? {
                         Some(state) => {
-                            Some(serde_json::from_str(state.value()).map_err(Error::from)?)
+                            let current_state =
+                                serde_json::from_str(state.value()).map_err(Error::from)?;
+                            check_state_transition(current_state, proofs_state)?;
+                            Some(current_state)
                         }
                         None => None,
                     };
+
                     states.push(current_state);
                 }
             }
         }
 
-        // Check if any proofs are spent
-        if states.contains(&Some(State::Spent)) {
-            write_txn.abort().map_err(Error::from)?;
-            return Err(database::Error::AttemptUpdateSpentProof);
-        }
-
         {
             let mut table = write_txn
                 .open_table(PROOFS_STATE_TABLE)
@@ -1007,7 +1006,7 @@ impl MintDatabase<database::Error> for MintRedbDatabase {
 #[cfg(test)]
 mod tests {
     use cdk_common::secret::Secret;
-    use cdk_common::{Amount, SecretKey};
+    use cdk_common::{mint_db_test, Amount, SecretKey};
     use tempfile::tempdir;
 
     use super::*;
@@ -1136,4 +1135,12 @@ mod tests {
         assert_eq!(states[0], Some(State::Spent));
         assert_eq!(states[1], Some(State::Unspent));
     }
+
+    async fn provide_db() -> MintRedbDatabase {
+        let tmp_dir = tempdir().unwrap();
+
+        MintRedbDatabase::new(&tmp_dir.path().join("mint.redb")).unwrap()
+    }
+
+    mint_db_test!(provide_db);
 }

+ 2 - 2
crates/cdk-sqlite/Cargo.toml

@@ -20,11 +20,11 @@ sqlcipher = ["libsqlite3-sys"]
 
 [dependencies]
 async-trait.workspace = true
-cdk-common.workspace = true
+cdk-common = { workspace = true, features = ["test"] }
 bitcoin.workspace = true
 sqlx = { version = "0.7.4", default-features = false, features = [
     "runtime-tokio-rustls",
-    "sqlite", 
+    "sqlite",
     "macros",
     "migrate",
     "uuid",

+ 11 - 9
crates/cdk-sqlite/src/mint/mod.rs

@@ -15,6 +15,7 @@ use cdk_common::mint::{self, MintKeySetInfo, MintQuote};
 use cdk_common::nut00::ProofsMethods;
 use cdk_common::nut05::QuoteState;
 use cdk_common::secret::Secret;
+use cdk_common::state::check_state_transition;
 use cdk_common::util::unix_time;
 use cdk_common::{
     Amount, BlindSignature, BlindSignatureDleq, CurrencyUnit, Id, MeltBolt11Request,
@@ -1311,10 +1312,8 @@ WHERE keyset_id=?;
 
         let states = current_states.values().collect::<HashSet<_>>();
 
-        if states.contains(&State::Spent) {
-            transaction.rollback().await.map_err(Error::from)?;
-            tracing::warn!("Attempted to update state of spent proof");
-            return Err(database::Error::AttemptUpdateSpentProof);
+        for state in states {
+            check_state_transition(*state, proofs_state)?;
         }
 
         // If no proofs are spent, proceed with update
@@ -1843,7 +1842,7 @@ fn sqlite_row_to_melt_request(
 #[cfg(test)]
 mod tests {
     use cdk_common::mint::MintKeySetInfo;
-    use cdk_common::Amount;
+    use cdk_common::{mint_db_test, Amount};
 
     use super::*;
 
@@ -1963,10 +1962,7 @@ mod tests {
 
         // Try to update both proofs - should fail because one is spent
         let result = db
-            .update_proofs_states(
-                &[proofs[0].y().unwrap(), proofs[1].y().unwrap()],
-                State::Reserved,
-            )
+            .update_proofs_states(&[proofs[0].y().unwrap()], State::Unspent)
             .await;
 
         assert!(result.is_err());
@@ -1985,4 +1981,10 @@ mod tests {
         assert_eq!(states[0], Some(State::Spent));
         assert_eq!(states[1], Some(State::Unspent));
     }
+
+    async fn provide_db() -> MintSqliteDatabase {
+        memory::empty().await.unwrap()
+    }
+
+    mint_db_test!(provide_db);
 }

+ 1 - 1
crates/cdk-sqlite/src/wallet/mod.rs

@@ -1116,7 +1116,7 @@ mod tests {
             .await
             .unwrap();
 
-        db.migrate().await;
+        db.migrate().await.unwrap();
 
         let mint_info = MintInfo::new().description("test");
         let mint_url = MintUrl::from_str("https://mint.xyz").unwrap();