Kaynağa Gözat

refactor: Add state check before deleting proofs to prevent removing spent proofs

thesimplekid (aider) 1 ay önce
ebeveyn
işleme
d41d3a7c94

+ 1 - 1
Cargo.toml

@@ -36,7 +36,7 @@ lightning-invoice = { version = "0.32.0", features = ["serde", "std"] }
 serde = { version = "1", features = ["derive"] }
 serde_json = "1"
 thiserror = { version = "1" }
-tokio = { version = "1", default-features = false }
+tokio = { version = "1", default-features = false, features = ["rt", "macros", "test-util"] }
 tokio-util = { version = "0.7.11", default-features = false }
 tower-http = { version = "0.6.1", features = ["compression-full", "decompression-full", "cors", "trace"] }
 tokio-tungstenite = { version = "0.26.0", default-features = false }

+ 6 - 0
crates/cdk-common/src/database/mod.rs

@@ -31,4 +31,10 @@ pub enum Error {
     /// Unknown Quote
     #[error("Unknown Quote")]
     UnknownQuote,
+    /// Attempt to remove spent proof
+    #[error("Attempt to remove spent proof")]
+    AttemptRemoveSpentProof,
+    /// Attempt to update state of spent proof
+    #[error("Attempt to update state of spent proof")]
+    AttemptUpdateSpentProof,
 }

+ 4 - 0
crates/cdk-redb/Cargo.toml

@@ -25,3 +25,7 @@ serde.workspace = true
 serde_json.workspace = true
 lightning-invoice.workspace = true
 uuid.workspace = true
+
+[dev-dependencies]
+tempfile = "3.17.1"
+tokio.workspace = true

+ 181 - 26
crates/cdk-redb/src/mint/mod.rs

@@ -1,7 +1,7 @@
 //! SQLite Storage for CDK
 
 use std::cmp::Ordering;
-use std::collections::HashMap;
+use std::collections::{HashMap, HashSet};
 use std::path::Path;
 use std::str::FromStr;
 use std::sync::Arc;
