Quellcode durchsuchen

Introduce ProofsWithState for atomic proof state management

Replace individual proof state operations with a unified ProofsWithState type
that enforces the invariant that all proofs in a set share the same state. This
shifts responsibility for state consistency to the database layer and
simplifies state transition logic in the saga implementations.

Changes to ProofsTransaction trait:
- add_proofs now returns Acquired<ProofsWithState>
- update_proofs_states(ys, state) -> update_proofs(&mut
  Acquired<ProofsWithState>)
- get_proofs_states(ys) -> get_proofs(ys) returning Acquired<ProofsWithState>

Benefits:
- State transitions validated in memory before persisting
- Eliminates scattered state checking in saga code
- Database layer guarantees proof state consistency
- Cleaner API with state encapsulated in the type
Cesar Rodas vor 1 Monat
Ursprung
Commit
34ca52473d

+ 10 - 8
crates/cdk-common/src/database/mint/mod.rs

@@ -8,7 +8,9 @@ use cashu::Amount;
 
 use super::{DbTransactionFinalizer, Error};
 use crate::database::Acquired;
-use crate::mint::{self, MeltQuote, MintKeySetInfo, MintQuote as MintMintQuote, Operation};
+use crate::mint::{
+    self, MeltQuote, MintKeySetInfo, MintQuote as MintMintQuote, Operation, ProofsWithState,
+};
 use crate::nuts::{
     BlindSignature, BlindedMessage, CurrencyUnit, Id, MeltQuoteState, Proof, Proofs, PublicKey,
     State,
@@ -297,19 +299,19 @@ pub trait ProofsTransaction {
         proof: Proofs,
         quote_id: Option<QuoteId>,
         operation: &Operation,
-    ) -> Result<(), Self::Err>;
+    ) -> Result<Acquired<ProofsWithState>, Self::Err>;
+
     /// Updates the proofs to a given states and return the previous states
-    async fn update_proofs_states(
+    async fn update_proofs(
         &mut self,
-        ys: &[PublicKey],
-        proofs_state: State,
-    ) -> Result<Vec<Option<State>>, Self::Err>;
+        proofs: &mut Acquired<ProofsWithState>,
+    ) -> Result<(), Self::Err>;
 
     /// get proofs states
