Selaa lähdekoodia

Introduce `ProofWriter`

Commit database transaction before melting and minting and then let ProofWriter
to rollback states automatically
Cesar Rodas 5 kuukautta sitten
vanhempi
säilyke
b6b6505835

+ 1 - 1
crates/cdk-cli/src/sub_commands/check_pending.rs

@@ -26,7 +26,7 @@ pub async fn check_pending(multi_mint_wallet: &MultiMintWallet) -> Result<()> {
         // Try to reclaim any proofs that are no longer pending
         match wallet.reclaim_unspent(pending_proofs).await {
             Ok(()) => println!("Successfully reclaimed pending proofs"),
-            Err(e) => println!("Error reclaimed pending proofs: {}", e),
+            Err(e) => println!("Error reclaimed pending proofs: {e}"),
         }
     }
     Ok(())

+ 7 - 0
crates/cdk-common/src/database/mint/mod.rs

@@ -149,6 +149,13 @@ pub trait ProofsTransaction<'a> {
         ys: &[PublicKey],
         proofs_state: State,
     ) -> Result<Vec<Option<State>>, Self::Err>;
+
+    /// Remove [`Proofs`]
+    async fn remove_proofs(
+        &mut self,
+        ys: &[PublicKey],
+        quote_id: Option<Uuid>,
+    ) -> Result<(), Self::Err>;
 }
 
 /// Mint Proof Database trait

+ 9 - 9
crates/cdk-sqlite/src/mint/auth/mod.rs

@@ -70,7 +70,7 @@ impl MintAuthTransaction<database::Error> for SqliteTransaction<'_> {
             "#,
         )
         .bind(":id", id.to_string())
-        .execute(&self.transaction)
+        .execute(&self.inner)
         .await?;
 
         Ok(())
@@ -106,7 +106,7 @@ impl MintAuthTransaction<database::Error> for SqliteTransaction<'_> {
         .bind(":derivation_path", keyset.derivation_path.to_string())
         .bind(":max_order", keyset.max_order)
         .bind(":derivation_path_index", keyset.derivation_path_index)
-        .execute(&self.transaction)
+        .execute(&self.inner)
         .await?;
 
         Ok(())
@@ -126,7 +126,7 @@ impl MintAuthTransaction<database::Error> for SqliteTransaction<'_> {
         .bind(":secret", proof.secret.to_string())
         .bind(":c", proof.c.to_bytes().to_vec())
         .bind(":state", "UNSPENT".to_string())
-        .execute(&self.transaction)
+        .execute(&self.inner)
         .await
         {
             tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
@@ -141,7 +141,7 @@ impl MintAuthTransaction<database::Error> for SqliteTransaction<'_> {
     ) -> Result<Option<State>, Self::Err> {
         let current_state = query(r#"SELECT state FROM proof WHERE y = :y"#)
             .bind(":y", y.to_bytes().to_vec())
-            .pluck(&self.transaction)
+            .pluck(&self.inner)
             .await?
             .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
             .transpose()?;
@@ -153,7 +153,7 @@ impl MintAuthTransaction<database::Error> for SqliteTransaction<'_> {
                 current_state.as_ref().map(|state| state.to_string()),
             )
             .bind(":new_state", proofs_state.to_string())
-            .execute(&self.transaction)
+            .execute(&self.inner)
             .await?;
 
         Ok(current_state)
@@ -178,7 +178,7 @@ impl MintAuthTransaction<database::Error> for SqliteTransaction<'_> {
             .bind(":amount", u64::from(signature.amount) as i64)
             .bind(":keyset_id", signature.keyset_id.to_string())
             .bind(":c", signature.c.to_bytes().to_vec())
-            .execute(&self.transaction)
+            .execute(&self.inner)
             .await?;
         }
 
@@ -199,7 +199,7 @@ impl MintAuthTransaction<database::Error> for SqliteTransaction<'_> {
             )
             .bind(":endpoint", serde_json::to_string(endpoint)?)
             .bind(":auth", serde_json::to_string(auth)?)
-            .execute(&self.transaction)
+            .execute(&self.inner)
             .await
             {
                 tracing::debug!(
@@ -223,7 +223,7 @@ impl MintAuthTransaction<database::Error> for SqliteTransaction<'_> {
                     .map(serde_json::to_string)
                     .collect::<Result<_, _>>()?,
             )
-            .execute(&self.transaction)
+            .execute(&self.inner)
             .await?;
         Ok(())
     }
