Sfoglia il codice sorgente

Address race condition and add private all functions that receive a transaction

Cesar Rodas 4 settimane fa
parent
commit
40a97ba785

+ 1 - 1
crates/cdk-cli/src/sub_commands/check_pending.rs

@@ -14,7 +14,7 @@ pub async fn check_pending(multi_mint_wallet: &MultiMintWallet) -> Result<()> {
         // Get all pending proofs
         //
 
-        let pending_proofs = wallet.get_pending_proofs(Some(&mut tx)).await?;
+        let pending_proofs = wallet.get_pending_proofs_with_tx(&mut tx).await?;
         if pending_proofs.is_empty() {
             println!("No pending proofs found");
             continue;

+ 1 - 1
crates/cdk-cli/src/sub_commands/list_mint_proofs.rs

@@ -35,7 +35,7 @@ async fn list_proofs(
         }
 
         // Pending proofs
-        let pending_proofs = wallet.get_pending_proofs(None).await?;
+        let pending_proofs = wallet.get_pending_proofs().await?;
         for proof in pending_proofs {
             println!(
                 "| {:8} | {:4} | {:8} | {:64} | {}",

+ 3 - 0
crates/cdk-common/src/database/wallet.rs

@@ -22,6 +22,9 @@ use crate::wallet::{
 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
 pub trait DatabaseTransaction<'a, Error>: DbTransactionFinalizer<Err = Error> {
+    /// Get mint from storage
+    async fn get_mint(&mut self, mint_url: MintUrl) -> Result<Option<MintInfo>, Error>;
+
     /// Get [`Keys`] from storage
     async fn get_keys(&mut self, id: &Id) -> Result<Option<Keys>, Error>;
 

+ 14 - 0
crates/cdk-ffi/src/database.rs

@@ -459,6 +459,20 @@ impl<'a> cdk::cdk_database::WalletDatabaseTransaction<'a, cdk::cdk_database::Err
             .await
             .map_err(|e| cdk::cdk_database::Error::Database(e.to_string().into()))
     }
+
+    async fn get_mint(
+        &mut self,
+        mint_url: cdk::mint_url::MintUrl,
+    ) -> Result<Option<cdk::nuts::MintInfo>, cdk::cdk_database::Error> {
+        let ffi_mint_url = mint_url.into();
+        let result = self
+            .ffi_db
+            .get_mint(ffi_mint_url)
+            .await
+            .map_err(|e| cdk::cdk_database::Error::Database(e.to_string().into()))?;
+
+        Ok(result.map(Into::into))
+    }
 }
 
 #[async_trait::async_trait]

+ 6 - 6
crates/cdk-ffi/src/wallet.rs

@@ -82,7 +82,7 @@ impl Wallet {
 
     /// Get mint info
     pub async fn get_mint_info(&self) -> Result<Option<MintInfo>, FfiError> {
-        let info = self.inner.fetch_mint_info(None).await?;
+        let info = self.inner.fetch_mint_info().await?;
         Ok(info.map(Into::into))
     }
 
@@ -273,9 +273,9 @@ impl Wallet {
         for state in states {
             let proofs = match state {
                 ProofState::Unspent => self.inner.get_unspent_proofs().await?,
-                ProofState::Pending => self.inner.get_pending_proofs(None).await?,
+                ProofState::Pending => self.inner.get_pending_proofs().await?,
                 ProofState::Reserved => self.inner.get_reserved_proofs().await?,
-                ProofState::PendingSpent => self.inner.get_pending_spent_proofs(None).await?,
+                ProofState::PendingSpent => self.inner.get_pending_spent_proofs().await?,
                 ProofState::Spent => {
                     // CDK doesn't have a method to get spent proofs directly
                     // They are removed from the database when spent
@@ -297,7 +297,7 @@ impl Wallet {
             proofs.into_iter().map(|p| p.try_into()).collect();
         let cdk_proofs = cdk_proofs?;
 
-        let proof_states = self.inner.check_proofs_spent(cdk_proofs, None).await?;
+        let proof_states = self.inner.check_proofs_spent(cdk_proofs).await?;
         // Convert ProofState to bool (spent = true, unspent = false)
         let spent_bools = proof_states
             .into_iter()
@@ -353,7 +353,7 @@ impl Wallet {
 
     /// Refresh keysets from the mint
     pub async fn refresh_keysets(&self) -> Result<Vec<KeySetInfo>, FfiError> {
-        let keysets = self.inner.refresh_keysets(None).await?;
+        let keysets = self.inner.refresh_keysets().await?;
         Ok(keysets.into_iter().map(Into::into).collect())
     }
 
@@ -369,7 +369,7 @@ impl Wallet {
             .map_err(|e| FfiError::Generic { msg: e.to_string() })?;
         Ok(self
             .inner
-            .get_keyset_fees_and_amounts_by_id(id, None)
+            .get_keyset_fees_and_amounts_by_id(id)
             .await?
             .fee())
     }

+ 3 - 3
crates/cdk-integration-tests/tests/integration_tests_pure.rs

@@ -99,7 +99,7 @@ async fn test_swap_to_send() {
         HashSet::<_, RandomState>::from_iter(token_proofs.ys().expect("Failed to get ys")),
         HashSet::from_iter(
             wallet_alice
-                .get_pending_spent_proofs(None)
+                .get_pending_spent_proofs()
                 .await
                 .expect("Failed to get pending spent proofs")
                 .ys()
@@ -204,7 +204,7 @@ async fn test_mint_nut06() {
 
     let initial_mint_url = wallet_alice.mint_url.clone();
     let mint_info_before = wallet_alice
-        .fetch_mint_info(None)
+        .fetch_mint_info()
         .await
         .expect("Failed to get mint info")
         .unwrap();
@@ -772,7 +772,7 @@ async fn test_mint_change_with_fee_melt() {
         .await
         .unwrap();
 
-    let mut tx = wallet_alice
+    let tx = wallet_alice
         .localstore
         .begin_db_transaction()
         .await

+ 15 - 0
crates/cdk-redb/src/wallet/mod.rs

@@ -487,6 +487,21 @@ impl WalletDatabase for WalletRedbDatabase {
 impl<'a> cdk_common::database::WalletDatabaseTransaction<'a, database::Error>
     for RedbWalletTransaction
 {
+    #[instrument(skip(self))]
+    async fn get_mint(&mut self, mint_url: MintUrl) -> Result<Option<MintInfo>, database::Error> {
+        let txn = self.txn().map_err(Into::<database::Error>::into)?;
+        let table = txn.open_table(MINTS_TABLE).map_err(Error::from)?;
+
+        if let Some(mint_info) = table
+            .get(mint_url.to_string().as_str())
+            .map_err(Error::from)?
+        {
+            return Ok(serde_json::from_str(mint_info.value()).map_err(Error::from)?);
+        }
+
+        Ok(None)
+    }
+
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
     async fn get_keys(&mut self, keyset_id: &Id) -> Result<Option<Keys>, database::Error> {
         let txn = self.txn().map_err(Into::<database::Error>::into)?;

+ 29 - 0
crates/cdk-sql-common/src/wallet/mod.rs

@@ -152,6 +152,35 @@ impl<'a, RM> WalletDatabaseTransaction<'a, Error> for SQLWalletTransaction<RM>
 where
     RM: DatabasePool + 'static,
 {
+    #[instrument(skip(self))]
+    async fn get_mint(&mut self, mint_url: MintUrl) -> Result<Option<MintInfo>, Error> {
+        Ok(query(
+            r#"
+            SELECT
+                name,
+                pubkey,
+                version,
+                description,
+                description_long,
+                contact,
+                nuts,
+                icon_url,
+                motd,
+                urls,
+                mint_time,
+                tos_url
+            FROM
+                mint
+            WHERE mint_url = :mint_url
+            "#,
+        )?
+        .bind("mint_url", mint_url.to_string())
+        .fetch_one(&self.inner)
+        .await?
+        .map(sql_row_to_mint_info)
+        .transpose()?)
+    }
+
     /// Read the key from a transaction but it will not lock it
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
     async fn get_keys(&mut self, keyset_id: &Id) -> Result<Option<Keys>, Error> {

+ 1 - 1
crates/cdk/src/wallet/balance.rs

@@ -23,7 +23,7 @@ impl Wallet {
     /// Total pending balance
     #[instrument(skip(self))]
     pub async fn total_pending_balance(&self) -> Result<Amount, Error> {
-        Ok(self.get_pending_proofs(None).await?.total_amount()?)
+        Ok(self.get_pending_proofs().await?.total_amount()?)
     }
 
     /// Total reserved balance

+ 9 - 8
crates/cdk/src/wallet/issue/issue_bolt11.rs

@@ -54,12 +54,11 @@ impl Wallet {
 
         let mut tx = self.localstore.begin_db_transaction().await?;
 
-        self.refresh_keysets(Some(&mut tx)).await?;
+        self.refresh_keysets_with_tx(&mut tx).await?;
 
         // If we have a description, we check that the mint supports it.
         if description.is_some() {
-            let settings = self
-                .localstore
+            let settings = tx
                 .get_mint(mint_url.clone())
                 .await?
                 .ok_or(Error::IncorrectMint)?
@@ -206,7 +205,7 @@ impl Wallet {
     ) -> Result<Proofs, Error> {
         let mut tx = self.localstore.begin_db_transaction().await?;
 
-        self.refresh_keysets(Some(&mut tx)).await?;
+        self.refresh_keysets_with_tx(&mut tx).await?;
 
         let quote_info = tx
             .get_mint_quote(quote_id)
@@ -230,9 +229,9 @@ impl Wallet {
             tracing::warn!("Attempting to mint with expired quote.");
         }
 
-        let active_keyset_id = self.fetch_active_keyset(Some(&mut tx)).await?.id;
+        let active_keyset_id = self.fetch_active_keyset_with_tx(&mut tx).await?.id;
         let fee_and_amounts = self
-            .get_keyset_fees_and_amounts_by_id(active_keyset_id, Some(&mut tx))
+            .get_keyset_fees_and_amounts_by_id_with_tx(&mut tx, active_keyset_id)
             .await?;
 
         let premint_secrets = match &spending_conditions {
@@ -286,13 +285,15 @@ impl Wallet {
         let mint_res = self.client.post_mint(request).await?;
 
         let keys = self
-            .load_keyset_keys(active_keyset_id, Some(&mut tx))
+            .load_keyset_keys_with_tx(&mut tx, active_keyset_id)
             .await?;
 
         // Verify the signature DLEQ is valid
         {
             for (sig, premint) in mint_res.signatures.iter().zip(&premint_secrets.secrets) {
-                let keys = self.load_keyset_keys(sig.keyset_id, Some(&mut tx)).await?;
+                let keys = self
+                    .load_keyset_keys_with_tx(&mut tx, sig.keyset_id)
+                    .await?;
                 let key = keys.amount_key(sig.amount).ok_or(Error::AmountKey)?;
                 match sig.verify_dleq(key, premint.blinded_message.blinded_secret) {
                     Ok(_) | Err(nut12::Error::MissingDleqProof) => (),

+ 8 - 6
crates/cdk/src/wallet/issue/issue_bolt12.rs

@@ -31,7 +31,7 @@ impl Wallet {
 
         let mut tx = self.localstore.begin_db_transaction().await?;
 
-        self.refresh_keysets(Some(&mut tx)).await?;
+        self.refresh_keysets_with_tx(&mut tx).await?;
 
         // If we have a description, we check that the mint supports it.
         if description.is_some() {
@@ -90,7 +90,7 @@ impl Wallet {
     ) -> Result<Proofs, Error> {
         let mut tx = self.localstore.begin_db_transaction().await?;
 
-        self.refresh_keysets(Some(&mut tx)).await?;
+        self.refresh_keysets_with_tx(&mut tx).await?;
 
         let quote_info = tx.get_mint_quote(quote_id).await?;
 
@@ -104,9 +104,9 @@ impl Wallet {
             return Err(Error::UnknownQuote);
         };
 
-        let active_keyset_id = self.fetch_active_keyset(Some(&mut tx)).await?.id;
+        let active_keyset_id = self.fetch_active_keyset_with_tx(&mut tx).await?.id;
         let fee_and_amounts = self
-            .get_keyset_fees_and_amounts_by_id(active_keyset_id, Some(&mut tx))
+            .get_keyset_fees_and_amounts_by_id_with_tx(&mut tx, active_keyset_id)
             .await?;
 
         let amount = match amount {
@@ -180,13 +180,15 @@ impl Wallet {
         let mint_res = self.client.post_mint(request).await?;
 
         let keys = self
-            .load_keyset_keys(active_keyset_id, Some(&mut tx))
+            .load_keyset_keys_with_tx(&mut tx, active_keyset_id)
             .await?;
 
         // Verify the signature DLEQ is valid
         {
             for (sig, premint) in mint_res.signatures.iter().zip(&premint_secrets.secrets) {
-                let keys = self.load_keyset_keys(sig.keyset_id, Some(&mut tx)).await?;
+                let keys = self
+                    .load_keyset_keys_with_tx(&mut tx, sig.keyset_id)
+                    .await?;
                 let key = keys.amount_key(sig.amount).ok_or(Error::AmountKey)?;
                 match sig.verify_dleq(key, premint.blinded_message.blinded_secret) {
                     Ok(_) | Err(nut12::Error::MissingDleqProof) => (),

+ 111 - 80
crates/cdk/src/wallet/keysets.rs

@@ -13,30 +13,36 @@ impl Wallet {
     ///
     /// Returns keys from local database if they are already stored.
     /// If keys are not found locally, goes online to query the mint for the keyset and stores the [`Keys`] in local database.
+    #[instrument(skip(self))]
+    pub async fn load_keyset_keys(&self, keyset_id: Id) -> Result<Keys, Error> {
+        if let Some(keys) = self.localstore.get_keys(&keyset_id).await? {
+            Ok(keys)
+        } else {
+            tracing::debug!(
+                "Keyset {} not in db fetching from mint {}",
+                keyset_id,
+                self.mint_url
+            );
+
+            let keys = self.client.get_mint_keyset(keyset_id).await?;
+
+            keys.verify_id()?;
+            let mut tx = self.localstore.begin_db_transaction().await?;
+            tx.add_keys(keys.clone()).await?;
+            tx.commit().await?;
+            Ok(keys.keys)
+        }
+    }
+
+    /// Load keyset keys with transaction
     #[instrument(skip(self, tx))]
-    pub async fn load_keyset_keys(
+    pub async fn load_keyset_keys_with_tx(
         &self,
+        tx: &mut Tx<'_, '_>,
         keyset_id: Id,
-        tx: Option<&mut Tx<'_, '_>>,
     ) -> Result<Keys, Error> {
-        let keys = if let Some(tx) = tx {
-            if let Some(keys) = tx.get_keys(&keyset_id).await? {
-                keys
-            } else {
-                tracing::debug!(
-                    "Keyset {} not in db fetching from mint {}",
-                    keyset_id,
-                    self.mint_url
-                );
-
-                let keys = self.client.get_mint_keyset(keyset_id).await?;
-
-                keys.verify_id()?;
-                tx.add_keys(keys.clone()).await?;
-                keys.keys
-            }
-        } else if let Some(keys) = self.localstore.get_keys(&keyset_id).await? {
-            keys
+        if let Some(keys) = tx.get_keys(&keyset_id).await? {
+            Ok(keys)
         } else {
             tracing::debug!(
                 "Keyset {} not in db fetching from mint {}",
@@ -47,13 +53,9 @@ impl Wallet {
             let keys = self.client.get_mint_keyset(keyset_id).await?;
 
             keys.verify_id()?;
-            let mut tx = self.localstore.begin_db_transaction().await?;
             tx.add_keys(keys.clone()).await?;
-            tx.commit().await?;
-            keys.keys
-        };
-
-        Ok(keys)
+            Ok(keys.keys)
+        }
     }
 
     /// Get keysets from local database or go online if missing
@@ -72,7 +74,7 @@ impl Wallet {
             Some(keysets_info) => Ok(keysets_info),
             None => {
                 // If we don't have any keysets, fetch them from the mint
-                let keysets = self.refresh_keysets(None).await?;
+                let keysets = self.refresh_keysets().await?;
                 Ok(keysets)
             }
         }
@@ -97,48 +99,41 @@ impl Wallet {
 
     /// Refresh keysets by fetching the latest from mint - always goes online
     ///
+    /// Refresh keysets from mint
+    ///
     /// This method always goes online to fetch the latest keyset information from the mint.
     /// It updates the local database with the fetched keysets and ensures we have keys
     /// for all active keysets. This is used when operations need the most up-to-date
     /// keyset information and are willing to go online.
+    #[instrument(skip(self))]
+    pub async fn refresh_keysets(&self) -> Result<KeySetInfos, Error> {
+        let mut tx = self.localstore.begin_db_transaction().await?;
+        let result = self.refresh_keysets_with_tx(&mut tx).await?;
+        tx.commit().await?;
+        Ok(result)
+    }
+
+    /// Refresh keysets from mint with transaction
     #[instrument(skip(self, tx))]
-    pub async fn refresh_keysets(&self, tx: Option<&mut Tx<'_, '_>>) -> Result<KeySetInfos, Error> {
+    pub async fn refresh_keysets_with_tx(&self, tx: &mut Tx<'_, '_>) -> Result<KeySetInfos, Error> {
         tracing::debug!("Refreshing keysets and ensuring we have keys");
 
-        let mut tx = tx;
-        let _ = self
-            .fetch_mint_info(if let Some(tx) = tx.as_mut() {
-                Some(*tx)
-            } else {
-                None
-            })
-            .await?;
+        let _ = self.fetch_mint_info_with_tx(tx).await?;
 
         // Fetch all current keysets from mint
         let keysets_response = self.client.get_mint_keysets().await?;
         let all_keysets = keysets_response.keysets;
 
         // Update local storage with keyset info
-        if let Some(tx) = tx.as_mut() {
-            tx.add_mint_keysets(self.mint_url.clone(), all_keysets.clone())
-                .await?;
-        } else {
-            let mut tx = self.localstore.begin_db_transaction().await?;
-            tx.add_mint_keysets(self.mint_url.clone(), all_keysets.clone())
-                .await?;
-            tx.commit().await?;
-        }
+        tx.add_mint_keysets(self.mint_url.clone(), all_keysets.clone())
+            .await?;
 
         // Filter for active keysets matching our unit
         let keysets: KeySetInfos = all_keysets.unit(self.unit.clone()).cloned().collect();
 
         // Ensure we have keys for all active keysets
         for keyset in &keysets {
-            if let Some(tx) = tx.as_mut() {
-                self.load_keyset_keys(keyset.id, Some(*tx)).await?;
-            } else {
-                self.load_keyset_keys(keyset.id, None).await?;
-            }
+            self.load_keyset_keys_with_tx(tx, keyset.id).await?;
         }
 
         Ok(keysets)
@@ -149,12 +144,23 @@ impl Wallet {
     /// This method always goes online to refresh keysets from the mint and then returns
     /// the active keyset with the minimum input fees. Use this when you need the most
     /// up-to-date keyset information for operations.
+    #[instrument(skip(self))]
+    pub async fn fetch_active_keyset(&self) -> Result<KeySetInfo, Error> {
+        self.refresh_keysets()
+            .await?
+            .active()
+            .min_by_key(|k| k.input_fee_ppk)
+            .cloned()
+            .ok_or(Error::NoActiveKeyset)
+    }
+
+    /// Get the active keyset with the lowest fees with transaction - always goes online
     #[instrument(skip(self, tx))]
-    pub async fn fetch_active_keyset(
+    pub async fn fetch_active_keyset_with_tx(
         &self,
-        tx: Option<&mut Tx<'_, '_>>,
+        tx: &mut Tx<'_, '_>,
     ) -> Result<KeySetInfo, Error> {
-        self.refresh_keysets(tx)
+        self.refresh_keysets_with_tx(tx)
             .await?
             .active()
             .min_by_key(|k| k.input_fee_ppk)
@@ -187,21 +193,41 @@ impl Wallet {
     /// Returns a HashMap of keyset IDs to their input fee rates (per-proof-per-thousand)
     /// from cached keysets in the local database. This is an offline operation that does
     /// not contact the mint. If no keysets are found locally, returns an error.
-    pub async fn get_keyset_fees_and_amounts(
+    pub async fn get_keyset_fees_and_amounts(&self) -> Result<KeysetFeeAndAmounts, Error> {
+        let keysets = self
+            .localstore
+            .get_mint_keysets(self.mint_url.clone())
+            .await?
+            .ok_or(Error::UnknownKeySet)?;
+
+        let mut fees = HashMap::new();
+        for keyset in keysets {
+            fees.insert(
+                keyset.id,
+                (
+                    keyset.input_fee_ppk,
+                    self.load_keyset_keys(keyset.id)
+                        .await?
+                        .iter()
+                        .map(|(amount, _)| amount.to_u64())
+                        .collect::<Vec<_>>(),
+                )
+                    .into(),
+            );
+        }
+
+        Ok(fees)
+    }
+
+    /// Get keyset fees and amounts for mint with transaction
+    pub async fn get_keyset_fees_and_amounts_with_tx(
         &self,
-        tx: Option<&mut Tx<'_, '_>>,
+        tx: &mut Tx<'_, '_>,
     ) -> Result<KeysetFeeAndAmounts, Error> {
-        let mut tx = tx;
-        let keysets = if let Some(tx) = tx.as_mut() {
-            tx.get_mint_keysets(self.mint_url.clone())
-                .await?
-                .ok_or(Error::UnknownKeySet)?
-        } else {
-            self.localstore
-                .get_mint_keysets(self.mint_url.clone())
-                .await?
-                .ok_or(Error::UnknownKeySet)?
-        };
+        let keysets = tx
+            .get_mint_keysets(self.mint_url.clone())
+            .await?
+            .ok_or(Error::UnknownKeySet)?;
 
         let mut fees = HashMap::new();
         for keyset in keysets {
@@ -209,18 +235,11 @@ impl Wallet {
                 keyset.id,
                 (
                     keyset.input_fee_ppk,
-                    self.load_keyset_keys(
-                        keyset.id,
-                        if let Some(tx) = tx.as_mut() {
-                            Some(*tx)
-                        } else {
-                            None
-                        },
-                    )
-                    .await?
-                    .iter()
-                    .map(|(amount, _)| amount.to_u64())
-                    .collect::<Vec<_>>(),
+                    self.load_keyset_keys_with_tx(tx, keyset.id)
+                        .await?
+                        .iter()
+                        .map(|(amount, _)| amount.to_u64())
+                        .collect::<Vec<_>>(),
                 )
                     .into(),
             );
@@ -237,9 +256,21 @@ impl Wallet {
     pub async fn get_keyset_fees_and_amounts_by_id(
         &self,
         keyset_id: Id,
-        tx: Option<&mut Tx<'_, '_>>,
     ) -> Result<FeeAndAmounts, Error> {
-        self.get_keyset_fees_and_amounts(tx)
+        self.get_keyset_fees_and_amounts()
+            .await?
+            .get(&keyset_id)
+            .cloned()
+            .ok_or(Error::UnknownKeySet)
+    }
+
+    /// Get keyset fees and amounts for mint by keyset id with transaction
+    pub async fn get_keyset_fees_and_amounts_by_id_with_tx(
+        &self,
+        tx: &mut Tx<'_, '_>,
+        keyset_id: Id,
+    ) -> Result<FeeAndAmounts, Error> {
+        self.get_keyset_fees_and_amounts_with_tx(tx)
             .await?
             .get(&keyset_id)
             .cloned()

+ 5 - 5
crates/cdk/src/wallet/melt/melt_bolt11.rs

@@ -49,7 +49,7 @@ impl Wallet {
         request: String,
         options: Option<MeltOptions>,
     ) -> Result<MeltQuote, Error> {
-        self.refresh_keysets(None).await?;
+        self.refresh_keysets().await?;
 
         let invoice = Bolt11Invoice::from_str(&request)?;
 
@@ -163,7 +163,7 @@ impl Wallet {
             .await?
             .ok_or(Error::UnknownQuote)?;
 
-        let active_keyset_id = self.fetch_active_keyset(Some(&mut tx)).await?.id;
+        let active_keyset_id = self.fetch_active_keyset_with_tx(&mut tx).await?.id;
 
         let active_keys = tx
             .get_keys(&active_keyset_id)
@@ -400,16 +400,16 @@ impl Wallet {
         let inputs_needed_amount = quote_info.amount + quote_info.fee_reserve;
 
         let available_proofs = self
-            .get_proofs_with(Some(vec![State::Unspent]), None, Some(&mut tx))
+            .get_proofs_with_tx(&mut tx, Some(vec![State::Unspent]), None)
             .await?;
 
         let active_keyset_ids = self
-            .refresh_keysets(Some(&mut tx))
+            .refresh_keysets_with_tx(&mut tx)
             .await?
             .into_iter()
             .map(|k| k.id)
             .collect();
-        let keyset_fees = self.get_keyset_fees_and_amounts(Some(&mut tx)).await?;
+        let keyset_fees = self.get_keyset_fees_and_amounts_with_tx(&mut tx).await?;
         let (mut input_proofs, mut exchange) = Wallet::select_exact_proofs(
             inputs_needed_amount,
             available_proofs,

+ 1 - 1
crates/cdk/src/wallet/melt/mod.rs

@@ -59,7 +59,7 @@ impl Wallet {
                 response.state
             );
             if response.state == MeltQuoteState::Paid {
-                let pending_proofs = self.get_pending_proofs(Some(tx)).await?;
+                let pending_proofs = self.get_pending_proofs_with_tx(tx).await?;
                 let proofs_total = pending_proofs.total_amount().unwrap_or_default();
                 let change_total = response.change_amount().unwrap_or_default();
 

+ 96 - 35
crates/cdk/src/wallet/mod.rs

@@ -206,35 +206,57 @@ impl Wallet {
 
     /// Fee required for proof set
     #[instrument(skip_all)]
-    pub async fn get_proofs_fee(
+    pub async fn get_proofs_fee(&self, proofs: &Proofs) -> Result<Amount, Error> {
+        let proofs_per_keyset = proofs.count_by_keyset();
+        self.get_proofs_fee_by_count(proofs_per_keyset).await
+    }
+
+    /// Fee required for proof set with transaction
+    #[instrument(skip_all)]
+    pub async fn get_proofs_fee_with_tx(
         &self,
-        tx: Option<&mut Tx<'_, '_>>,
+        tx: &mut Tx<'_, '_>,
         proofs: &Proofs,
     ) -> Result<Amount, Error> {
         let proofs_per_keyset = proofs.count_by_keyset();
-        self.get_proofs_fee_by_count(tx, proofs_per_keyset).await
+        self.get_proofs_fee_by_count_with_tx(tx, proofs_per_keyset)
+            .await
     }
 
     /// Fee required for proof set by count
     pub async fn get_proofs_fee_by_count(
         &self,
-        tx: Option<&mut Tx<'_, '_>>,
         proofs_per_keyset: HashMap<Id, u64>,
     ) -> Result<Amount, Error> {
         let mut fee_per_keyset = HashMap::new();
-        let mut tx = tx;
 
         for keyset_id in proofs_per_keyset.keys() {
-            let mint_keyset_info = if let Some(tx) = tx.as_mut() {
-                tx.get_keyset_by_id(keyset_id)
-                    .await?
-                    .ok_or(Error::UnknownKeySet)?
-            } else {
-                self.localstore
-                    .get_keyset_by_id(keyset_id)
-                    .await?
-                    .ok_or(Error::UnknownKeySet)?
-            };
+            let mint_keyset_info = self
+                .localstore
+                .get_keyset_by_id(keyset_id)
+                .await?
+                .ok_or(Error::UnknownKeySet)?;
+            fee_per_keyset.insert(*keyset_id, mint_keyset_info.input_fee_ppk);
+        }
+
+        let fee = calculate_fee(&proofs_per_keyset, &fee_per_keyset)?;
+
+        Ok(fee)
+    }
+
+    /// Fee required for proof set by count with transaction
+    pub async fn get_proofs_fee_by_count_with_tx(
+        &self,
+        tx: &mut Tx<'_, '_>,
+        proofs_per_keyset: HashMap<Id, u64>,
+    ) -> Result<Amount, Error> {
+        let mut fee_per_keyset = HashMap::new();
+
+        for keyset_id in proofs_per_keyset.keys() {
+            let mint_keyset_info = tx
+                .get_keyset_by_id(keyset_id)
+                .await?
+                .ok_or(Error::UnknownKeySet)?;
             fee_per_keyset.insert(*keyset_id, mint_keyset_info.input_fee_ppk);
         }
 
@@ -275,10 +297,19 @@ impl Wallet {
     }
 
     /// Query mint for current mint information
+    #[instrument(skip(self))]
+    pub async fn fetch_mint_info(&self) -> Result<Option<MintInfo>, Error> {
+        let mut tx = self.localstore.begin_db_transaction().await?;
+        let result = self.fetch_mint_info_with_tx(&mut tx).await?;
+        tx.commit().await?;
+        Ok(result)
+    }
+
+    /// Query mint for current mint information with transaction
     #[instrument(skip(self, tx))]
-    pub async fn fetch_mint_info(
+    pub async fn fetch_mint_info_with_tx(
         &self,
-        tx: Option<&mut Tx<'_, '_>>,
+        tx: &mut Tx<'_, '_>,
     ) -> Result<Option<MintInfo>, Error> {
         match self.client.get_mint_info().await {
             Ok(mint_info) => {
@@ -332,15 +363,8 @@ impl Wallet {
                     }
                 }
 
-                if let Some(tx) = tx {
-                    tx.add_mint(self.mint_url.clone(), Some(mint_info.clone()))
-                        .await?;
-                } else {
-                    let mut tx = self.localstore.begin_db_transaction().await?;
-                    tx.add_mint(self.mint_url.clone(), Some(mint_info.clone()))
-                        .await?;
-                    tx.commit().await?;
-                };
+                tx.add_mint(self.mint_url.clone(), Some(mint_info.clone()))
+                    .await?;
 
                 tracing::trace!("Mint info updated for {}", self.mint_url);
 
@@ -354,14 +378,51 @@ impl Wallet {
     }
 
     /// Get amounts needed to refill proof state
-    #[instrument(skip(self, tx))]
+    #[instrument(skip(self))]
     pub async fn amounts_needed_for_state_target(
         &self,
-        tx: Option<&mut Tx<'_, '_>>,
         fee_and_amounts: &FeeAndAmounts,
     ) -> Result<Vec<Amount>, Error> {
         let unspent_proofs = self
-            .get_proofs_with(Some(vec![State::Unspent]), None, tx)
+            .get_proofs_with(Some(vec![State::Unspent]), None)
+            .await?;
+
+        let amounts_count: HashMap<u64, u64> =
+            unspent_proofs
+                .iter()
+                .fold(HashMap::new(), |mut acc, proof| {
+                    let amount = proof.amount;
+                    let counter = acc.entry(u64::from(amount)).or_insert(0);
+                    *counter += 1;
+                    acc
+                });
+
+        let needed_amounts =
+            fee_and_amounts
+                .amounts()
+                .iter()
+                .fold(Vec::new(), |mut acc, amount| {
+                    let count_needed = (self.target_proof_count as u64)
+                        .saturating_sub(*amounts_count.get(amount).unwrap_or(&0));
+
+                    for _i in 0..count_needed {
+                        acc.push(Amount::from(*amount));
+                    }
+
+                    acc
+                });
+        Ok(needed_amounts)
+    }
+
+    /// Get amounts needed to refill proof state with transaction
+    #[instrument(skip(self, tx))]
+    pub async fn amounts_needed_for_state_target_with_tx(
+        &self,
+        tx: &mut Tx<'_, '_>,
+        fee_and_amounts: &FeeAndAmounts,
+    ) -> Result<Vec<Amount>, Error> {
+        let unspent_proofs = self
+            .get_proofs_with_tx(tx, Some(vec![State::Unspent]), None)
             .await?;
 
         let amounts_count: HashMap<u64, u64> =
@@ -395,12 +456,12 @@ impl Wallet {
     #[instrument(skip(self, tx))]
     async fn determine_split_target_values(
         &self,
-        tx: Option<&mut Tx<'_, '_>>,
+        tx: &mut Tx<'_, '_>,
         change_amount: Amount,
         fee_and_amounts: &FeeAndAmounts,
     ) -> Result<SplitTarget, Error> {
         let mut amounts_needed_refill = self
-            .amounts_needed_for_state_target(tx, fee_and_amounts)
+            .amounts_needed_for_state_target_with_tx(tx, fee_and_amounts)
             .await?;
 
         amounts_needed_refill.sort();
@@ -427,7 +488,7 @@ impl Wallet {
             .await?
             .is_none()
         {
-            self.fetch_mint_info(None).await?;
+            self.fetch_mint_info().await?;
         }
 
         let keysets = self.load_mint_keysets().await?;
@@ -436,7 +497,7 @@ impl Wallet {
 
         for keyset in keysets {
             let mut tx = self.localstore.begin_db_transaction().await?;
-            let keys = self.load_keyset_keys(keyset.id, Some(&mut tx)).await?;
+            let keys = self.load_keyset_keys_with_tx(&mut tx, keyset.id).await?;
             let mut empty_batch = 0;
             let mut start_counter = 0;
 
@@ -491,7 +552,7 @@ impl Wallet {
                     .await?;
 
                 let states = self
-                    .check_proofs_spent(proofs.clone(), Some(&mut tx))
+                    .check_proofs_spent_with_tx(&mut tx, proofs.clone())
                     .await?;
 
                 let unspent_proofs: Vec<Proof> = proofs
@@ -686,7 +747,7 @@ impl Wallet {
             let mint_pubkey = match keys_cache.get(&proof.keyset_id) {
                 Some(keys) => keys.amount_key(proof.amount),
                 None => {
-                    let keys = self.load_keyset_keys(proof.keyset_id, None).await?;
+                    let keys = self.load_keyset_keys(proof.keyset_id).await?;
 
                     let key = keys.amount_key(proof.amount);
                     keys_cache.insert(proof.keyset_id, keys);

+ 1 - 1
crates/cdk/src/wallet/multi_mint_wallet.rs

@@ -1689,7 +1689,7 @@ impl MultiMintWallet {
             mint_url: mint_url.to_string(),
         })?;
 
-        wallet.fetch_mint_info(None).await
+        wallet.fetch_mint_info().await
     }
 }
 

+ 58 - 37
crates/cdk/src/wallet/proofs.rs

@@ -19,30 +19,41 @@ impl Wallet {
     /// Get unspent proofs for mint
     #[instrument(skip(self))]
     pub async fn get_unspent_proofs(&self) -> Result<Proofs, Error> {
-        self.get_proofs_with(Some(vec![State::Unspent]), None, None)
-            .await
+        self.get_proofs_with(Some(vec![State::Unspent]), None).await
     }
 
     /// Get pending [`Proofs`]
+    #[instrument(skip(self))]
+    pub async fn get_pending_proofs(&self) -> Result<Proofs, Error> {
+        self.get_proofs_with(Some(vec![State::Pending]), None).await
+    }
+
+    /// Get pending [`Proofs`] with transaction
     #[instrument(skip(self, tx))]
-    pub async fn get_pending_proofs(&self, tx: Option<&mut Tx<'_, '_>>) -> Result<Proofs, Error> {
-        self.get_proofs_with(Some(vec![State::Pending]), None, tx)
+    pub async fn get_pending_proofs_with_tx(&self, tx: &mut Tx<'_, '_>) -> Result<Proofs, Error> {
+        self.get_proofs_with_tx(tx, Some(vec![State::Pending]), None)
             .await
     }
 
     /// Get reserved [`Proofs`]
     #[instrument(skip(self))]
     pub async fn get_reserved_proofs(&self) -> Result<Proofs, Error> {
-        self.get_proofs_with(Some(vec![State::Reserved]), None, None)
+        self.get_proofs_with(Some(vec![State::Reserved]), None)
             .await
     }
 
     /// Get pending spent [`Proofs`]
-    pub async fn get_pending_spent_proofs(
+    pub async fn get_pending_spent_proofs(&self) -> Result<Proofs, Error> {
+        self.get_proofs_with(Some(vec![State::PendingSpent]), None)
+            .await
+    }
+
+    /// Get pending spent [`Proofs`] with transaction
+    pub async fn get_pending_spent_proofs_with_tx(
         &self,
-        tx: Option<&mut Tx<'_, '_>>,
+        tx: &mut Tx<'_, '_>,
     ) -> Result<Proofs, Error> {
-        self.get_proofs_with(Some(vec![State::PendingSpent]), None, tx)
+        self.get_proofs_with_tx(tx, Some(vec![State::PendingSpent]), None)
             .await
     }
 
@@ -51,10 +62,30 @@ impl Wallet {
         &self,
         state: Option<Vec<State>>,
         spending_conditions: Option<Vec<SpendingConditions>>,
-        tx: Option<&mut Tx<'_, '_>>,
     ) -> Result<Proofs, Error> {
-        Ok(if let Some(tx) = tx {
-            tx.get_proofs(
+        Ok(self
+            .localstore
+            .get_proofs(
+                Some(self.mint_url.clone()),
+                Some(self.unit.clone()),
+                state,
+                spending_conditions,
+            )
+            .await?
+            .into_iter()
+            .map(|p| p.proof)
+            .collect())
+    }
+
+    /// Get this wallet's [Proofs] that match the args with transaction
+    pub async fn get_proofs_with_tx(
+        &self,
+        tx: &mut Tx<'_, '_>,
+        state: Option<Vec<State>>,
+        spending_conditions: Option<Vec<SpendingConditions>>,
+    ) -> Result<Proofs, Error> {
+        Ok(tx
+            .get_proofs(
                 Some(self.mint_url.clone()),
                 Some(self.unit.clone()),
                 state,
@@ -63,20 +94,7 @@ impl Wallet {
             .await?
             .into_iter()
             .map(|p| p.proof)
-            .collect()
-        } else {
-            self.localstore
-                .get_proofs(
-                    Some(self.mint_url.clone()),
-                    Some(self.unit.clone()),
-                    state,
-                    spending_conditions,
-                )
-                .await?
-                .into_iter()
-                .map(|p| p.proof)
-                .collect()
-        })
+            .collect())
     }
 
     /// Return proofs to unspent allowing them to be selected and spent
@@ -138,11 +156,20 @@ impl Wallet {
     }
 
     /// NUT-07 Check the state of a [`Proof`] with the mint
-    #[instrument(skip(self, proofs, tx))]
-    pub async fn check_proofs_spent(
+    #[instrument(skip(self, proofs))]
+    pub async fn check_proofs_spent(&self, proofs: Proofs) -> Result<Vec<ProofState>, Error> {
+        let mut tx = self.localstore.begin_db_transaction().await?;
+        let result = self.check_proofs_spent_with_tx(&mut tx, proofs).await?;
+        tx.commit().await?;
+        Ok(result)
+    }
+
+    /// NUT-07 Check the state of a [`Proof`] with the mint with transaction
+    #[instrument(skip(self, tx, proofs))]
+    pub async fn check_proofs_spent_with_tx(
         &self,
+        tx: &mut Tx<'_, '_>,
         proofs: Proofs,
-        tx: Option<&mut Tx<'_, '_>>,
     ) -> Result<Vec<ProofState>, Error> {
         let spendable = self
             .client
@@ -158,13 +185,7 @@ impl Wallet {
             })
             .collect();
 
-        if let Some(tx) = tx {
-            tx.update_proofs(vec![], spent_ys).await?;
-        } else {
-            let mut tx = self.localstore.begin_db_transaction().await?;
-            tx.update_proofs(vec![], spent_ys).await?;
-            tx.commit().await?;
-        }
+        tx.update_proofs(vec![], spent_ys).await?;
 
         Ok(spendable.states)
     }
@@ -190,9 +211,9 @@ impl Wallet {
         }
 
         let states = self
-            .check_proofs_spent(
+            .check_proofs_spent_with_tx(
+                &mut tx,
                 proofs.clone().into_iter().map(|p| p.proof).collect(),
-                Some(&mut tx),
             )
             .await?;
 

+ 4 - 4
crates/cdk/src/wallet/receive.rs

@@ -45,11 +45,11 @@ impl Wallet {
     ) -> Result<Amount, Error> {
         let mint_url = &self.mint_url;
 
-        self.refresh_keysets(Some(tx)).await?;
+        self.refresh_keysets_with_tx(tx).await?;
 
-        let active_keyset_id = self.fetch_active_keyset(Some(tx)).await?.id;
+        let active_keyset_id = self.fetch_active_keyset_with_tx(tx).await?.id;
 
-        let keys = self.load_keyset_keys(active_keyset_id, Some(tx)).await?;
+        let keys = self.load_keyset_keys_with_tx(tx, active_keyset_id).await?;
 
         let mut proofs = proofs;
 
@@ -77,7 +77,7 @@ impl Wallet {
         for proof in &mut proofs {
             // Verify that proof DLEQ is valid
             if proof.dleq.is_some() {
-                let keys = self.load_keyset_keys(proof.keyset_id, Some(tx)).await?;
+                let keys = self.load_keyset_keys_with_tx(tx, proof.keyset_id).await?;
                 let key = keys.amount_key(proof.amount).ok_or(Error::AmountKey)?;
                 proof.verify_dleq(key)?;
             }

+ 11 - 11
crates/cdk/src/wallet/send.rs

@@ -33,20 +33,19 @@ impl Wallet {
 
         // If online send check mint for current keysets fees
         if opts.send_kind.is_online() {
-            if let Err(e) = self.refresh_keysets(None).await {
+            if let Err(e) = self.refresh_keysets().await {
                 tracing::error!("Error refreshing keysets: {:?}. Using stored keysets", e);
             }
         }
 
         // Get keyset fees from localstore
-        let keyset_fees = self.get_keyset_fees_and_amounts(None).await?;
+        let keyset_fees = self.get_keyset_fees_and_amounts().await?;
 
         // Get available proofs matching conditions
         let mut available_proofs = self
             .get_proofs_with(
                 Some(vec![State::Unspent]),
                 opts.conditions.clone().map(|c| vec![c]),
-                None,
             )
             .await?;
 
@@ -94,7 +93,7 @@ impl Wallet {
 
         // Check if selected proofs are exact
         let send_fee = if opts.include_fee {
-            self.get_proofs_fee(None, &selected_proofs).await?
+            self.get_proofs_fee(&selected_proofs).await?
         } else {
             Amount::ZERO
         };
@@ -133,14 +132,13 @@ impl Wallet {
         // Split amount with fee if necessary
         let active_keyset_id = self.get_active_keyset().await?.id;
         let fee_and_amounts = self
-            .get_keyset_fees_and_amounts_by_id(active_keyset_id, None)
+            .get_keyset_fees_and_amounts_by_id(active_keyset_id)
             .await?;
         let (send_amounts, send_fee) = if opts.include_fee {
             tracing::debug!("Keyset fee per proof: {:?}", fee_and_amounts.fee());
             let send_split = amount.split_with_fee(&fee_and_amounts)?;
             let send_fee = self
                 .get_proofs_fee_by_count(
-                    None,
                     vec![(active_keyset_id, send_split.len() as u64)]
                         .into_iter()
                         .collect(),
@@ -190,7 +188,9 @@ impl Wallet {
         }
 
         // Calculate swap fee
-        let swap_fee = self.get_proofs_fee(Some(&mut tx), &proofs_to_swap).await?;
+        let swap_fee = self
+            .get_proofs_fee_with_tx(&mut tx, &proofs_to_swap)
+            .await?;
 
         tx.commit().await?;
 
@@ -271,13 +271,13 @@ impl PreparedSend {
         let mut tx = self.wallet.localstore.begin_db_transaction().await?;
 
         // Get active keyset ID
-        let active_keyset_id = self.wallet.fetch_active_keyset(Some(&mut tx)).await?.id;
+        let active_keyset_id = self.wallet.fetch_active_keyset_with_tx(&mut tx).await?.id;
         tracing::debug!("Active keyset ID: {:?}", active_keyset_id);
 
         // Get keyset fees
         let keyset_fee_ppk = self
             .wallet
-            .get_keyset_fees_and_amounts_by_id(active_keyset_id, Some(&mut tx))
+            .get_keyset_fees_and_amounts_by_id_with_tx(&mut tx, active_keyset_id)
             .await?;
         tracing::debug!("Keyset fees: {:?}", keyset_fee_ppk);
 
@@ -317,10 +317,10 @@ impl PreparedSend {
         // Check if proofs are reserved or unspent
         let sendable_proof_ys = self
             .wallet
-            .get_proofs_with(
+            .get_proofs_with_tx(
+                &mut tx,
                 Some(vec![State::Reserved, State::Unspent]),
                 self.options.conditions.clone().map(|c| vec![c]),
-                Some(&mut tx),
             )
             .await?
             .ys()?;

+ 8 - 8
crates/cdk/src/wallet/swap.rs

@@ -21,7 +21,7 @@ impl Wallet {
         spending_conditions: Option<SpendingConditions>,
         include_fees: bool,
     ) -> Result<Option<Proofs>, Error> {
-        self.refresh_keysets(Some(tx)).await?;
+        self.refresh_keysets_with_tx(tx).await?;
 
         tracing::info!("Swapping");
         let mint_url = &self.mint_url;
@@ -42,7 +42,7 @@ impl Wallet {
 
         let active_keyset_id = pre_swap.pre_mint_secrets.keyset_id;
         let fee_and_amounts = self
-            .get_keyset_fees_and_amounts_by_id(active_keyset_id, Some(tx))
+            .get_keyset_fees_and_amounts_by_id_with_tx(tx, active_keyset_id)
             .await?;
 
         let active_keys = tx
@@ -217,13 +217,13 @@ impl Wallet {
         ensure_cdk!(proofs_sum >= amount, Error::InsufficientFunds);
 
         let active_keyset_ids = self
-            .refresh_keysets(Some(&mut tx))
+            .refresh_keysets_with_tx(&mut tx)
             .await?
             .active()
             .map(|k| k.id)
             .collect();
 
-        let keyset_fees = self.get_keyset_fees_and_amounts(Some(&mut tx)).await?;
+        let keyset_fees = self.get_keyset_fees_and_amounts_with_tx(&mut tx).await?;
         let proofs = Wallet::select_proofs(
             amount,
             available_proofs,
@@ -285,7 +285,7 @@ impl Wallet {
         include_fees: bool,
     ) -> Result<PreSwap, Error> {
         tracing::info!("Creating swap");
-        let active_keyset_id = self.fetch_active_keyset(Some(tx)).await?.id;
+        let active_keyset_id = self.fetch_active_keyset_with_tx(tx).await?.id;
 
         // Desired amount is either amount passed or value of all proof
         let proofs_total = proofs.total_amount()?;
@@ -293,7 +293,7 @@ impl Wallet {
         let ys: Vec<PublicKey> = proofs.ys()?;
         tx.update_proofs_state(ys, State::Reserved).await?;
 
-        let fee = self.get_proofs_fee(Some(tx), &proofs).await?;
+        let fee = self.get_proofs_fee_with_tx(tx, &proofs).await?;
 
         let total_to_subtract = amount
             .unwrap_or(Amount::ZERO)
@@ -305,7 +305,7 @@ impl Wallet {
             .ok_or(Error::InsufficientFunds)?;
 
         let fee_and_amounts = self
-            .get_keyset_fees_and_amounts_by_id(active_keyset_id, Some(tx))
+            .get_keyset_fees_and_amounts_by_id_with_tx(tx, active_keyset_id)
             .await?;
 
         let (send_amount, change_amount) = match include_fees {
@@ -336,7 +336,7 @@ impl Wallet {
         // else use state refill
         let change_split_target = match amount_split_target {
             SplitTarget::None => {
-                self.determine_split_target_values(Some(tx), change_amount, &fee_and_amounts)
+                self.determine_split_target_values(tx, change_amount, &fee_and_amounts)
                     .await?
             }
             s => s,

+ 1 - 1
crates/cdk/src/wallet/transactions.rs

@@ -44,7 +44,7 @@ impl Wallet {
         let mut db_tx = self.localstore.begin_db_transaction().await?;
 
         let pending_spent_proofs = self
-            .get_pending_spent_proofs(Some(&mut db_tx))
+            .get_pending_spent_proofs_with_tx(&mut db_tx)
             .await?
             .into_iter()
             .filter(|p| match p.y() {