-    async fn get_proofs_states(
+    async fn get_proofs(
         &mut self,
         ys: &[PublicKey],
-    ) -> Result<Vec<Option<State>>, Self::Err>;
+    ) -> Result<Acquired<ProofsWithState>, Self::Err>;
 
     /// Remove [`Proofs`]
     async fn remove_proofs(

+ 1 - 0
crates/cdk-common/src/database/mint/test/mod.rs

@@ -255,6 +255,7 @@ macro_rules! mint_db_test {
             get_blind_signatures_in_transaction,
             reject_duplicate_payment_ids,
             remove_spent_proofs_should_fail,
+            get_proofs_with_inconsistent_states_fails,
         );
     };
     ($make_db_fn:ident, $($name:ident),+ $(,)?) => {

+ 116 - 20
crates/cdk-common/src/database/mint/test/proofs.rs

@@ -217,7 +217,9 @@ where
 
     // Update to pending
     let mut tx = Database::begin_transaction(&db).await.unwrap();
-    let _old_states = tx.update_proofs_states(&ys, State::Pending).await.unwrap();
+    let mut proofs = tx.get_proofs(&ys).await.unwrap();
+    proofs.set_new_state(State::Pending).unwrap();
+    tx.update_proofs(&mut proofs).await.unwrap();
     tx.commit().await.unwrap();
 
     // Verify new state
@@ -227,8 +229,9 @@ where
 
     // Update to spent
     let mut tx = Database::begin_transaction(&db).await.unwrap();
-    let old_states = tx.update_proofs_states(&ys, State::Spent).await.unwrap();
-    assert_eq!(old_states, vec![Some(State::Pending), Some(State::Pending)]);
+    let mut proofs = tx.get_proofs(&ys).await.unwrap();
+    proofs.set_new_state(State::Spent).unwrap();
+    tx.update_proofs(&mut proofs).await.unwrap();
     tx.commit().await.unwrap();
 
     // Verify final state
@@ -346,16 +349,16 @@ where
 
     // First update to Pending (valid state transition)
     let mut tx = Database::begin_transaction(&db).await.unwrap();
-    tx.update_proofs_states(&[ys[0], ys[1]], State::Pending)
-        .await
-        .unwrap();
+    let mut proofs = tx.get_proofs(&[ys[0], ys[1]]).await.unwrap();
+    proofs.set_new_state(State::Pending).unwrap();
+    tx.update_proofs(&mut proofs).await.unwrap();
     tx.commit().await.unwrap();
 
     // Then mark some as spent
     let mut tx = Database::begin_transaction(&db).await.unwrap();
-    tx.update_proofs_states(&[ys[0], ys[1]], State::Spent)
-        .await
-        .unwrap();
+    let mut proofs = tx.get_proofs(&[ys[0], ys[1]]).await.unwrap();
+    proofs.set_new_state(State::Spent).unwrap();
+    tx.update_proofs(&mut proofs).await.unwrap();
     tx.commit().await.unwrap();
 
     // Get total redeemed
@@ -705,26 +708,26 @@ where
 
     // Transition proofs to Pending state
     let mut tx = Database::begin_transaction(&db).await.unwrap();
-    let _records = tx
-        .get_proof_ys_by_quote_id(&quote_id)
-        .await
-        .expect("valid records");
-    tx.update_proofs_states(&ys, State::Pending).await.unwrap();
+    let mut records = tx.get_proofs(&ys).await.expect("valid records");
+    records.set_new_state(State::Pending).unwrap();
+    tx.update_proofs(&mut records).await.unwrap();
     tx.commit().await.unwrap();
 
     // Removing Pending proofs should also succeed
     let mut tx = Database::begin_transaction(&db).await.unwrap();
     let result = tx.remove_proofs(&[ys[0]], Some(quote_id.clone())).await;
-    assert!(result.is_ok(), "Removing Pending proof should succeed");
+    assert!(
+        result.is_ok(),
+        "Removing Pending proof should succeed: {:?}",
+        result,
+    );
     tx.rollback().await.unwrap(); // Rollback to keep proofs for next test
 
     // Now transition proofs to Spent state
     let mut tx = Database::begin_transaction(&db).await.unwrap();
-    let _records = tx
-        .get_proof_ys_by_quote_id(&quote_id)
-        .await
-        .expect("valid records");
-    tx.update_proofs_states(&ys, State::Spent).await.unwrap();
+    let mut records = tx.get_proofs(&ys).await.expect("valid records");
+    records.set_new_state(State::Spent).unwrap();
+    tx.update_proofs(&mut records).await.unwrap();
     tx.commit().await.unwrap();
 
     // Verify proofs are now in Spent state
@@ -762,3 +765,96 @@ where
         "Second proof should still exist"
     );
 }