@@ -558,22 +558,36 @@ impl MintDatabase for MintRedbDatabase {
     ) -> Result<(), Self::Err> {
         let write_txn = self.db.begin_write().map_err(Error::from)?;
 
-        {
-            let mut proofs_table = write_txn.open_table(PROOFS_TABLE).map_err(Error::from)?;
-
-            for y in ys {
-                proofs_table.remove(&y.to_bytes()).map_err(Error::from)?;
-            }
-        }
+        let mut states: HashSet<State> = HashSet::new();
 
         {
             let mut proof_state_table = write_txn
                 .open_table(PROOFS_STATE_TABLE)
                 .map_err(Error::from)?;
             for y in ys {
-                proof_state_table
+                let state = proof_state_table
                     .remove(&y.to_bytes())
                     .map_err(Error::from)?;
+
+                if let Some(state) = state {
+                    let state: State = serde_json::from_str(state.value()).map_err(Error::from)?;
+
+                    states.insert(state);
+                }
+            }
+        }
+
+        if states.contains(&State::Spent) {
+            tracing::warn!("Db attempted to remove spent proof");
+            write_txn.abort().map_err(Error::from)?;
+            return Err(Self::Err::AttemptRemoveSpentProof);
+        }
+
+        {
+            let mut proofs_table = write_txn.open_table(PROOFS_TABLE).map_err(Error::from)?;
+
+            for y in ys {
+                proofs_table.remove(&y.to_bytes()).map_err(Error::from)?;
             }
         }
 
@@ -684,37 +698,44 @@ impl MintDatabase for MintRedbDatabase {
         let write_txn = self.db.begin_write().map_err(Error::from)?;
 
         let mut states = Vec::with_capacity(ys.len());
-
-        let state_str = serde_json::to_string(&proofs_state).map_err(Error::from)?;
-
         {
-            let mut table = write_txn
+            let table = write_txn
                 .open_table(PROOFS_STATE_TABLE)
                 .map_err(Error::from)?;
-
-            for y in ys {
-                let current_state;
-                {
-                    match table.get(y.to_bytes()).map_err(Error::from)? {
+            {
+                // First collect current states
+                for y in ys {
+                    let current_state = match table.get(y.to_bytes()).map_err(Error::from)? {
                         Some(state) => {
-                            current_state =
-                                Some(serde_json::from_str(state.value()).map_err(Error::from)?)
+                            Some(serde_json::from_str(state.value()).map_err(Error::from)?)
                         }
-                        None => current_state = None,
-                    }
+                        None => None,
+                    };
+                    states.push(current_state);
                 }
-                states.push(current_state);
             }
+        }
+
+        // Check if any proofs are spent
+        if states.iter().any(|state| *state == Some(State::Spent)) {
+            write_txn.abort().map_err(Error::from)?;
+            return Err(database::Error::AttemptUpdateSpentProof);
+        }
 
-            for (y, current_state) in ys.iter().zip(&states) {
-                if current_state != &Some(State::Spent) {
+        {
+            let mut table = write_txn
+                .open_table(PROOFS_STATE_TABLE)
+                .map_err(Error::from)?;
+            {
+                // If no proofs are spent, proceed with update
+                let state_str = serde_json::to_string(&proofs_state).map_err(Error::from)?;
+                for y in ys {
                     table
                         .insert(y.to_bytes(), state_str.as_str())
                         .map_err(Error::from)?;
                 }
             }
         }
-
         write_txn.commit().map_err(Error::from)?;
 
         Ok(states)
@@ -924,3 +945,137 @@ impl MintDatabase for MintRedbDatabase {
         Err(Error::UnknownQuoteTTL.into())
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use cdk_common::secret::Secret;
+    use cdk_common::{Amount, SecretKey};
+    use tempfile::tempdir;
+
+    use super::*;
+
+    #[tokio::test]
+    async fn test_remove_spent_proofs() {
+        let tmp_dir = tempdir().unwrap();
+
+        let db = MintRedbDatabase::new(&tmp_dir.path().join("mint.redb")).unwrap();
+        // Create some test proofs
+        let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap();
+
+        let proofs = vec![
+            Proof {
+                amount: Amount::from(100),
+                keyset_id: keyset_id.clone(),
+                secret: Secret::generate(),
+                c: SecretKey::generate().public_key(),
+                witness: None,
+                dleq: None,
+            },
+            Proof {
+                amount: Amount::from(200),
+                keyset_id: keyset_id.clone(),
+                secret: Secret::generate(),
+                c: SecretKey::generate().public_key(),
+                witness: None,
+                dleq: None,
+            },
+        ];
+
+        // Add proofs to database
+        db.add_proofs(proofs.clone(), None).await.unwrap();
+
+        // Mark one proof as spent
+        db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent)
+            .await
+            .unwrap();
+
+        db.update_proofs_states(&[proofs[1].y().unwrap()], State::Unspent)
+            .await
+            .unwrap();
+
+        // Try to remove both proofs - should fail because one is spent
+        let result = db
+            .remove_proofs(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()], None)
+            .await;
+
+        assert!(result.is_err());
+        assert!(matches!(
+            result.unwrap_err(),
+            database::Error::AttemptRemoveSpentProof
+        ));
+
+        // Verify both proofs still exist
+        let states = db
+            .get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()])
+            .await
+            .unwrap();
+
+        assert_eq!(states.len(), 2);
+        assert_eq!(states[0], Some(State::Spent));
+        assert_eq!(states[1], Some(State::Unspent));
+    }
+
+    #[tokio::test]
+    async fn test_update_spent_proofs() {
+        let tmp_dir = tempdir().unwrap();
+
+        let db = MintRedbDatabase::new(&tmp_dir.path().join("mint.redb")).unwrap();
+        // Create some test proofs
+        let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap();
+
+        let proofs = vec![
+            Proof {
+                amount: Amount::from(100),
+                keyset_id: keyset_id.clone(),
+                secret: Secret::generate(),
+                c: SecretKey::generate().public_key(),
+                witness: None,
+                dleq: None,
+            },
+            Proof {
+                amount: Amount::from(200),
+                keyset_id: keyset_id.clone(),
+                secret: Secret::generate(),
+                c: SecretKey::generate().public_key(),
+                witness: None,
+                dleq: None,
+            },
+        ];
+
+        // Add proofs to database
+        db.add_proofs(proofs.clone(), None).await.unwrap();
+
+        // Mark one proof as spent
+        db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent)
+            .await
+            .unwrap();
+
+        db.update_proofs_states(&[proofs[1].y().unwrap()], State::Unspent)
+            .await
+            .unwrap();
+
+        // Mark one proof as spent
+        let result = db
+            .update_proofs_states(
+                &[proofs[0].y().unwrap(), proofs[1].y().unwrap()],
+                State::Unspent,
+            )
+            .await;
+
+        assert!(result.is_err());
+        assert!(matches!(
+            result.unwrap_err(),
+            database::Error::AttemptUpdateSpentProof
+        ));
+
+        // Verify both proofs still exist
+        let states = db
+            .get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()])
+            .await
+            .unwrap();
+
+        assert_eq!(states.len(), 2);
+        assert_eq!(states[0], Some(State::Spent));
+        assert_eq!(states[1], Some(State::Unspent));
+    }
+}