@@ -238,7 +238,7 @@ impl MintAuthDatabase for MintSqliteAuthDatabase {
     ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
     {
         Ok(Box::new(SqliteTransaction {
-            transaction: self.pool.begin().await?,
+            inner: self.pool.begin().await?,
         }))
     }
 

+ 55 - 38
crates/cdk-sqlite/src/mint/mod.rs

@@ -138,17 +138,17 @@ impl MintSqliteDatabase {
 
 /// Sqlite Writer
 pub struct SqliteTransaction<'a> {
-    transaction: Transaction<'a>,
+    inner: Transaction<'a>,
 }
 
 #[async_trait]
 impl<'a> database::MintTransaction<'a, database::Error> for SqliteTransaction<'a> {
     async fn set_mint_info(&mut self, mint_info: MintInfo) -> Result<(), database::Error> {
-        Ok(set_to_config(&self.transaction, "mint_info", &mint_info).await?)
+        Ok(set_to_config(&self.inner, "mint_info", &mint_info).await?)
     }
 
     async fn set_quote_ttl(&mut self, quote_ttl: QuoteTTL) -> Result<(), database::Error> {
-        Ok(set_to_config(&self.transaction, "quote_ttl", &quote_ttl).await?)
+        Ok(set_to_config(&self.inner, "quote_ttl", &quote_ttl).await?)
     }
 }
 
@@ -157,11 +157,11 @@ impl MintDbWriterFinalizer for SqliteTransaction<'_> {
     type Err = database::Error;
 
     async fn commit(self: Box<Self>) -> Result<(), database::Error> {
-        Ok(self.transaction.commit().await?)
+        Ok(self.inner.commit().await?)
     }
 
     async fn rollback(self: Box<Self>) -> Result<(), database::Error> {
-        Ok(self.transaction.rollback().await?)
+        Ok(self.inner.rollback().await?)
     }
 }
 
@@ -199,7 +199,7 @@ impl<'a> MintKeyDatabaseTransaction<'a, database::Error> for SqliteTransaction<'
         .bind(":max_order", keyset.max_order)
         .bind(":input_fee_ppk", keyset.input_fee_ppk as i64)
         .bind(":derivation_path_index", keyset.derivation_path_index)
-        .execute(&self.transaction)
+        .execute(&self.inner)
         .await?;
 
         Ok(())
@@ -212,13 +212,13 @@ impl<'a> MintKeyDatabaseTransaction<'a, database::Error> for SqliteTransaction<'
     ) -> Result<(), database::Error> {
         query(r#"UPDATE keyset SET active=FALSE WHERE unit IS :unit"#)
             .bind(":unit", unit.to_string())
-            .execute(&self.transaction)
+            .execute(&self.inner)
             .await?;
 
         query(r#"UPDATE keyset SET active=TRUE WHERE unit IS :unit AND id IS :id"#)
             .bind(":unit", unit.to_string())
             .bind(":id", id.to_string())
-            .execute(&self.transaction)
+            .execute(&self.inner)
             .await?;
 
         Ok(())
@@ -236,7 +236,7 @@ impl MintKeysDatabase for MintSqliteDatabase {
         database::Error,
     > {
         Ok(Box::new(SqliteTransaction {
-            transaction: self.pool.begin().await?,
+            inner: self.pool.begin().await?,
         }))
     }
 
@@ -344,22 +344,17 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
         .bind(":created_time", quote.created_time as i64)
         .bind(":paid_time", quote.paid_time.map(|t| t as i64))
         .bind(":issued_time", quote.issued_time.map(|t| t as i64))
-        .execute(&self.transaction)
+        .execute(&self.inner)
         .await?;
 
         Ok(())
     }
 
     async fn remove_mint_quote(&mut self, quote_id: &Uuid) -> Result<(), Self::Err> {
-        query(
-            r#"
-            DELETE FROM mint_quote
-            WHERE id=?
-            "#,
-        )
-        .bind(":id", quote_id.as_hyphenated().to_string())
-        .execute(&self.transaction)
-        .await?;
+        query(r#"DELETE FROM mint_quote WHERE id=:id"#)
+            .bind(":id", quote_id.as_hyphenated().to_string())
+            .execute(&self.inner)
+            .await?;
         Ok(())
     }
 
@@ -395,7 +390,7 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
         )
         .bind(":created_time", quote.created_time as i64)
         .bind(":paid_time", quote.paid_time.map(|t| t as i64))
-        .execute(&self.transaction)
+        .execute(&self.inner)
         .await?;
 
         Ok(())
@@ -430,7 +425,7 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
         )
         .bind(":id", quote_id.as_hyphenated().to_string())
         .bind(":state", state.to_string())
-        .fetch_one(&self.transaction)
+        .fetch_one(&self.inner)
         .await?
         .map(sqlite_row_to_melt_quote)
         .transpose()?
@@ -442,13 +437,13 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
                 .bind(":state", state.to_string())
                 .bind(":paid_time", current_time as i64)
                 .bind(":id", quote_id.as_hyphenated().to_string())
-                .execute(&self.transaction)
+                .execute(&self.inner)
                 .await
         } else {
             query(r#"UPDATE melt_quote SET state = :state WHERE id = :id"#)
                 .bind(":state", state.to_string())
                 .bind(":id", quote_id.as_hyphenated().to_string())
-                .execute(&self.transaction)
+                .execute(&self.inner)
                 .await
         };
 
@@ -474,7 +469,7 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
             "#,
         )
         .bind(":id", quote_id.as_hyphenated().to_string())
-        .execute(&self.transaction)
+        .execute(&self.inner)
         .await?;
 
         Ok(())
@@ -504,7 +499,7 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
             WHERE id = :id"#,
         )
         .bind(":id", quote_id.as_hyphenated().to_string())