+
+/// Test that get_proofs fails when proofs have inconsistent states
+///
+/// This validates the database layer's responsibility to ensure all proofs
+/// returned by get_proofs share the same state. The mint never needs proofs
+/// with different states, so this is an invariant the database must enforce.
+pub async fn get_proofs_with_inconsistent_states_fails<DB>(db: DB)
+where
+    DB: Database<Error> + KeysDatabase<Err = Error>,
+{
+    use cashu::State;
+
+    let keyset_id = setup_keyset(&db).await;
+    let quote_id = QuoteId::new_uuid();
+
+    // Create three proofs
+    let proofs = vec![
+        Proof {
+            amount: Amount::from(100),
+            keyset_id,
+            secret: Secret::generate(),
+            c: SecretKey::generate().public_key(),
+            witness: None,
+            dleq: None,
+        },
+        Proof {
+            amount: Amount::from(200),
+            keyset_id,
+            secret: Secret::generate(),
+            c: SecretKey::generate().public_key(),
+            witness: None,
+            dleq: None,
+        },
+        Proof {
+            amount: Amount::from(300),
+            keyset_id,
+            secret: Secret::generate(),
+            c: SecretKey::generate().public_key(),
+            witness: None,
+            dleq: None,
+        },
+    ];
+
+    let ys: Vec<_> = proofs.iter().map(|p| p.y().unwrap()).collect();
+
+    // Add all proofs (initial state is Unspent)
+    let mut tx = Database::begin_transaction(&db).await.unwrap();
+    tx.add_proofs(
+        proofs,
+        Some(quote_id),
+        &Operation::new_swap(Amount::ZERO, Amount::ZERO, Amount::ZERO),
+    )
+    .await
+    .unwrap();
+    tx.commit().await.unwrap();
+
+    // Transition only the first two proofs to Pending state
+    let mut tx = Database::begin_transaction(&db).await.unwrap();
+    let mut first_two_proofs = tx.get_proofs(&ys[0..2]).await.unwrap();
+    first_two_proofs.set_new_state(State::Pending).unwrap();
+    tx.update_proofs(&mut first_two_proofs).await.unwrap();
+    tx.commit().await.unwrap();
+
+    // Verify the states are now inconsistent via get_proofs_states
+    let states = db.get_proofs_states(&ys).await.unwrap();
+    assert_eq!(
+        states[0],
+        Some(State::Pending),
+        "First proof should be Pending"
+    );
+    assert_eq!(
+        states[1],
+        Some(State::Pending),
+        "Second proof should be Pending"
+    );
+    assert_eq!(
+        states[2],
+        Some(State::Unspent),
+        "Third proof should be Unspent"
+    );
+
+    // Now try to get all three proofs via get_proofs - this should fail
+    // because the proofs have inconsistent states
+    let mut tx = Database::begin_transaction(&db).await.unwrap();
+    let result = tx.get_proofs(&ys).await;
+
+    assert!(
+        result.is_err(),
+        "get_proofs should fail when proofs have inconsistent states"
+    );
+
+    tx.rollback().await.unwrap();
+}

+ 84 - 1
crates/cdk-common/src/mint.rs

@@ -1,6 +1,7 @@
 //! Mint types
 
 use std::fmt;
+use std::ops::Deref;
 use std::str::FromStr;
 
 use bitcoin::bip32::DerivationPath;
@@ -8,7 +9,7 @@ use cashu::quote_id::QuoteId;
 use cashu::util::unix_time;
 use cashu::{
     Bolt11Invoice, MeltOptions, MeltQuoteBolt11Response, MintQuoteBolt11Response,
-    MintQuoteBolt12Response, PaymentMethod,
+    MintQuoteBolt12Response, PaymentMethod, Proofs, State,
 };
 use lightning::offers::offer::Offer;
 use serde::{Deserialize, Serialize};
@@ -17,6 +18,7 @@ use uuid::Uuid;
 
 use crate::nuts::{MeltQuoteState, MintQuoteState};
 use crate::payment::PaymentIdentifier;
+use crate::state::check_state_transition;
 use crate::{Amount, CurrencyUnit, Error, Id, KeySetInfo, PublicKey};
 
 /// Operation kind for saga persistence
@@ -31,6 +33,87 @@ pub enum OperationKind {
     Melt,
 }
 