+ 173 - 7
crates/cdk-sqlite/src/mint/mod.rs

@@ -1,6 +1,6 @@
 //! SQLite Mint
 
-use std::collections::HashMap;
+use std::collections::{HashMap, HashSet};
 use std::path::Path;
 use std::str::FromStr;
 
@@ -36,6 +36,37 @@ pub struct MintSqliteDatabase {
 }
 
 impl MintSqliteDatabase {
+    /// Check if any proofs are spent
+    async fn check_for_spent_proofs(
+        &self,
+        transaction: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
+        ys: &[PublicKey],
+    ) -> Result<bool, database::Error> {
+        if ys.is_empty() {
+            return Ok(false);
+        }
+
+        let check_sql = format!(
+            "SELECT state FROM proof WHERE y IN ({}) AND state = 'SPENT'",
+            std::iter::repeat("?")
+                .take(ys.len())
+                .collect::<Vec<_>>()
+                .join(",")
+        );
+
+        let spent_count = ys
+            .iter()
+            .fold(sqlx::query(&check_sql), |query, y| {
+                query.bind(y.to_bytes().to_vec())
+            })
+            .fetch_all(&mut *transaction)
+            .await
+            .map_err(Error::from)?
+            .len();
+
+        Ok(spent_count > 0)
+    }
+
     /// Create new [`MintSqliteDatabase`]
     pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
         Ok(Self {
@@ -858,7 +889,13 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?);
     ) -> Result<(), Self::Err> {
         let mut transaction = self.pool.begin().await.map_err(Error::from)?;
 
-        let sql = format!(
+        if self.check_for_spent_proofs(&mut transaction, ys).await? {
+            transaction.rollback().await.map_err(Error::from)?;
+            return Err(Self::Err::AttemptRemoveSpentProof);
+        }
+
+        // If no proofs are spent, proceed with deletion
+        let delete_sql = format!(
             "DELETE FROM proof WHERE y IN ({})",
             std::iter::repeat("?")
                 .take(ys.len())
@@ -867,7 +904,7 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?);
         );
 
         ys.iter()
-            .fold(sqlx::query(&sql), |query, y| {
+            .fold(sqlx::query(&delete_sql), |query, y| {
                 query.bind(y.to_bytes().to_vec())
             })
             .execute(&mut transaction)
@@ -1064,16 +1101,23 @@ WHERE keyset_id=?;
             })
             .collect::<Result<HashMap<_, _>, _>>()?;
 
+        let states = current_states.values().collect::<HashSet<_>>();
+
+        if states.contains(&State::Spent) {
+            transaction.rollback().await.map_err(Error::from)?;
+            tracing::warn!("Attempted to update state of spent proof");
+            return Err(database::Error::AttemptUpdateSpentProof);
+        }
+
+        // If no proofs are spent, proceed with update
         let update_sql = format!(
-            "UPDATE proof SET state = ? WHERE state != ? AND y IN ({})",
+            "UPDATE proof SET state = ? WHERE y IN ({})",
             "?,".repeat(ys.len()).trim_end_matches(',')
         );
 
         ys.iter()
             .fold(
-                sqlx::query(&update_sql)
-                    .bind(proofs_state.to_string())
-                    .bind(State::Spent.to_string()),
+                sqlx::query(&update_sql).bind(proofs_state.to_string()),
                 |query, y| query.bind(y.to_bytes().to_vec()),
             )
             .execute(&mut transaction)
@@ -1647,3 +1691,125 @@ fn sqlite_row_to_melt_request(row: SqliteRow) -> Result<(MeltBolt11Request<Uuid>
 
     Ok((melt_request, ln_key))
 }
+
+#[cfg(test)]
+mod tests {
+    use cdk_common::Amount;
+
+    use super::*;
+
+    #[tokio::test]
+    async fn test_remove_spent_proofs() {
+        let db = memory::empty().await.unwrap();
+
+        // Create some test proofs
+        let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap();
+
+        let proofs = vec![
+            Proof {
+                amount: Amount::from(100),
+                keyset_id: keyset_id.clone(),
+                secret: Secret::generate(),
+                c: SecretKey::generate().public_key(),
+                witness: None,
+                dleq: None,
+            },
+            Proof {
+                amount: Amount::from(200),
+                keyset_id: keyset_id.clone(),
+                secret: Secret::generate(),
+                c: SecretKey::generate().public_key(),
+                witness: None,
+                dleq: None,
+            },
+        ];
+
+        // Add proofs to database
+        db.add_proofs(proofs.clone(), None).await.unwrap();
+
+        // Mark one proof as spent
+        db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent)
+            .await
+            .unwrap();
+
+        // Try to remove both proofs - should fail because one is spent
+        let result = db
+            .remove_proofs(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()], None)
+            .await;
+
+        assert!(result.is_err());
+        assert!(matches!(
+            result.unwrap_err(),
+            database::Error::AttemptRemoveSpentProof
+        ));
+
+        // Verify both proofs still exist
+        let states = db
+            .get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()])
+            .await
+            .unwrap();
+
+        assert_eq!(states.len(), 2);
+        assert_eq!(states[0], Some(State::Spent));
+        assert_eq!(states[1], Some(State::Unspent));
+    }
+
+    #[tokio::test]
+    async fn test_update_spent_proofs() {
+        let db = memory::empty().await.unwrap();
+
+        // Create some test proofs
+        let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap();
+
+        let proofs = vec![
+            Proof {
+                amount: Amount::from(100),
+                keyset_id: keyset_id.clone(),
+                secret: Secret::generate(),
+                c: SecretKey::generate().public_key(),
+                witness: None,
+                dleq: None,
+            },
+            Proof {
+                amount: Amount::from(200),
+                keyset_id: keyset_id.clone(),
+                secret: Secret::generate(),
+                c: SecretKey::generate().public_key(),
+                witness: None,
+                dleq: None,
+            },
+        ];
+
+        // Add proofs to database
+        db.add_proofs(proofs.clone(), None).await.unwrap();
+
+        // Mark one proof as spent
+        db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent)
+            .await
+            .unwrap();
+
+        // Try to update both proofs - should fail because one is spent
+        let result = db
+            .update_proofs_states(
+                &[proofs[0].y().unwrap(), proofs[1].y().unwrap()],
+                State::Reserved,
+            )
+            .await;
+
+        assert!(result.is_err());
+        assert!(matches!(
+            result.unwrap_err(),
+            database::Error::AttemptUpdateSpentProof
+        ));
+
+        // Verify states haven't changed
+        let states = db
+            .get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()])
+            .await
+            .unwrap();
+
+        assert_eq!(states.len(), 2);
+        assert_eq!(states[0], Some(State::Spent));
+        assert_eq!(states[1], Some(State::Unspent));
+    }
+}

+ 10 - 5
crates/cdk/src/mint/check_spendable.rs

@@ -3,7 +3,7 @@ use std::collections::HashSet;
 use tracing::instrument;
 
 use super::{CheckStateRequest, CheckStateResponse, Mint, ProofState, PublicKey, State};
-use crate::Error;
+use crate::{cdk_database, Error};
 
 impl Mint {
     /// Check state
@@ -41,10 +41,15 @@ impl Mint {
         ys: &[PublicKey],
         proof_state: State,
     ) -> Result<(), Error> {
-        let original_proofs_state = self
-            .localstore
-            .update_proofs_states(ys, proof_state)
-            .await?;
+        let original_proofs_state =
+            match self.localstore.update_proofs_states(ys, proof_state).await {
+                Ok(states) => states,
+                Err(cdk_database::Error::AttemptUpdateSpentProof)
+                | Err(cdk_database::Error::AttemptRemoveSpentProof) => {
+                    return Err(Error::TokenAlreadySpent)
+                }
+                Err(err) => return Err(err.into()),
+            };
 
         let proofs_state = original_proofs_state
             .iter()