-        .fetch_one(&self.transaction)
+        .fetch_one(&self.inner)
         .await?
         .map(sqlite_row_to_mint_quote)
         .ok_or(Error::QuoteNotFound)??;
@@ -535,7 +530,7 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
                 .bind(":quote_id", quote_id.as_hyphenated().to_string()),
         };
 
-        match update.execute(&self.transaction).await {
+        match update.execute(&self.inner).await {
             Ok(_) => Ok(quote.state),
             Err(err) => {
                 tracing::error!("SQLite Could not update keyset: {:?}", err);
@@ -565,7 +560,7 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
             WHERE id = :id"#,
         )
         .bind(":id", quote_id.as_hyphenated().to_string())
-        .fetch_one(&self.transaction)
+        .fetch_one(&self.inner)
         .await?
         .map(sqlite_row_to_mint_quote)
         .transpose()?)
@@ -597,7 +592,7 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
             "#,
         )
         .bind(":id", quote_id.as_hyphenated().to_string())
-        .fetch_one(&self.transaction)
+        .fetch_one(&self.inner)
         .await?
         .map(sqlite_row_to_melt_quote)
         .transpose()?)
@@ -626,7 +621,7 @@ impl<'a> MintQuotesTransaction<'a> for SqliteTransaction<'a> {
             WHERE request = :request"#,
         )
         .bind(":request", request.to_owned())
-        .fetch_one(&self.transaction)
+        .fetch_one(&self.inner)
         .await?
         .map(sqlite_row_to_mint_quote)
         .transpose()?)
@@ -857,7 +852,7 @@ impl<'a> MintProofsTransaction<'a> for SqliteTransaction<'a> {
                     .map(|y| y.y().map(|y| y.to_bytes().to_vec()))
                     .collect::<Result<_, _>>()?,
             )
-            .pluck(&self.transaction)
+            .pluck(&self.inner)
             .await?
             .map(|state| Ok::<_, Error>(column_as_string!(&state, State::from_str)))
             .transpose()?
@@ -888,7 +883,7 @@ impl<'a> MintProofsTransaction<'a> for SqliteTransaction<'a> {
             .bind(":state", "UNSPENT".to_string())
             .bind(":quote_id", quote_id.map(|q| q.hyphenated().to_string()))
             .bind(":created_time", current_time as i64)
-            .execute(&self.transaction)
+            .execute(&self.inner)
             .await?;
         }
 