+/// A collection of proofs that share a common state.
+///
+/// This type enforces the invariant that all proofs in the collection have the same state.
+/// The mint never needs to operate on a set of proofs with different states - proofs are
+/// always processed together as a unit (e.g., during swap, melt, or mint operations).
+///
+/// # Database Layer Responsibility
+///
+/// This design shifts the responsibility of ensuring state consistency to the database layer.
+/// When the database retrieves proofs via [`get_proofs`](crate::database::mint::ProofsTransaction::get_proofs),
+/// it must verify that all requested proofs share the same state and return an error if they don't.
+/// This prevents invalid proof sets from propagating through the system.
+///
+/// # State Transitions
+///
+/// State changes are performed atomically on the entire collection via [`set_new_state`](Self::set_new_state),
+/// which validates the transition before applying it. The database layer then persists
+/// the new state for all proofs in a single transaction.
+///
+/// # Example
+///
+/// ```ignore
+/// // Database layer ensures all proofs have the same state
+/// let mut proofs = tx.get_proofs(&ys).await?;
+///
+/// // Transition all proofs to a new state
+/// let old_state = proofs.set_new_state(State::Spent)?;
+///
+/// // Persist the state change
+/// tx.update_proofs(&mut proofs).await?;
+/// ```
+#[derive(Debug)]
+pub struct ProofsWithState {
+    proofs: Proofs,
+    state: State,
+}
+
+impl Deref for ProofsWithState {
+    type Target = Proofs;
+
+    fn deref(&self) -> &Self::Target {
+        &self.proofs
+    }
+}
+
+impl ProofsWithState {
+    /// Creates a new `ProofsWithState` with the given proofs and their shared state.
+    ///
+    /// # Note
+    ///
+    /// This constructor assumes all proofs share the given state. It is typically
+    /// called by the database layer after verifying state consistency.
+    pub fn new(proofs: Proofs, current_state: State) -> Self {
+        Self {
+            proofs,
+            state: current_state,
+        }
+    }
+
+    /// Returns the current state shared by all proofs in the collection.
+    pub fn get_state(&self) -> State {
+        self.state
+    }
+
+    /// Transitions all proofs to a new state.
+    ///
+    /// Validates that the state transition is allowed before applying it.
+    /// Returns the previous state on success.
+    ///
+    /// # Errors
+    ///
+    /// Returns [`Error::UnexpectedProofState`] if the transition from the current
+    /// state to the new state is not permitted.
+    pub fn set_new_state(&mut self, new_state: State) -> Result<State, Error> {
+        check_state_transition(self.state, new_state).map_err(|_| Error::UnexpectedProofState)?;
+        let old_state = self.state;
+        self.state = new_state;
+        Ok(old_state)
+    }
+}
+
 impl fmt::Display for OperationKind {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         match self {

+ 111 - 45
crates/cdk-sql-common/src/mint/proofs.rs

@@ -4,11 +4,12 @@ use std::collections::HashMap;
 use std::str::FromStr;
 
 use async_trait::async_trait;
-use cdk_common::database::{self, Error, MintProofsDatabase};
-use cdk_common::mint::Operation;
+use cdk_common::database::{self, Acquired, Error, MintProofsDatabase};
+use cdk_common::mint::{Operation, ProofsWithState};
 use cdk_common::nut00::ProofsMethods;
 use cdk_common::quote_id::QuoteId;
 use cdk_common::secret::Secret;
+use cdk_common::util::unix_time;
 use cdk_common::{Amount, Id, Proof, Proofs, PublicKey, State};
 
 use super::{SQLMintDatabase, SQLTransaction};
@@ -69,9 +70,7 @@ pub(super) fn sql_row_to_proof(row: Vec<Column>) -> Result<Proof, Error> {
     })
 }
 
