Browse Source

Added get_keyset_by_id to Tx

Cesar Rodas 1 month ago
parent
commit
e082a7b719

+ 3 - 0
crates/cdk-common/src/database/wallet.rs

@@ -48,6 +48,9 @@ pub trait DatabaseTransaction<'a, Error>: DbTransactionFinalizer<Err = Error> {
         new_mint_url: MintUrl,
         new_mint_url: MintUrl,
     ) -> Result<(), Error>;
     ) -> Result<(), Error>;
 
 
+    /// Get mint keyset by id
+    async fn get_keyset_by_id(&mut self, keyset_id: &Id) -> Result<Option<KeySetInfo>, Error>;
+
     /// Add mint keyset to storage
     /// Add mint keyset to storage
     async fn add_mint_keysets(
     async fn add_mint_keysets(
         &mut self,
         &mut self,

+ 13 - 0
crates/cdk-ffi/src/database.rs

@@ -411,6 +411,19 @@ impl<'a> cdk::cdk_database::WalletDatabaseTransaction<'a, cdk::cdk_database::Err
             .map_err(|e| cdk::cdk_database::Error::Database(e.to_string().into()))
             .map_err(|e| cdk::cdk_database::Error::Database(e.to_string().into()))
     }
     }
 
 
+    async fn get_keyset_by_id(
+        &mut self,
+        keyset_id: &cdk::nuts::Id,
+    ) -> Result<Option<cdk::nuts::KeySetInfo>, cdk::cdk_database::Error> {
+        let ffi_id = (*keyset_id).into();
+        let result = self
+            .ffi_db
+            .get_keyset_by_id(ffi_id)
+            .await
+            .map_err(|e| cdk::cdk_database::Error::Database(e.to_string().into()))?;
+        Ok(result.map(Into::into))
+    }
+
     async fn increment_keyset_counter(
     async fn increment_keyset_counter(
         &mut self,
         &mut self,
         keyset_id: &cdk::nuts::Id,
         keyset_id: &cdk::nuts::Id,

+ 24 - 0
crates/cdk-redb/src/wallet/mod.rs

@@ -946,6 +946,30 @@ impl<'a> cdk_common::database::WalletDatabaseTransaction<'a, database::Error>
     }
     }
 
 
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
+    async fn get_keyset_by_id(
+        &mut self,
+        keyset_id: &Id,
+    ) -> Result<Option<KeySetInfo>, database::Error> {
+        let txn = self.txn().map_err(Into::<database::Error>::into)?;
+        let table = txn.open_table(KEYSETS_TABLE).map_err(Error::from)?;
+
+        let result = match table
+            .get(keyset_id.to_bytes().as_slice())
+            .map_err(Error::from)?
+        {
+            Some(keyset) => {
+                let keyset: KeySetInfo =
+                    serde_json::from_str(keyset.value()).map_err(Error::from)?;
+
+                Ok(Some(keyset))
+            }
+            None => Ok(None),
+        };
+
+        result
+    }
+
+    #[instrument(skip(self), fields(keyset_id = %keyset_id))]
     async fn increment_keyset_counter(
     async fn increment_keyset_counter(
         &mut self,
         &mut self,
         keyset_id: &Id,
         keyset_id: &Id,

+ 22 - 0
crates/cdk-sql-common/src/wallet/mod.rs

@@ -691,6 +691,28 @@ ON CONFLICT(mint_url) DO UPDATE SET
     }
     }
 
 
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
+    async fn get_keyset_by_id(&mut self, keyset_id: &Id) -> Result<Option<KeySetInfo>, Error> {
+        Ok(query(
+            r#"
+               SELECT
+                   id,
+                   unit,
+                   active,
+                   input_fee_ppk,
+                   final_expiry
+               FROM
+                   keyset
+               WHERE id = :id
+               "#,
+        )?
+        .bind("id", keyset_id.to_string())
+        .fetch_one(&self.inner)
+        .await?
+        .map(sql_row_to_keyset)
+        .transpose()?)
+    }
+
+    #[instrument(skip(self), fields(keyset_id = %keyset_id))]
     async fn increment_keyset_counter(&mut self, keyset_id: &Id, count: u32) -> Result<u32, Error> {
     async fn increment_keyset_counter(&mut self, keyset_id: &Id, count: u32) -> Result<u32, Error> {
         // Lock the row and get current counter
         // Lock the row and get current counter
         let current_counter = query(
         let current_counter = query(

+ 1 - 2
crates/cdk/src/wallet/melt/melt_bolt11.rs

@@ -147,8 +147,7 @@ impl Wallet {
 
 
         let active_keyset_id = self.fetch_active_keyset(Some(tx)).await?.id;
         let active_keyset_id = self.fetch_active_keyset(Some(tx)).await?.id;
 
 
-        let active_keys = self
-            .localstore
+        let active_keys = tx
             .get_keys(&active_keyset_id)
             .get_keys(&active_keyset_id)
             .await?
             .await?
             .ok_or(Error::NoActiveKeyset)?;
             .ok_or(Error::NoActiveKeyset)?;

+ 26 - 11
crates/cdk/src/wallet/mod.rs

@@ -206,24 +206,35 @@ impl Wallet {
 
 
     /// Fee required for proof set
     /// Fee required for proof set
     #[instrument(skip_all)]
     #[instrument(skip_all)]
-    pub async fn get_proofs_fee(&self, proofs: &Proofs) -> Result<Amount, Error> {
+    pub async fn get_proofs_fee(
+        &self,
+        tx: Option<&mut Tx<'_, '_>>,
+        proofs: &Proofs,
+    ) -> Result<Amount, Error> {
         let proofs_per_keyset = proofs.count_by_keyset();
         let proofs_per_keyset = proofs.count_by_keyset();
-        self.get_proofs_fee_by_count(proofs_per_keyset).await
+        self.get_proofs_fee_by_count(tx, proofs_per_keyset).await
     }
     }
 
 
     /// Fee required for proof set by count
     /// Fee required for proof set by count
     pub async fn get_proofs_fee_by_count(
     pub async fn get_proofs_fee_by_count(
         &self,
         &self,
+        tx: Option<&mut Tx<'_, '_>>,
         proofs_per_keyset: HashMap<Id, u64>,
         proofs_per_keyset: HashMap<Id, u64>,
     ) -> Result<Amount, Error> {
     ) -> Result<Amount, Error> {
         let mut fee_per_keyset = HashMap::new();
         let mut fee_per_keyset = HashMap::new();
+        let mut tx = tx;
 
 
         for keyset_id in proofs_per_keyset.keys() {
         for keyset_id in proofs_per_keyset.keys() {
-            let mint_keyset_info = self
-                .localstore
-                .get_keyset_by_id(keyset_id)
-                .await?
-                .ok_or(Error::UnknownKeySet)?;
+            let mint_keyset_info = if let Some(tx) = tx.as_mut() {
+                tx.get_keyset_by_id(keyset_id)
+                    .await?
+                    .ok_or(Error::UnknownKeySet)?
+            } else {
+                self.localstore
+                    .get_keyset_by_id(keyset_id)
+                    .await?
+                    .ok_or(Error::UnknownKeySet)?
+            };
             fee_per_keyset.insert(*keyset_id, mint_keyset_info.input_fee_ppk);
             fee_per_keyset.insert(*keyset_id, mint_keyset_info.input_fee_ppk);
         }
         }
 
 
@@ -343,12 +354,15 @@ impl Wallet {
     }
     }
 
 
     /// Get amounts needed to refill proof state
     /// Get amounts needed to refill proof state
-    #[instrument(skip(self))]
+    #[instrument(skip(self, tx))]
     pub async fn amounts_needed_for_state_target(
     pub async fn amounts_needed_for_state_target(
         &self,
         &self,
+        tx: Option<&mut Tx<'_, '_>>,
         fee_and_amounts: &FeeAndAmounts,
         fee_and_amounts: &FeeAndAmounts,
     ) -> Result<Vec<Amount>, Error> {
     ) -> Result<Vec<Amount>, Error> {
-        let unspent_proofs = self.get_unspent_proofs().await?;
+        let unspent_proofs = self
+            .get_proofs_with(Some(vec![State::Unspent]), None, tx)
+            .await?;
 
 
         let amounts_count: HashMap<u64, u64> =
         let amounts_count: HashMap<u64, u64> =
             unspent_proofs
             unspent_proofs
@@ -378,14 +392,15 @@ impl Wallet {
     }
     }
 
 
     /// Determine [`SplitTarget`] for amount based on state
     /// Determine [`SplitTarget`] for amount based on state
-    #[instrument(skip(self))]
+    #[instrument(skip(self, tx))]
     async fn determine_split_target_values(
     async fn determine_split_target_values(
         &self,
         &self,
+        tx: Option<&mut Tx<'_, '_>>,
         change_amount: Amount,
         change_amount: Amount,
         fee_and_amounts: &FeeAndAmounts,
         fee_and_amounts: &FeeAndAmounts,
     ) -> Result<SplitTarget, Error> {
     ) -> Result<SplitTarget, Error> {
         let mut amounts_needed_refill = self
         let mut amounts_needed_refill = self
-            .amounts_needed_for_state_target(fee_and_amounts)
+            .amounts_needed_for_state_target(tx, fee_and_amounts)
             .await?;
             .await?;
 
 
         amounts_needed_refill.sort();
         amounts_needed_refill.sort();

+ 6 - 4
crates/cdk/src/wallet/send.rs

@@ -94,7 +94,7 @@ impl Wallet {
 
 
         // Check if selected proofs are exact
         // Check if selected proofs are exact
         let send_fee = if opts.include_fee {
         let send_fee = if opts.include_fee {
-            self.get_proofs_fee(&selected_proofs).await?
+            self.get_proofs_fee(None, &selected_proofs).await?
         } else {
         } else {
             Amount::ZERO
             Amount::ZERO
         };
         };
@@ -140,6 +140,7 @@ impl Wallet {
             let send_split = amount.split_with_fee(&fee_and_amounts)?;
             let send_split = amount.split_with_fee(&fee_and_amounts)?;
             let send_fee = self
             let send_fee = self
                 .get_proofs_fee_by_count(
                 .get_proofs_fee_by_count(
+                    None,
                     vec![(active_keyset_id, send_split.len() as u64)]
                     vec![(active_keyset_id, send_split.len() as u64)]
                         .into_iter()
                         .into_iter()
                         .collect(),
                         .collect(),
@@ -159,7 +160,6 @@ impl Wallet {
         // Reserve proofs
         // Reserve proofs
         tx.update_proofs_state(proofs.ys()?, State::Reserved)
         tx.update_proofs_state(proofs.ys()?, State::Reserved)
             .await?;
             .await?;
-        tx.commit().await?;
 
 
         // Check if proofs are exact send amount (and does not exceed max_proofs)
         // Check if proofs are exact send amount (and does not exceed max_proofs)
         let mut exact_proofs = proofs.total_amount()? == amount + send_fee;
         let mut exact_proofs = proofs.total_amount()? == amount + send_fee;
@@ -190,7 +190,9 @@ impl Wallet {
         }
         }
 
 
         // Calculate swap fee
         // Calculate swap fee
-        let swap_fee = self.get_proofs_fee(&proofs_to_swap).await?;
+        let swap_fee = self.get_proofs_fee(Some(&mut tx), &proofs_to_swap).await?;
+
+        tx.commit().await?;
 
 
         // Return prepared send
         // Return prepared send
         Ok(PreparedSend {
         Ok(PreparedSend {
@@ -318,7 +320,7 @@ impl PreparedSend {
             .get_proofs_with(
             .get_proofs_with(
                 Some(vec![State::Reserved, State::Unspent]),
                 Some(vec![State::Reserved, State::Unspent]),
                 self.options.conditions.clone().map(|c| vec![c]),
                 self.options.conditions.clone().map(|c| vec![c]),
-                None,
+                Some(&mut tx),
             )
             )
             .await?
             .await?
             .ys()?;
             .ys()?;

+ 3 - 4
crates/cdk/src/wallet/swap.rs

@@ -46,8 +46,7 @@ impl Wallet {
             .get_keyset_fees_and_amounts_by_id(active_keyset_id, Some(tx))
             .get_keyset_fees_and_amounts_by_id(active_keyset_id, Some(tx))
             .await?;
             .await?;
 
 
-        let active_keys = self
-            .localstore
+        let active_keys = tx
             .get_keys(&active_keyset_id)
             .get_keys(&active_keyset_id)
             .await?
             .await?
             .ok_or(Error::NoActiveKeyset)?;
             .ok_or(Error::NoActiveKeyset)?;
@@ -263,7 +262,7 @@ impl Wallet {
         let ys: Vec<PublicKey> = proofs.ys()?;
         let ys: Vec<PublicKey> = proofs.ys()?;
         tx.update_proofs_state(ys, State::Reserved).await?;
         tx.update_proofs_state(ys, State::Reserved).await?;
 
 
-        let fee = self.get_proofs_fee(&proofs).await?;
+        let fee = self.get_proofs_fee(Some(tx), &proofs).await?;
 
 
         let total_to_subtract = amount
         let total_to_subtract = amount
             .unwrap_or(Amount::ZERO)
             .unwrap_or(Amount::ZERO)
@@ -306,7 +305,7 @@ impl Wallet {
         // else use state refill
         // else use state refill
         let change_split_target = match amount_split_target {
         let change_split_target = match amount_split_target {
             SplitTarget::None => {
             SplitTarget::None => {
-                self.determine_split_target_values(change_amount, &fee_and_amounts)
+                self.determine_split_target_values(Some(tx), change_amount, &fee_and_amounts)
                     .await?
                     .await?
             }
             }
             s => s,
             s => s,