@@ -900,7 +895,7 @@ impl<'a> MintProofsTransaction<'a> for SqliteTransaction<'a> {
         ys: &[PublicKey],
         new_state: State,
     ) -> Result<Vec<Option<State>>, Self::Err> {
-        let mut current_states = get_current_states(&self.transaction, ys).await?;
+        let mut current_states = get_current_states(&self.inner, ys).await?;
 
         if current_states.len() != ys.len() {
             tracing::warn!(
@@ -918,11 +913,33 @@ impl<'a> MintProofsTransaction<'a> for SqliteTransaction<'a> {
         query(r#"UPDATE proof SET state = :new_state WHERE y IN (:ys)"#)
             .bind(":new_state", new_state.to_string())
             .bind_vec(":ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
-            .execute(&self.transaction)
+            .execute(&self.inner)
             .await?;
 
         Ok(ys.iter().map(|y| current_states.remove(y)).collect())
     }
+
+    async fn remove_proofs(
+        &mut self,
+        ys: &[PublicKey],
+        _quote_id: Option<Uuid>,
+    ) -> Result<(), Self::Err> {
+        let total_deleted = query(
+            r#"
+            DELETE FROM proof WHERE y IN (:ys) AND state NOT IN (:exclude_state)
+            "#,
+        )
+        .bind_vec(":ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
+        .bind_vec(":exclude_state", vec![State::Spent.to_string()])
+        .execute(&self.inner)
+        .await?;
+
+        if total_deleted != ys.len() {
+            return Err(Self::Err::AttemptRemoveSpentProof);
+        }
+
+        Ok(())
+    }
 }
 
 #[async_trait]
@@ -1059,7 +1076,7 @@ impl<'a> MintSignatureTransaction<'a> for SqliteTransaction<'a> {
                 signature.dleq.as_ref().map(|dleq| dleq.s.to_secret_hex()),
             )
             .bind(":created_time", current_time as i64)
-            .execute(&self.transaction)
+            .execute(&self.inner)
             .await?;
         }
 
@@ -1077,10 +1094,10 @@ impl<'a> MintSignatureTransaction<'a> for SqliteTransaction<'a> {
                 c,
                 dleq_e,
                 dleq_s,
-                y
+                blinded_message
             FROM
                 blind_signature
-            WHERE y IN (:y)
+            WHERE blinded_message IN (:y)
             "#,
         )
         .bind_vec(
@@ -1090,7 +1107,7 @@ impl<'a> MintSignatureTransaction<'a> for SqliteTransaction<'a> {
                 .map(|y| y.to_bytes().to_vec())
                 .collect(),
         )
-        .fetch_all(&self.transaction)
+        .fetch_all(&self.inner)
         .await?
         .into_iter()
         .map(|mut row| {
@@ -1222,7 +1239,7 @@ impl MintDatabase<database::Error> for MintSqliteDatabase {
         database::Error,
     > {
         Ok(Box::new(SqliteTransaction {
-            transaction: self.pool.begin().await?,
+            inner: self.pool.begin().await?,
         }))
     }
 

+ 2 - 79
crates/cdk/src/mint/check_spendable.rs

@@ -1,41 +1,10 @@
-use std::collections::{HashMap, HashSet};
-
 use futures::future::try_join_all;
 use tracing::instrument;
 
-use super::{CheckStateRequest, CheckStateResponse, Mint, ProofState, PublicKey, State};
-use crate::{cdk_database, Error};
+use super::{CheckStateRequest, CheckStateResponse, Mint, ProofState, State};
+use crate::Error;
 
 impl Mint {
-    /// Helper function to reset proofs to their original state, skipping spent proofs
-    async fn reset_proofs_to_original_state(
-        &self,
-        ys: &[PublicKey],
-        original_states: Vec<Option<State>>,
-    ) -> Result<(), Error> {
-        let mut ys_by_state = HashMap::new();
-        let mut unknown_proofs = Vec::new();
-        for (y, state) in ys.iter().zip(original_states) {
-            if let Some(state) = state {
-                // Skip attempting to update proofs that were originally spent
-                if state != State::Spent {
-                    ys_by_state.entry(state).or_insert_with(Vec::new).push(*y);
-                }
-            } else {
-                unknown_proofs.push(*y);
-            }
-        }
-
-        let mut tx = self.localstore.begin_transaction().await?;
-        for (state, ys) in ys_by_state {
-            tx.update_proofs_states(&ys, state).await?;
-        }
-
-        tx.commit().await?;
-
-        Ok(())
-    }
-
     /// Check state
     #[instrument(skip_all)]
     pub async fn check_state(
@@ -71,50 +40,4 @@ impl Mint {
             states: proof_states,
         })
     }
-
-    /// Check Tokens are not spent or pending
-    #[instrument(skip_all)]
-    pub async fn check_ys_spendable(
-        &self,
-        tx: &mut Box<dyn cdk_database::MintTransaction<'_, cdk_database::Error> + Send + Sync + '_>,
-        ys: &[PublicKey],
-        proof_state: State,
-    ) -> Result<(), Error> {
-        let original_proofs_state = match tx.update_proofs_states(ys, proof_state).await {
-            Ok(states) => states,
-            Err(cdk_database::Error::AttemptUpdateSpentProof)
-            | Err(cdk_database::Error::AttemptRemoveSpentProof) => {
-                return Err(Error::TokenAlreadySpent)
-            }
-            Err(err) => return Err(err.into()),
-        };
-
-        assert!(ys.len() == original_proofs_state.len());
-
-        let proofs_state = original_proofs_state
-            .iter()
-            .flatten()
-            .collect::<HashSet<&State>>();
-
-        if proofs_state.contains(&State::Pending) {
-            // Reset states before returning error
-            self.reset_proofs_to_original_state(ys, original_proofs_state)
-                .await?;
-            return Err(Error::TokenPending);
-        }
-
-        if proofs_state.contains(&State::Spent) {
-            // Reset states before returning error
-            self.reset_proofs_to_original_state(ys, original_proofs_state)
-                .await?;
-            return Err(Error::TokenAlreadySpent);
-        }
-
-        for public_key in ys {
-            tracing::trace!("proof: {} set to {}", public_key.to_hex(), proof_state);
-            self.pubsub_manager.proof_state((*public_key, proof_state));
-        }
-
-        Ok(())
-    }
 }

+ 9 - 5
crates/cdk/src/mint/issue/issue_nut04.rs

@@ -133,8 +133,10 @@ impl Mint {
         // response. In practice the wallet should not be checking the state of
         // a quote while waiting for the mint response.
         if mint_quote.state == MintQuoteState::Unpaid {
-            self.check_mint_quote_paid(&mut tx, &mut mint_quote).await?;
-            tx.commit().await?;
+            self.check_mint_quote_paid(tx, &mut mint_quote)
+                .await?
+                .commit()
+                .await?;
         }
 
         Ok(MintQuoteBolt11Response {
@@ -242,9 +244,11 @@ impl Mint {
             .await?
             .ok_or(Error::UnknownQuote)?;
 
-        if mint_quote.state == MintQuoteState::Unpaid {
-            self.check_mint_quote_paid(&mut tx, &mut mint_quote).await?
-        }
+        let mut tx = if mint_quote.state == MintQuoteState::Unpaid {
+            self.check_mint_quote_paid(tx, &mut mint_quote).await?
+        } else {
+            tx
+        };
 
         match mint_quote.state {
             MintQuoteState::Unpaid => {

+ 7 - 3
crates/cdk/src/mint/ln.rs

@@ -10,9 +10,9 @@ impl Mint {
     /// Check the status of an ln payment for a quote
     pub async fn check_mint_quote_paid(
         &self,
-        tx: &mut Box<dyn MintTransaction<'_, database::Error> + Send + Sync + '_>,
+        tx: Box<dyn MintTransaction<'_, database::Error> + Send + Sync + '_>,
         quote: &mut MintQuote,
-    ) -> Result<(), Error> {
+    ) -> Result<Box<dyn MintTransaction<'_, database::Error> + Send + Sync + '_>, Error> {
         let ln = match self.ln.get(&PaymentProcessorKey::new(
             quote.unit.clone(),
             cdk_common::PaymentMethod::Bolt11,
@@ -25,10 +25,14 @@ impl Mint {
             }
         };
 
+        tx.commit().await?;
+
         let ln_status = ln
             .check_incoming_payment_status(&quote.request_lookup_id)
             .await?;
 
+        let mut tx = self.localstore.begin_transaction().await?;
+
         if ln_status != quote.state && quote.state != MintQuoteState::Issued {
             tx.update_mint_quote_state(&quote.id, ln_status).await?;
 
@@ -38,6 +42,6 @@ impl Mint {
                 .mint_quote_bolt11_status(quote.clone(), ln_status);
         }
 
-        Ok(())
+        Ok(tx)
     }
 }

+ 34 - 41
crates/cdk/src/mint/melt.rs

@@ -15,6 +15,7 @@ use super::{
 };
 use crate::amount::to_unit;
 use crate::cdk_payment::{MakePaymentResponse, MintPayment};
+use crate::mint::proof_writer::ProofWriter;
 use crate::mint::verification::Verification;
 use crate::mint::SigFlag;
 use crate::nuts::nut11::{enforce_sig_flag, EnforceSigFlag};
@@ -290,7 +291,7 @@ impl Mint {
         &self,
         tx: &mut Box<dyn MintTransaction<'_, database::Error> + Send + Sync + '_>,
         melt_request: &MeltRequest<Uuid>,
-    ) -> Result<MeltQuote, Error> {
+    ) -> Result<(ProofWriter, MeltQuote), Error> {
         let (state, quote) = tx
             .update_melt_quote_state(melt_request.quote(), MeltQuoteState::Pending)
             .await?;
@@ -312,8 +313,6 @@ impl Mint {
 
         ensure_cdk!(input_unit.is_some(), Error::UnsupportedUnit);
 
-        let input_ys = melt_request.inputs().ys()?;
-
         let fee = self.get_proofs_fee(melt_request.inputs()).await?;
 
         let required_total = quote.amount + quote.fee_reserve + fee;
@@ -334,27 +333,10 @@ impl Mint {
             ));
         }
 
-        if let Some(err) = tx
-            .add_proofs(melt_request.inputs().clone(), None)
-            .await
-            .err()
-        {
-            return match err {
-                cdk_common::database::Error::Duplicate => Err(Error::TokenPending),
-                cdk_common::database::Error::AttemptUpdateSpentProof => {
-                    Err(Error::TokenAlreadySpent)
-                }
-                err => Err(Error::Database(err)),
-            };
-        }
+        let mut proof_writer =
+            ProofWriter::new(self.localstore.clone(), self.pubsub_manager.clone());
 
-        self.check_ys_spendable(tx, &input_ys, State::Pending)
-            .await?;
-
-        for proof in melt_request.inputs() {
-            self.pubsub_manager
-                .proof_state((proof.y()?, State::Pending));
-        }
+        proof_writer.add_proofs(tx, melt_request.inputs()).await?;
 
         let EnforceSigFlag { sig_flag, .. } = enforce_sig_flag(melt_request.inputs().clone());
 
@@ -372,7 +354,7 @@ impl Mint {
         }
 
         tracing::debug!("Verified melt quote: {}", melt_request.quote());
-        Ok(quote)
+        Ok((proof_writer, quote))
     }
 
     /// Melt Bolt11
@@ -405,7 +387,7 @@ impl Mint {
 
         let mut tx = self.localstore.begin_transaction().await?;
 
-        let quote = self
+        let (proof_writer, quote) = self
             .verify_melt_request(&mut tx, melt_request)
             .await
             .map_err(|err| {
@@ -421,8 +403,9 @@ impl Mint {
                 err
             })?;
 
-        let (preimage, amount_spent_quote_unit, quote) = match settled_internally_amount {
-            Some(amount_spent) => (None, amount_spent, quote),
+        let (tx, preimage, amount_spent_quote_unit, quote) = match settled_internally_amount {
+            Some(amount_spent) => (tx, None, amount_spent, quote),
+
             None => {
                 // If the quote unit is SAT or MSAT we can check that the expected fees are
                 // provided. We also check if the quote is less then the invoice
@@ -456,6 +439,9 @@ impl Mint {
                     }
                 };
 
+                // Commit before talking to the external call
+                tx.commit().await?;
+
                 let pre = match ln
                     .make_payment(quote.clone(), partial_amount, Some(quote.fee_reserve))
                     .await
@@ -468,13 +454,13 @@ impl Mint {
                             if let Ok(ok) = check_payment_state(Arc::clone(ln), &quote).await {
                                 ok
                             } else {
-                                tx.commit().await?;
                                 return Err(Error::Internal);
                             };
 
                         if check_response.status == MeltQuoteState::Paid {
                             tracing::warn!("Pay invoice returned {} but check returned {}. Proofs stuck as pending", pay.status.to_string(), check_response.status.to_string());
-                            tx.commit().await?;
+
+                            proof_writer.commit();
 
                             return Err(Error::Internal);
                         }
@@ -496,14 +482,13 @@ impl Mint {
                             if let Ok(ok) = check_payment_state(Arc::clone(ln), &quote).await {
                                 ok
                             } else {
-                                tx.commit().await?;
+                                proof_writer.commit();
                                 return Err(Error::Internal);
                             };
                         // If there error is something else we want to check the status of the payment ensure it is not pending or has been made.
                         if check_response.status == MeltQuoteState::Paid {
                             tracing::warn!("Pay invoice returned an error but check returned {}. Proofs stuck as pending", check_response.status.to_string());
-                            tx.commit().await?;
-
+                            proof_writer.commit();
                             return Err(Error::Internal);
                         }
                         check_response
@@ -524,7 +509,7 @@ impl Mint {
                             "LN payment pending, proofs are stuck as pending for quote: {}",
                             melt_request.quote()
                         );
-                        tx.commit().await?;
+                        proof_writer.commit();
                         return Err(Error::PendingQuote);
                     }
                 }
@@ -536,6 +521,7 @@ impl Mint {
                     to_unit(pre.total_spent, &pre.unit, &quote.unit).unwrap_or_default();
 
                 let payment_lookup_id = pre.payment_lookup_id;
+                let mut tx = self.localstore.begin_transaction().await?;
 
                 if payment_lookup_id != quote.request_lookup_id {
                     tracing::info!(
@@ -550,9 +536,9 @@ impl Mint {
                     if let Err(err) = tx.add_melt_quote(melt_quote.clone()).await {
                         tracing::warn!("Could not update payment lookup id: {}", err);
                     }
-                    (pre.payment_proof, amount_spent, melt_quote)
+                    (tx, pre.payment_proof, amount_spent, melt_quote)
                 } else {
-                    (pre.payment_proof, amount_spent, quote)
+                    (tx, pre.payment_proof, amount_spent, quote)
                 }
             }
         };
@@ -560,7 +546,14 @@ impl Mint {
         // If we made it here the payment has been made.
         // We process the melt burning the inputs and returning change
         let res = self
-            .process_melt_request(tx, quote, melt_request, preimage, amount_spent_quote_unit)
+            .process_melt_request(
+                tx,
+                proof_writer,
+                quote,
+                melt_request,
+                preimage,
+                amount_spent_quote_unit,
+            )
             .await
             .map_err(|err| {
                 tracing::error!("Could not process melt request: {}", err);
@@ -576,6 +569,7 @@ impl Mint {
     pub async fn process_melt_request(
         &self,
         mut tx: Box<dyn MintTransaction<'_, database::Error> + Send + Sync + '_>,
+        mut proof_writer: ProofWriter,
         quote: MeltQuote,
         melt_request: &MeltRequest<Uuid>,
         payment_preimage: Option<String>,
@@ -585,7 +579,9 @@ impl Mint {
 
         let input_ys = melt_request.inputs().ys()?;
 
-        tx.update_proofs_states(&input_ys, State::Spent).await?;
+        proof_writer
+            .update_proofs_states(&mut tx, &input_ys, State::Spent)
+            .await?;
 
         tx.update_melt_quote_state(melt_request.quote(), MeltQuoteState::Paid)
             .await?;
@@ -597,10 +593,6 @@ impl Mint {
             MeltQuoteState::Paid,
         );
 
-        for public_key in input_ys {
-            self.pubsub_manager.proof_state((public_key, State::Spent));
-        }
-
         let mut change = None;
 
         // Check if there is change to return
@@ -666,6 +658,7 @@ impl Mint {
             }
         }
 
+        proof_writer.commit();
         tx.commit().await?;
 
         Ok(MeltQuoteBolt11Response {

+ 1 - 0
crates/cdk/src/mint/mod.rs

@@ -36,6 +36,7 @@ mod issue;
 mod keysets;
 mod ln;
 mod melt;
+mod proof_writer;
 mod start_up_check;
 pub mod subscription;
 mod swap;

+ 214 - 0
crates/cdk/src/mint/proof_writer.rs

@@ -0,0 +1,214 @@
+//! Proof writer
+use std::collections::{HashMap, HashSet};
+use std::sync::Arc;
+
+use cdk_common::database::{self, MintDatabase, MintTransaction};
+use cdk_common::{Error, Proofs, ProofsMethods, PublicKey, State};
+
+use super::subscription::PubSubManager;
+
+type Db = Arc<dyn MintDatabase<database::Error> + Send + Sync>;
+type Tx<'a, 'b> = Box<dyn MintTransaction<'a, database::Error> + Send + Sync + 'b>;
+
+/// Proof writer
+///
+/// This is a proof writer that emulates a database transaction but without holding the transaction
+/// alive while waiting for external events to be fully committed to the database; instead, it
+/// maintains a `pending` state.
+///
+/// This struct allows for premature exit on error, enabling it to remove proofs or reset their
+/// status.
+///
+/// This struct is not fully ACID. If the process exits due to a panic, and the `Drop` function
+/// cannot be run, the reset process should reset the state.
+pub struct ProofWriter {
+    db: Option<Db>,
+    pubsub_manager: Arc<PubSubManager>,
+    proof_original_states: Option<HashMap<PublicKey, Option<State>>>,
+}
+
+impl ProofWriter {
+    /// Creates a new ProofWriter on top of the database
+    pub fn new(db: Db, pubsub_manager: Arc<PubSubManager>) -> Self {
+        Self {
+            db: Some(db),
+            pubsub_manager,
+            proof_original_states: Some(Default::default()),
+        }
+    }
+
+    /// The changes are permanent, consume the struct removing the database, so the Drop does
+    /// nothing
+    pub fn commit(mut self) {
+        self.db.take();
+        self.proof_original_states.take();
+    }
+
+    /// Add proofs
+    pub async fn add_proofs(
+        &mut self,
+        tx: &mut Tx<'_, '_>,
+        proofs: &Proofs,
+    ) -> Result<Vec<PublicKey>, Error> {
+        let proof_states = if let Some(proofs) = self.proof_original_states.as_mut() {
+            proofs
+        } else {
+            return Err(Error::Internal);
+        };
+
+        if let Some(err) = tx.add_proofs(proofs.clone(), None).await.err() {
+            return match err {
+                cdk_common::database::Error::Duplicate => Err(Error::TokenPending),
+                cdk_common::database::Error::AttemptUpdateSpentProof => {
+                    Err(Error::TokenAlreadySpent)
+                }
+                err => Err(Error::Database(err)),
+            };
+        }
+
+        let ys = proofs.ys()?;
+
+        for pk in ys.iter() {
+            proof_states.insert(*pk, None);
+        }
+
+        self.update_proofs_states(tx, &ys, State::Pending).await?;
+
+        Ok(ys)
+    }
+
+    /// Update proof status
+    pub async fn update_proofs_states(
+        &mut self,
+        tx: &mut Tx<'_, '_>,
+        ys: &[PublicKey],
+        new_proof_state: State,
+    ) -> Result<(), Error> {
+        let proof_states = if let Some(proofs) = self.proof_original_states.as_mut() {
+            proofs
+        } else {
+            return Err(Error::Internal);
+        };
+
+        let original_proofs_state = match tx.update_proofs_states(ys, new_proof_state).await {
+            Ok(states) => states,
+            Err(database::Error::AttemptUpdateSpentProof)
+            | Err(database::Error::AttemptRemoveSpentProof) => {
+                return Err(Error::TokenAlreadySpent)
+            }
+            Err(err) => return Err(err.into()),
+        };
+
+        if ys.len() != original_proofs_state.len() {
+            return Err(Error::Internal);
+        }
+
+        let proofs_state = original_proofs_state
+            .iter()
+            .flatten()
+            .map(|x| x.to_owned())
+            .collect::<HashSet<State>>();
+
+        let forbidden_states = if new_proof_state == State::Pending {
+            // If the new state is `State::Pending` it cannot be pending already
+            vec![State::Pending, State::Spent]
+        } else {
+            // For other state it cannot be spent
+            vec![State::Spent]
+        };
+
+        for forbidden_state in forbidden_states.iter() {
+            if proofs_state.contains(forbidden_state) {
+                reset_proofs_to_original_state(tx, ys, original_proofs_state).await?;
+
+                return Err(if proofs_state.contains(&State::Pending) {
+                    Error::TokenPending
+                } else {
+                    Error::TokenAlreadySpent
+                });
+            }
+        }
+
+        for (idx, ys) in ys.iter().enumerate() {
+            proof_states
+                .entry(*ys)
+                .or_insert(original_proofs_state[idx]);
+        }
+
+        for pk in ys {
+            self.pubsub_manager.proof_state((*pk, new_proof_state));
+        }
+
+        Ok(())
+    }
+
+    /// Rollback all changes in this ProofWriter consuming it.
+    pub async fn rollback(mut self, tx: &mut Tx<'_, '_>) -> Result<(), Error> {
+        let (ys, original_states) = if let Some(proofs) = self.proof_original_states.take() {
+            proofs.into_iter().unzip::<_, _, Vec<_>, Vec<_>>()
+        } else {
+            return Ok(());
+        };
+        reset_proofs_to_original_state(tx, &ys, original_states).await?;
+        Ok(())
+    }
+}
+
+/// Resets proofs to their original states or removes them
+#[inline(always)]
+async fn reset_proofs_to_original_state(
+    tx: &mut Tx<'_, '_>,
+    ys: &[PublicKey],
+    original_states: Vec<Option<State>>,
+) -> Result<(), Error> {
+    let mut ys_by_state = HashMap::new();
+    let mut unknown_proofs = Vec::new();
+    for (y, state) in ys.iter().zip(original_states) {
+        if let Some(state) = state {
+            // Skip attempting to update proofs that were originally spent
+            if state != State::Spent {
+                ys_by_state.entry(state).or_insert_with(Vec::new).push(*y);
+            }
+        } else {
+            unknown_proofs.push(*y);
+        }
+    }
+
+    for (state, ys) in ys_by_state {
+        tx.update_proofs_states(&ys, state).await?;
+    }
+
+    tx.remove_proofs(&unknown_proofs, None).await?;
+
+    Ok(())
+}
+
+#[inline(always)]
+async fn rollback(
+    db: Arc<dyn MintDatabase<database::Error> + Send + Sync>,
+    ys: Vec<PublicKey>,
+    original_states: Vec<Option<State>>,
+) -> Result<(), Error> {
+    let mut tx = db.begin_transaction().await?;
+    reset_proofs_to_original_state(&mut tx, &ys, original_states).await?;
+    tx.commit().await?;
+
+    Ok(())
+}
+
+impl Drop for ProofWriter {
+    fn drop(&mut self) {
+        let db = if let Some(db) = self.db.take() {
+            db
+        } else {
+            return;
+        };
+        let (ys, states) = if let Some(proofs) = self.proof_original_states.take() {
+            proofs.into_iter().unzip()
+        } else {
+            return;
+        };
+
+        tokio::spawn(rollback(db, ys, states));
+    }
+}

+ 6 - 4
crates/cdk/src/mint/start_up_check.rs

@@ -21,14 +21,16 @@ impl Mint {
             "There are {} pending and unpaid mint quotes.",
             all_quotes.len()
         );
-        let mut tx = self.localstore.begin_transaction().await?;
         for mut quote in all_quotes.into_iter() {
             tracing::debug!("Checking status of mint quote: {}", quote.id);
-            if let Err(err) = self.check_mint_quote_paid(&mut tx, &mut quote).await {
-                tracing::error!("Could not check status of {}, {}", quote.id, err);
+            match self
+                .check_mint_quote_paid(self.localstore.begin_transaction().await?, &mut quote)
+                .await
+            {
+                Ok(tx) => tx.commit().await?,
+                Err(err) => tracing::error!("Could not check status of {}, {}", quote.id, err),
             }
         }
-        tx.commit().await?;
         Ok(())
     }
 

+ 10 - 28
crates/cdk/src/mint/swap.rs

@@ -1,9 +1,9 @@
 use tracing::instrument;
 
 use super::nut11::{enforce_sig_flag, EnforceSigFlag};
+use super::proof_writer::ProofWriter;
 use super::{Mint, PublicKey, SigFlag, State, SwapRequest, SwapResponse};
-use crate::nuts::nut00::ProofsMethods;
-use crate::{cdk_database, Error};
+use crate::Error;
 
 impl Mint {
     /// Process Swap
@@ -24,22 +24,10 @@ impl Mint {
 
         self.validate_sig_flag(&swap_request).await?;
 
-        // After swap request is fully validated, add the new proofs to DB
-        let input_ys = swap_request.inputs().ys()?;
-        if let Some(err) = tx
-            .add_proofs(swap_request.inputs().clone(), None)
-            .await
-            .err()
-        {
-            return match err {
-                cdk_common::database::Error::Duplicate => Err(Error::TokenPending),
-                cdk_common::database::Error::AttemptUpdateSpentProof => {
-                    Err(Error::TokenAlreadySpent)
-                }
-                err => Err(Error::Database(err)),
-            };
-        }
-        self.check_ys_spendable(&mut tx, &input_ys, State::Pending)
+        let mut proof_writer =
+            ProofWriter::new(self.localstore.clone(), self.pubsub_manager.clone());
+        let input_ys = proof_writer
+            .add_proofs(&mut tx, swap_request.inputs())
             .await?;
 
         let mut promises = Vec::with_capacity(swap_request.outputs().len());
@@ -49,16 +37,9 @@ impl Mint {
             promises.push(blinded_signature);
         }
 
-        tx.update_proofs_states(&input_ys, State::Spent)
-            .await
-            .map_err(|e| match e {
-                cdk_database::Error::AttemptUpdateSpentProof => Error::TokenAlreadySpent,
-                e => e.into(),
-            })?;
-
-        for pub_key in input_ys {
-            self.pubsub_manager.proof_state((pub_key, State::Spent));
-        }
+        proof_writer
+            .update_proofs_states(&mut tx, &input_ys, State::Spent)
+            .await?;
 
         tx.add_blind_signatures(
             &swap_request
@@ -71,6 +52,7 @@ impl Mint {
         )
         .await?;
 
+        proof_writer.commit();
         tx.commit().await?;
 
         Ok(SwapResponse::new(promises))