-pub(super) fn sql_row_to_proof_with_state(
-    row: Vec<Column>,
-) -> Result<(Proof, Option<State>), Error> {
+pub(super) fn sql_row_to_proof_with_state(row: Vec<Column>) -> Result<(Proof, State), Error> {
     unpack_into!(
         let (
             keyset_id, amount, secret, c, witness, state
@@ -79,7 +78,9 @@ pub(super) fn sql_row_to_proof_with_state(
     );
 
     let amount: u64 = column_as_number!(amount);
-    let state = column_as_nullable_string!(state).and_then(|s| State::from_str(&s).ok());
+    let state = column_as_nullable_string!(state)
+        .and_then(|s| State::from_str(&s).ok())
+        .unwrap_or(State::Pending);
 
     Ok((
         Proof {
@@ -116,13 +117,22 @@ where
 {
     type Err = Error;
 
+    /// Adds proofs to the database with initial state `Unspent`.
+    ///
+    /// This method first checks if any of the proofs already exist in the database.
+    /// If a proof exists and is spent, returns [`Error::AttemptUpdateSpentProof`].
+    /// If a proof exists in any other state, returns [`Error::Duplicate`].
+    ///
+    /// On success, returns the proofs wrapped in [`Acquired<ProofsWithState>`] with
+    /// state set to `Unspent`, indicating the rows are locked for the duration of
+    /// the transaction.
     async fn add_proofs(
         &mut self,
         proofs: Proofs,
         quote_id: Option<QuoteId>,
         operation: &Operation,
-    ) -> Result<(), Self::Err> {
-        let current_time = cdk_common::util::unix_time();
+    ) -> Result<Acquired<ProofsWithState>, Self::Err> {
+        let current_time = unix_time();
 
         // Check any previous proof, this query should return None in order to proceed storing
         // Any result here would error
@@ -144,7 +154,7 @@ where
             None => Ok(()), // no previous record
         }?;
 
-        for proof in proofs {
+        for proof in &proofs {
             let y = proof.y()?;
 
             query(
@@ -162,7 +172,7 @@ where
             .bind("c", proof.c.to_bytes().to_vec())
             .bind(
                 "witness",
-                proof.witness.and_then(|w| serde_json::to_string(&w).inspect_err(|e| tracing::error!("Failed to serialize witness: {:?}", e)).ok()),
+                proof.witness.clone().and_then(|w| serde_json::to_string(&w).inspect_err(|e| tracing::error!("Failed to serialize witness: {:?}", e)).ok()),
             )
             .bind("state", "UNSPENT".to_string())
             .bind("quote_id", quote_id.clone().map(|q| q.to_string()))
@@ -173,24 +183,28 @@ where
             .await?;
         }
 
-        Ok(())
+        Ok(ProofsWithState::new(proofs, State::Unspent).into())
     }
 
-    async fn update_proofs_states(
+    /// Persists the current state of the proofs to the database.
+    ///
+    /// Reads the state from the [`ProofsWithState`] wrapper (previously set via
+    /// [`ProofsWithState::set_new_state`]) and updates all proofs in the database
+    /// to that state.
+    ///
+    /// When the new state is `Spent`, this method also updates the `keyset_amounts`
+    /// table to track the total redeemed amount per keyset for analytics purposes.
+    ///
+    /// # Prerequisites
+    ///
+    /// The proofs must have been previously acquired via `add_proofs`
+    /// or `get_proofs` to ensure they are locked within the current transaction.
+    async fn update_proofs(
         &mut self,
-        ys: &[PublicKey],
-        new_state: State,
-    ) -> Result<Vec<Option<State>>, Self::Err> {
-        let mut current_states = get_current_states(&self.inner, ys, true).await?;
-
-        if current_states.len() != ys.len() {
-            tracing::warn!(
-                "Attempted to update state of non-existent proof {} {}",
-                current_states.len(),
-                ys.len()
-            );
-            return Err(database::Error::ProofNotFound);
-        }
+        proofs: &mut Acquired<ProofsWithState>,
+    ) -> Result<(), Self::Err> {
+        let ys = proofs.ys()?;
+        let new_state = proofs.get_state();
 
         query(r#"UPDATE proof SET state = :new_state WHERE y IN (:ys)"#)?
             .bind("new_state", new_state.to_string())
@@ -200,22 +214,22 @@ where
 
         if new_state == State::Spent {
             query(
-                r#"
-                INSERT INTO keyset_amounts (keyset_id, total_issued, total_redeemed)
-                SELECT keyset_id, 0, COALESCE(SUM(amount), 0)
-                FROM proof
-                WHERE y IN (:ys)
-                GROUP BY keyset_id
-                ON CONFLICT (keyset_id)
-                DO UPDATE SET total_redeemed = keyset_amounts.total_redeemed + EXCLUDED.total_redeemed
-                "#,
-            )?
-            .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
-            .execute(&self.inner)
-            .await?;
+                    r#"
+                    INSERT INTO keyset_amounts (keyset_id, total_issued, total_redeemed)
+                    SELECT keyset_id, 0, COALESCE(SUM(amount), 0)
+                    FROM proof
+                    WHERE y IN (:ys)
+                    GROUP BY keyset_id
+                    ON CONFLICT (keyset_id)
+                    DO UPDATE SET total_redeemed = keyset_amounts.total_redeemed + EXCLUDED.total_redeemed
+                    "#,
+                )?
+                .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
+                .execute(&self.inner)
+                .await?;
         }
 
-        Ok(ys.iter().map(|y| current_states.remove(y)).collect())
+        Ok(())
     }
 
     async fn remove_proofs(
@@ -332,13 +346,62 @@ where
         .collect::<Result<Vec<_>, _>>()?)
     }
 
-    async fn get_proofs_states(
+    async fn get_proofs(
         &mut self,
         ys: &[PublicKey],
-    ) -> Result<Vec<Option<State>>, Self::Err> {
-        let mut current_states = get_current_states(&self.inner, ys, true).await?;
+    ) -> Result<Acquired<ProofsWithState>, Self::Err> {
+        if ys.is_empty() {
+            return Ok(ProofsWithState::new(vec![], State::Unspent).into());
+        }
 
-        Ok(ys.iter().map(|y| current_states.remove(y)).collect())
+        let rows = query(
+            r#"
+             SELECT
+                 keyset_id,
+                 amount,
+                 secret,
+                 c,
+                 witness,
+                 state
+             FROM
+                 proof
+             WHERE
+                 y IN (:ys)
+             FOR UPDATE
+             "#,
+        )?
+        .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
+        .fetch_all(&self.inner)
+        .await?;
+
+        if rows.is_empty() {
+            return Err(database::Error::ProofNotFound);
+        }
+
+        let results: Vec<(Proof, State)> = rows
+            .into_iter()
+            .map(sql_row_to_proof_with_state)
+            .collect::<Result<Vec<_>, _>>()?;
+
+        let mut proofs = Vec::with_capacity(results.len());
+        let mut first_state: Option<State> = None;
+
+        for (proof, state) in results {
+            if let Some(first) = first_state {
+                if first != state {
+                    return Err(database::Error::Internal(
+                        "Proofs have inconsistent states".to_string(),
+                    ));
+                }
+            } else {
+                first_state = Some(state);
+            }
+
+            proofs.push(proof);
+        }
+
+        let state = first_state.unwrap_or(State::Unspent);
+        Ok(ProofsWithState::new(proofs, state).into())
     }
 }
 
@@ -425,7 +488,8 @@ where
         keyset_id: &Id,
     ) -> Result<(Proofs, Vec<Option<State>>), Self::Err> {
         let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
-        Ok(query(
+
+        let (proofs, states): (Vec<Proof>, Vec<State>) = query(
             r#"
             SELECT
                keyset_id,
@@ -447,7 +511,9 @@ where
         .map(sql_row_to_proof_with_state)
         .collect::<Result<Vec<_>, _>>()?
         .into_iter()
-        .unzip())
+        .unzip();
+
+        Ok((proofs, states.into_iter().map(Some).collect()))
     }
 
     /// Get total proofs redeemed by keyset id

+ 14 - 30
crates/cdk/src/mint/melt/melt_saga/mod.rs

@@ -7,7 +7,6 @@ use cdk_common::database::DynMintDatabase;
 use cdk_common::mint::{MeltSagaState, Operation, Saga, SagaStateEnum};
 use cdk_common::nut00::KnownMethod;
 use cdk_common::nuts::MeltQuoteState;
-use cdk_common::state::check_state_transition;
 use cdk_common::{Amount, Error, ProofsMethods, PublicKey, QuoteId, State};
 #[cfg(feature = "prometheus")]
 use cdk_prometheus::METRICS;
@@ -246,19 +245,23 @@ impl MeltSaga<Initial> {
 
         let input_ys = melt_request.inputs().ys()?;
 
-        for current_state in tx
-            .get_proofs_states(&input_ys)
-            .await?
-            .into_iter()
-            .collect::<Option<Vec<_>>>()
-            .ok_or(Error::UnexpectedProofState)?
-        {
-            check_state_transition(current_state, State::Pending)
-                .map_err(|_| Error::UnexpectedProofState)?;
+        let mut proofs = tx.get_proofs(&input_ys).await?;
+
+        let original_state = proofs.get_state();
+
+        if matches!(original_state, State::Pending | State::Spent) {
+            tx.rollback().await?;
+            return Err(if original_state == State::Pending {
+                Error::TokenPending
+            } else {
+                Error::TokenAlreadySpent
+            });
         }
 
+        proofs.set_new_state(State::Pending)?;
+
         // Update proof states to Pending
-        let original_states = match tx.update_proofs_states(&input_ys, State::Pending).await {
+        match tx.update_proofs(&mut proofs).await {
             Ok(states) => states,
             Err(cdk_common::database::Error::AttemptUpdateSpentProof)
             | Err(cdk_common::database::Error::AttemptRemoveSpentProof) => {
@@ -271,25 +274,6 @@ impl MeltSaga<Initial> {
             }
         };
 
-        // Check for forbidden states (Pending or Spent)
-        let has_forbidden_state = original_states
-            .iter()
-            .any(|state| matches!(state, Some(State::Pending) | Some(State::Spent)));
-
-        if has_forbidden_state {
-            tx.rollback().await?;
-            return Err(
-                if original_states
-                    .iter()
-                    .any(|s| matches!(s, Some(State::Pending)))
-                {
-                    Error::TokenPending
-                } else {
-                    Error::TokenAlreadySpent
-                },
-            );
-        }
-
         let previous_state = quote.state;
 
         // Publish proof state changes

+ 3 - 12
crates/cdk/src/mint/melt/shared.rs

@@ -8,7 +8,6 @@
 
 use cdk_common::database::{self, Acquired, DynMintDatabase};
 use cdk_common::nuts::{BlindSignature, BlindedMessage, MeltQuoteState, State};
-use cdk_common::state::check_state_transition;
 use cdk_common::{Amount, Error, PublicKey, QuoteId};
 use cdk_signatory::signatory::SignatoryKeySet;
 
@@ -393,19 +392,11 @@ pub async fn finalize_melt_core(
             .await?;
     }
 
-    for current_state in tx
-        .get_proofs_states(input_ys)
-        .await?
-        .into_iter()
-        .collect::<Option<Vec<_>>>()
-        .ok_or(Error::UnexpectedProofState)?
-    {
-        check_state_transition(current_state, State::Spent)
-            .map_err(|_| Error::UnexpectedProofState)?;
-    }
+    let mut proofs = tx.get_proofs(input_ys).await?;
+    proofs.set_new_state(State::Spent)?;
 
     // Mark input proofs as spent
-    match tx.update_proofs_states(input_ys, State::Spent).await {
+    match tx.update_proofs(&mut proofs).await {
         Ok(_) => {}
         Err(database::Error::AttemptUpdateSpentProof) => {
             tracing::info!("Proofs for quote {} already marked as spent", quote.id);

+ 26 - 37
crates/cdk/src/mint/swap/swap_saga/mod.rs

@@ -4,7 +4,6 @@ use std::sync::Arc;
 use cdk_common::database::DynMintDatabase;
 use cdk_common::mint::{Operation, Saga, SwapSagaState};
 use cdk_common::nuts::BlindedMessage;
-use cdk_common::state::check_state_transition;
 use cdk_common::{database, Amount, Error, Proofs, ProofsMethods, PublicKey, QuoteId, State};
 use tokio::sync::Mutex;
 use tracing::instrument;
@@ -177,17 +176,30 @@ impl<'a> SwapSaga<'a, Initial> {
         );
 
         // Add input proofs to DB
-        if let Err(err) = tx
+        let mut new_proofs = match tx
             .add_proofs(input_proofs.clone(), quote_id.clone(), &operation)
             .await
         {
-            tx.rollback().await?;
-            return Err(match err {
-                database::Error::Duplicate => Error::TokenPending,
-                database::Error::AttemptUpdateSpentProof => Error::TokenAlreadySpent,
-                _ => Error::Database(err),
-            });
-        }
+            Ok(proofs) => proofs,
+            Err(err) => {
+                tx.rollback().await?;
+                return Err(match err {
+                    database::Error::Duplicate => Error::TokenPending,
+                    database::Error::AttemptUpdateSpentProof => Error::TokenAlreadySpent,
+                    _ => Error::Database(err),
+                });
+            }
+        };
+
+        let original_state = new_proofs.get_state();
+
+        new_proofs.set_new_state(State::Pending).map_err(|_| {
+            if original_state == State::Pending {
+                Error::TokenPending
+            } else {
+                Error::TokenAlreadySpent
+            }
+        })?;
 
         let ys = match input_proofs.ys() {
             Ok(ys) => ys,
@@ -195,7 +207,7 @@ impl<'a> SwapSaga<'a, Initial> {
         };
 
         // Update input proof states to Pending
-        let original_proof_states = match tx.update_proofs_states(&ys, State::Pending).await {
+        match tx.update_proofs(&mut new_proofs).await {
             Ok(states) => states,
             Err(database::Error::AttemptUpdateSpentProof)
             | Err(database::Error::AttemptRemoveSpentProof) => {
@@ -209,24 +221,12 @@ impl<'a> SwapSaga<'a, Initial> {
         };
 
         // Verify proofs weren't already pending or spent
-        if ys.len() != original_proof_states.len() {
+        if ys.len() != new_proofs.len() {
             tracing::error!("Mismatched proof states");
             tx.rollback().await?;
             return Err(Error::Internal);
         }
 
-        let forbidden_states = [State::Pending, State::Spent];
-        for original_state in original_proof_states.iter().flatten() {
-            if forbidden_states.contains(original_state) {
-                tx.rollback().await?;
-                return Err(if *original_state == State::Pending {
-                    Error::TokenPending
-                } else {
-                    Error::TokenAlreadySpent
-                });
-            }
-        }
-
         // Add output blinded messages
         if let Err(err) = tx
             .add_blinded_messages(quote_id.as_ref(), blinded_messages, &operation)
@@ -423,21 +423,10 @@ impl SwapSaga<'_, Signed> {
             }
         }
 
-        for current_state in tx
-            .get_proofs_states(&self.state_data.ys)
-            .await?
-            .into_iter()
-            .collect::<Option<Vec<_>>>()
-            .ok_or(Error::UnexpectedProofState)?
-        {
-            check_state_transition(current_state, State::Spent)
-                .map_err(|_| Error::UnexpectedProofState)?;
-        }
+        let mut proofs = tx.get_proofs(&self.state_data.ys).await?;
+        proofs.set_new_state(State::Spent)?;
 
-        match tx
-            .update_proofs_states(&self.state_data.ys, State::Spent)
-            .await
-        {
+        match tx.update_proofs(&mut proofs).await {
             Ok(_) => {}
             Err(database::Error::AttemptUpdateSpentProof)
             | Err(database::Error::AttemptRemoveSpentProof) => {