|  | @@ -1,6 +1,6 @@
 | 
	
		
			
				|  |  |  //! SQLite Mint
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -use std::collections::HashMap;
 | 
	
		
			
				|  |  | +use std::collections::{HashMap, HashSet};
 | 
	
		
			
				|  |  |  use std::path::Path;
 | 
	
		
			
				|  |  |  use std::str::FromStr;
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -36,6 +36,37 @@ pub struct MintSqliteDatabase {
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  impl MintSqliteDatabase {
 | 
	
		
			
				|  |  | +    /// Check if any proofs are spent
 | 
	
		
			
				|  |  | +    async fn check_for_spent_proofs(
 | 
	
		
			
				|  |  | +        &self,
 | 
	
		
			
				|  |  | +        transaction: &mut sqlx::Transaction<'_, sqlx::Sqlite>,
 | 
	
		
			
				|  |  | +        ys: &[PublicKey],
 | 
	
		
			
				|  |  | +    ) -> Result<bool, database::Error> {
 | 
	
		
			
				|  |  | +        if ys.is_empty() {
 | 
	
		
			
				|  |  | +            return Ok(false);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        let check_sql = format!(
 | 
	
		
			
				|  |  | +            "SELECT state FROM proof WHERE y IN ({}) AND state = 'SPENT'",
 | 
	
		
			
				|  |  | +            std::iter::repeat("?")
 | 
	
		
			
				|  |  | +                .take(ys.len())
 | 
	
		
			
				|  |  | +                .collect::<Vec<_>>()
 | 
	
		
			
				|  |  | +                .join(",")
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        let spent_count = ys
 | 
	
		
			
				|  |  | +            .iter()
 | 
	
		
			
				|  |  | +            .fold(sqlx::query(&check_sql), |query, y| {
 | 
	
		
			
				|  |  | +                query.bind(y.to_bytes().to_vec())
 | 
	
		
			
				|  |  | +            })
 | 
	
		
			
				|  |  | +            .fetch_all(&mut *transaction)
 | 
	
		
			
				|  |  | +            .await
 | 
	
		
			
				|  |  | +            .map_err(Error::from)?
 | 
	
		
			
				|  |  | +            .len();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Ok(spent_count > 0)
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      /// Create new [`MintSqliteDatabase`]
 | 
	
		
			
				|  |  |      pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
 | 
	
		
			
				|  |  |          Ok(Self {
 | 
	
	
		
			
				|  | @@ -858,7 +889,13 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?);
 | 
	
		
			
				|  |  |      ) -> Result<(), Self::Err> {
 | 
	
		
			
				|  |  |          let mut transaction = self.pool.begin().await.map_err(Error::from)?;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        let sql = format!(
 | 
	
		
			
				|  |  | +        if self.check_for_spent_proofs(&mut transaction, ys).await? {
 | 
	
		
			
				|  |  | +            transaction.rollback().await.map_err(Error::from)?;
 | 
	
		
			
				|  |  | +            return Err(Self::Err::AttemptRemoveSpentProof);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // If no proofs are spent, proceed with deletion
 | 
	
		
			
				|  |  | +        let delete_sql = format!(
 | 
	
		
			
				|  |  |              "DELETE FROM proof WHERE y IN ({})",
 | 
	
		
			
				|  |  |              std::iter::repeat("?")
 | 
	
		
			
				|  |  |                  .take(ys.len())
 | 
	
	
		
			
				|  | @@ -867,7 +904,7 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?);
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          ys.iter()
 | 
	
		
			
				|  |  | -            .fold(sqlx::query(&sql), |query, y| {
 | 
	
		
			
				|  |  | +            .fold(sqlx::query(&delete_sql), |query, y| {
 | 
	
		
			
				|  |  |                  query.bind(y.to_bytes().to_vec())
 | 
	
		
			
				|  |  |              })
 | 
	
		
			
				|  |  |              .execute(&mut transaction)
 | 
	
	
		
			
				|  | @@ -1064,16 +1101,23 @@ WHERE keyset_id=?;
 | 
	
		
			
				|  |  |              })
 | 
	
		
			
				|  |  |              .collect::<Result<HashMap<_, _>, _>>()?;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        let states = current_states.values().collect::<HashSet<_>>();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if states.contains(&State::Spent) {
 | 
	
		
			
				|  |  | +            transaction.rollback().await.map_err(Error::from)?;
 | 
	
		
			
				|  |  | +            tracing::warn!("Attempted to update state of spent proof");
 | 
	
		
			
				|  |  | +            return Err(database::Error::AttemptUpdateSpentProof);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // If no proofs are spent, proceed with update
 | 
	
		
			
				|  |  |          let update_sql = format!(
 | 
	
		
			
				|  |  | -            "UPDATE proof SET state = ? WHERE state != ? AND y IN ({})",
 | 
	
		
			
				|  |  | +            "UPDATE proof SET state = ? WHERE y IN ({})",
 | 
	
		
			
				|  |  |              "?,".repeat(ys.len()).trim_end_matches(',')
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          ys.iter()
 | 
	
		
			
				|  |  |              .fold(
 | 
	
		
			
				|  |  | -                sqlx::query(&update_sql)
 | 
	
		
			
				|  |  | -                    .bind(proofs_state.to_string())
 | 
	
		
			
				|  |  | -                    .bind(State::Spent.to_string()),
 | 
	
		
			
				|  |  | +                sqlx::query(&update_sql).bind(proofs_state.to_string()),
 | 
	
		
			
				|  |  |                  |query, y| query.bind(y.to_bytes().to_vec()),
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  |              .execute(&mut transaction)
 | 
	
	
		
			
				|  | @@ -1647,3 +1691,125 @@ fn sqlite_row_to_melt_request(row: SqliteRow) -> Result<(MeltBolt11Request<Uuid>
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      Ok((melt_request, ln_key))
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +#[cfg(test)]
 | 
	
		
			
				|  |  | +mod tests {
 | 
	
		
			
				|  |  | +    use cdk_common::Amount;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    use super::*;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    #[tokio::test]
 | 
	
		
			
				|  |  | +    async fn test_remove_spent_proofs() {
 | 
	
		
			
				|  |  | +        let db = memory::empty().await.unwrap();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Create some test proofs
 | 
	
		
			
				|  |  | +        let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        let proofs = vec![
 | 
	
		
			
				|  |  | +            Proof {
 | 
	
		
			
				|  |  | +                amount: Amount::from(100),
 | 
	
		
			
				|  |  | +                keyset_id: keyset_id.clone(),
 | 
	
		
			
				|  |  | +                secret: Secret::generate(),
 | 
	
		
			
				|  |  | +                c: SecretKey::generate().public_key(),
 | 
	
		
			
				|  |  | +                witness: None,
 | 
	
		
			
				|  |  | +                dleq: None,
 | 
	
		
			
				|  |  | +            },
 | 
	
		
			
				|  |  | +            Proof {
 | 
	
		
			
				|  |  | +                amount: Amount::from(200),
 | 
	
		
			
				|  |  | +                keyset_id: keyset_id.clone(),
 | 
	
		
			
				|  |  | +                secret: Secret::generate(),
 | 
	
		
			
				|  |  | +                c: SecretKey::generate().public_key(),
 | 
	
		
			
				|  |  | +                witness: None,
 | 
	
		
			
				|  |  | +                dleq: None,
 | 
	
		
			
				|  |  | +            },
 | 
	
		
			
				|  |  | +        ];
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Add proofs to database
 | 
	
		
			
				|  |  | +        db.add_proofs(proofs.clone(), None).await.unwrap();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Mark one proof as spent
 | 
	
		
			
				|  |  | +        db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent)
 | 
	
		
			
				|  |  | +            .await
 | 
	
		
			
				|  |  | +            .unwrap();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Try to remove both proofs - should fail because one is spent
 | 
	
		
			
				|  |  | +        let result = db
 | 
	
		
			
				|  |  | +            .remove_proofs(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()], None)
 | 
	
		
			
				|  |  | +            .await;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        assert!(result.is_err());
 | 
	
		
			
				|  |  | +        assert!(matches!(
 | 
	
		
			
				|  |  | +            result.unwrap_err(),
 | 
	
		
			
				|  |  | +            database::Error::AttemptRemoveSpentProof
 | 
	
		
			
				|  |  | +        ));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Verify both proofs still exist
 | 
	
		
			
				|  |  | +        let states = db
 | 
	
		
			
				|  |  | +            .get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()])
 | 
	
		
			
				|  |  | +            .await
 | 
	
		
			
				|  |  | +            .unwrap();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        assert_eq!(states.len(), 2);
 | 
	
		
			
				|  |  | +        assert_eq!(states[0], Some(State::Spent));
 | 
	
		
			
				|  |  | +        assert_eq!(states[1], Some(State::Unspent));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    #[tokio::test]
 | 
	
		
			
				|  |  | +    async fn test_update_spent_proofs() {
 | 
	
		
			
				|  |  | +        let db = memory::empty().await.unwrap();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Create some test proofs
 | 
	
		
			
				|  |  | +        let keyset_id = Id::from_str("00916bbf7ef91a36").unwrap();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        let proofs = vec![
 | 
	
		
			
				|  |  | +            Proof {
 | 
	
		
			
				|  |  | +                amount: Amount::from(100),
 | 
	
		
			
				|  |  | +                keyset_id: keyset_id.clone(),
 | 
	
		
			
				|  |  | +                secret: Secret::generate(),
 | 
	
		
			
				|  |  | +                c: SecretKey::generate().public_key(),
 | 
	
		
			
				|  |  | +                witness: None,
 | 
	
		
			
				|  |  | +                dleq: None,
 | 
	
		
			
				|  |  | +            },
 | 
	
		
			
				|  |  | +            Proof {
 | 
	
		
			
				|  |  | +                amount: Amount::from(200),
 | 
	
		
			
				|  |  | +                keyset_id: keyset_id.clone(),
 | 
	
		
			
				|  |  | +                secret: Secret::generate(),
 | 
	
		
			
				|  |  | +                c: SecretKey::generate().public_key(),
 | 
	
		
			
				|  |  | +                witness: None,
 | 
	
		
			
				|  |  | +                dleq: None,
 | 
	
		
			
				|  |  | +            },
 | 
	
		
			
				|  |  | +        ];
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Add proofs to database
 | 
	
		
			
				|  |  | +        db.add_proofs(proofs.clone(), None).await.unwrap();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Mark one proof as spent
 | 
	
		
			
				|  |  | +        db.update_proofs_states(&[proofs[0].y().unwrap()], State::Spent)
 | 
	
		
			
				|  |  | +            .await
 | 
	
		
			
				|  |  | +            .unwrap();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Try to update both proofs - should fail because one is spent
 | 
	
		
			
				|  |  | +        let result = db
 | 
	
		
			
				|  |  | +            .update_proofs_states(
 | 
	
		
			
				|  |  | +                &[proofs[0].y().unwrap(), proofs[1].y().unwrap()],
 | 
	
		
			
				|  |  | +                State::Reserved,
 | 
	
		
			
				|  |  | +            )
 | 
	
		
			
				|  |  | +            .await;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        assert!(result.is_err());
 | 
	
		
			
				|  |  | +        assert!(matches!(
 | 
	
		
			
				|  |  | +            result.unwrap_err(),
 | 
	
		
			
				|  |  | +            database::Error::AttemptUpdateSpentProof
 | 
	
		
			
				|  |  | +        ));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Verify states haven't changed
 | 
	
		
			
				|  |  | +        let states = db
 | 
	
		
			
				|  |  | +            .get_proofs_states(&[proofs[0].y().unwrap(), proofs[1].y().unwrap()])
 | 
	
		
			
				|  |  | +            .await
 | 
	
		
			
				|  |  | +            .unwrap();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        assert_eq!(states.len(), 2);
 | 
	
		
			
				|  |  | +        assert_eq!(states[0], Some(State::Spent));
 | 
	
		
			
				|  |  | +        assert_eq!(states[1], Some(State::Unspent));
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +}
 |