Pārlūkot izejas kodu

Changes requested in code review

* Removed duplicate SQL statements into a single function
* Avoid external calls while having an open db transaction
Cesar Rodas 2 mēneši atpakaļ
vecāks
revīzija
21292ff9fe

+ 203 - 233
crates/cdk-sql-common/src/wallet/mod.rs

@@ -546,103 +546,22 @@ where
 
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
     async fn get_keyset_by_id(&mut self, keyset_id: &Id) -> Result<Option<KeySetInfo>, Error> {
-        Ok(query(
-            r#"
-            SELECT
-                id,
-                unit,
-                active,
-                input_fee_ppk,
-                final_expiry
-            FROM
-                keyset
-            WHERE id = :id
-            FOR UPDATE
-            "#,
-        )?
-        .bind("id", keyset_id.to_string())
-        .fetch_one(&self.inner)
-        .await?
-        .map(sql_row_to_keyset)
-        .transpose()?)
+        get_keyset_by_id_inner(&self.inner, keyset_id, true).await
     }
 
     #[instrument(skip(self), fields(keyset_id = %id))]
     async fn get_keys(&mut self, id: &Id) -> Result<Option<Keys>, Error> {
-        Ok(query(
-            r#"
-            SELECT
-                keys
-            FROM key
-            WHERE id = :id
-            "#,
-        )?
-        .bind("id", id.to_string())
-        .pluck(&self.inner)
-        .await?
-        .map(|keys| {
-            let keys = column_as_string!(keys);
-            serde_json::from_str(&keys).map_err(Error::from)
-        })
-        .transpose()?)
+        get_keys_inner(&self.inner, id).await
     }
 
     #[instrument(skip(self))]
     async fn get_mint_quote(&mut self, quote_id: &str) -> Result<Option<MintQuote>, Error> {
-        Ok(query(
-            r#"
-            SELECT
-                id,
-                mint_url,
-                amount,
-                unit,
-                request,
-                state,
-                expiry,
-                secret_key,
-                payment_method,
-                amount_issued,
-                amount_paid
-            FROM
-                mint_quote
-            WHERE
-                id = :id
-            FOR UPDATE
-            "#,
-        )?
-        .bind("id", quote_id.to_string())
-        .fetch_one(&self.inner)
-        .await?
-        .map(sql_row_to_mint_quote)
-        .transpose()?)
+        get_mint_quote_inner(&self.inner, quote_id, true).await
     }
 
     #[instrument(skip(self))]
     async fn get_melt_quote(&mut self, quote_id: &str) -> Result<Option<wallet::MeltQuote>, Error> {
-        Ok(query(
-            r#"
-               SELECT
-                   id,
-                   unit,
-                   amount,
-                   request,
-                   fee_reserve,
-                   state,
-                   expiry,
-                   payment_preimage,
-                   payment_method
-               FROM
-                   melt_quote
-               WHERE
-                   id=:id
-                FOR UPDATE
-               "#,
-        )?
-        .bind("id", quote_id.to_owned())
-        .fetch_one(&self.inner)
-        .await?
-        .map(sql_row_to_melt_quote)
-        .transpose()?)
+        get_melt_quote_inner(&self.inner, quote_id, true).await
     }
 
     #[instrument(skip(self, state, spending_conditions))]
@@ -653,40 +572,15 @@ where
         state: Option<Vec<State>>,
         spending_conditions: Option<Vec<SpendingConditions>>,
     ) -> Result<Vec<ProofInfo>, Error> {
-        Ok(query(
-            r#"
-            SELECT
-                amount,
-                unit,
-                keyset_id,
-                secret,
-                c,
-                witness,
-                dleq_e,
-                dleq_s,
-                dleq_r,
-                y,
-                mint_url,
-                state,
-                spending_condition
-            FROM proof
-            FOR UPDATE
-        "#,
-        )?
-        .fetch_all(&self.inner)
-        .await?
-        .into_iter()
-        .filter_map(|row| {
-            let row = sql_row_to_proof_info(row).ok()?;
-
-            // convert matches_conditions to SQL to lock only affected rows
-            if row.matches_conditions(&mint_url, &unit, &state, &spending_conditions) {
-                Some(row)
-            } else {
-                None
-            }
-        })
-        .collect::<Vec<_>>())
+        get_proofs_inner(
+            &self.inner,
+            mint_url,
+            unit,
+            state,
+            spending_conditions,
+            true,
+        )
+        .await
     }
 }
 
@@ -706,6 +600,191 @@ where
     }
 }
 
+// Inline helper functions that work with both connections and transactions
+#[inline]
+async fn get_keyset_by_id_inner<T>(
+    executor: &T,
+    keyset_id: &Id,
+    for_update: bool,
+) -> Result<Option<KeySetInfo>, Error>
+where
+    T: DatabaseExecutor,
+{
+    let for_update_clause = if for_update { "FOR UPDATE" } else { "" };
+    let query_str = format!(
+        r#"
+        SELECT
+            id,
+            unit,
+            active,
+            input_fee_ppk,
+            final_expiry
+        FROM
+            keyset
+        WHERE id = :id
+        {for_update_clause}
+        "#
+    );
+
+    query(&query_str)?
+        .bind("id", keyset_id.to_string())
+        .fetch_one(executor)
+        .await?
+        .map(sql_row_to_keyset)
+        .transpose()
+}
+
+#[inline]
+async fn get_keys_inner<T>(executor: &T, id: &Id) -> Result<Option<Keys>, Error>
+where
+    T: DatabaseExecutor,
+{
+    query(
+        r#"
+        SELECT
+            keys
+        FROM key
+        WHERE id = :id
+        "#,
+    )?
+    .bind("id", id.to_string())
+    .pluck(executor)
+    .await?
+    .map(|keys| {
+        let keys = column_as_string!(keys);
+        serde_json::from_str(&keys).map_err(Error::from)
+    })
+    .transpose()
+}
+
+#[inline]
+async fn get_mint_quote_inner<T>(
+    executor: &T,
+    quote_id: &str,
+    for_update: bool,
+) -> Result<Option<MintQuote>, Error>
+where
+    T: DatabaseExecutor,
+{
+    let for_update_clause = if for_update { "FOR UPDATE" } else { "" };
+    let query_str = format!(
+        r#"
+        SELECT
+            id,
+            mint_url,
+            amount,
+            unit,
+            request,
+            state,
+            expiry,
+            secret_key,
+            payment_method,
+            amount_issued,
+            amount_paid
+        FROM
+            mint_quote
+        WHERE
+            id = :id
+        {for_update_clause}
+        "#
+    );
+
+    query(&query_str)?
+        .bind("id", quote_id.to_string())
+        .fetch_one(executor)
+        .await?
+        .map(sql_row_to_mint_quote)
+        .transpose()
+}
+
+#[inline]
+async fn get_melt_quote_inner<T>(
+    executor: &T,
+    quote_id: &str,
+    for_update: bool,
+) -> Result<Option<wallet::MeltQuote>, Error>
+where
+    T: DatabaseExecutor,
+{
+    let for_update_clause = if for_update { "FOR UPDATE" } else { "" };
+    let query_str = format!(
+        r#"
+        SELECT
+            id,
+            unit,
+            amount,
+            request,
+            fee_reserve,
+            state,
+            expiry,
+            payment_preimage,
+            payment_method
+        FROM
+            melt_quote
+        WHERE
+            id=:id
+        {for_update_clause}
+        "#
+    );
+
+    query(&query_str)?
+        .bind("id", quote_id.to_owned())
+        .fetch_one(executor)
+        .await?
+        .map(sql_row_to_melt_quote)
+        .transpose()
+}
+
+#[inline]
+async fn get_proofs_inner<T>(
+    executor: &T,
+    mint_url: Option<MintUrl>,
+    unit: Option<CurrencyUnit>,
+    state: Option<Vec<State>>,
+    spending_conditions: Option<Vec<SpendingConditions>>,
+    for_update: bool,
+) -> Result<Vec<ProofInfo>, Error>
+where
+    T: DatabaseExecutor,
+{
+    let for_update_clause = if for_update { "FOR UPDATE" } else { "" };
+    let query_str = format!(
+        r#"
+        SELECT
+            amount,
+            unit,
+            keyset_id,
+            secret,
+            c,
+            witness,
+            dleq_e,
+            dleq_s,
+            dleq_r,
+            y,
+            mint_url,
+            state,
+            spending_condition
+        FROM proof
+        {for_update_clause}
+        "#
+    );
+
+    Ok(query(&query_str)?
+        .fetch_all(executor)
+        .await?
+        .into_iter()
+        .filter_map(|row| {
+            let row = sql_row_to_proof_info(row).ok()?;
+
+            if row.matches_conditions(&mint_url, &unit, &state, &spending_conditions) {
+                Some(row)
+            } else {
+                None
+            }
+        })
+        .collect::<Vec<_>>())
+}
+
 impl<RM> SQLWalletDatabase<RM>
 where
     RM: DatabasePool + 'static,
@@ -951,54 +1030,13 @@ where
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
     async fn get_keyset_by_id(&self, keyset_id: &Id) -> Result<Option<KeySetInfo>, Self::Err> {
         let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
-        Ok(query(
-            r#"
-            SELECT
-                id,
-                unit,
-                active,
-                input_fee_ppk,
-                final_expiry
-            FROM
-                keyset
-            WHERE id = :id
-            "#,
-        )?
-        .bind("id", keyset_id.to_string())
-        .fetch_one(&*conn)
-        .await?
-        .map(sql_row_to_keyset)
-        .transpose()?)
+        get_keyset_by_id_inner(&*conn, keyset_id, false).await
     }
 
     #[instrument(skip(self))]
     async fn get_mint_quote(&self, quote_id: &str) -> Result<Option<MintQuote>, Self::Err> {
         let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
-        Ok(query(
-            r#"
-            SELECT
-                id,
-                mint_url,
-                amount,
-                unit,
-                request,
-                state,
-                expiry,
-                secret_key,
-                payment_method,
-                amount_issued,
-                amount_paid
-            FROM
-                mint_quote
-            WHERE
-                id = :id
-            "#,
-        )?
-        .bind("id", quote_id.to_string())
-        .fetch_one(&*conn)
-        .await?
-        .map(sql_row_to_mint_quote)
-        .transpose()?)
+        get_mint_quote_inner(&*conn, quote_id, false).await
     }
 
     #[instrument(skip(self))]
@@ -1032,50 +1070,13 @@ where
     #[instrument(skip(self))]
     async fn get_melt_quote(&self, quote_id: &str) -> Result<Option<wallet::MeltQuote>, Self::Err> {
         let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
-        Ok(query(
-            r#"
-            SELECT
-                id,
-                unit,
-                amount,
-                request,
-                fee_reserve,
-                state,
-                expiry,
-                payment_preimage,
-                payment_method
-            FROM
-                melt_quote
-            WHERE
-                id=:id
-            "#,
-        )?
-        .bind("id", quote_id.to_owned())
-        .fetch_one(&*conn)
-        .await?
-        .map(sql_row_to_melt_quote)
-        .transpose()?)
+        get_melt_quote_inner(&*conn, quote_id, false).await
     }
 
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
     async fn get_keys(&self, keyset_id: &Id) -> Result<Option<Keys>, Self::Err> {
         let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
-        Ok(query(
-            r#"
-            SELECT
-                keys
-            FROM key
-            WHERE id = :id
-            "#,
-        )?
-        .bind("id", keyset_id.to_string())
-        .pluck(&*conn)
-        .await?
-        .map(|keys| {
-            let keys = column_as_string!(keys);
-            serde_json::from_str(&keys).map_err(Error::from)
-        })
-        .transpose()?)
+        get_keys_inner(&*conn, keyset_id).await
     }
 
     #[instrument(skip(self, state, spending_conditions))]
@@ -1087,38 +1088,7 @@ where
         spending_conditions: Option<Vec<SpendingConditions>>,
     ) -> Result<Vec<ProofInfo>, Self::Err> {
         let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
-        Ok(query(
-            r#"
-            SELECT
-                amount,
-                unit,
-                keyset_id,
-                secret,
-                c,
-                witness,
-                dleq_e,
-                dleq_s,
-                dleq_r,
-                y,
-                mint_url,
-                state,
-                spending_condition
-            FROM proof
-        "#,
-        )?
-        .fetch_all(&*conn)
-        .await?
-        .into_iter()
-        .filter_map(|row| {
-            let row = sql_row_to_proof_info(row).ok()?;
-
-            if row.matches_conditions(&mint_url, &unit, &state, &spending_conditions) {
-                Some(row)
-            } else {
-                None
-            }
-        })
-        .collect::<Vec<_>>())
+        get_proofs_inner(&*conn, mint_url, unit, state, spending_conditions, false).await
     }
 
     async fn get_balance(

+ 5 - 5
crates/cdk/src/wallet/issue/issue_bolt11.rs

@@ -202,6 +202,11 @@ impl Wallet {
         amount_split_target: SplitTarget,
         spending_conditions: Option<SpendingConditions>,
     ) -> Result<Proofs, Error> {
+        let active_keyset_id = self.fetch_active_keyset().await?.id;
+        let fee_and_amounts = self
+            .get_keyset_fees_and_amounts_by_id(active_keyset_id)
+            .await?;
+
         let mut tx = self.localstore.begin_db_transaction().await?;
         let quote_info = tx
             .get_mint_quote(quote_id)
@@ -225,11 +230,6 @@ impl Wallet {
             tracing::warn!("Attempting to mint with expired quote.");
         }
 
-        let active_keyset_id = self.fetch_active_keyset().await?.id;
-        let fee_and_amounts = self
-            .get_keyset_fees_and_amounts_by_id(active_keyset_id)
-            .await?;
-
         let premint_secrets = match &spending_conditions {
             Some(spending_conditions) => PreMintSecrets::with_conditions(
                 active_keyset_id,

+ 5 - 5
crates/cdk/src/wallet/issue/issue_bolt12.rs

@@ -89,6 +89,11 @@ impl Wallet {
         amount_split_target: SplitTarget,
         spending_conditions: Option<SpendingConditions>,
     ) -> Result<Proofs, Error> {
+        let active_keyset_id = self.fetch_active_keyset().await?.id;
+        let fee_and_amounts = self
+            .get_keyset_fees_and_amounts_by_id(active_keyset_id)
+            .await?;
+
         let mut tx = self.localstore.begin_db_transaction().await?;
         let quote_info = tx.get_mint_quote(quote_id).await?;
 
@@ -102,11 +107,6 @@ impl Wallet {
             return Err(Error::UnknownQuote);
         };
 
-        let active_keyset_id = self.fetch_active_keyset().await?.id;
-        let fee_and_amounts = self
-            .get_keyset_fees_and_amounts_by_id(active_keyset_id)
-            .await?;
-
         let (mut tx, quote_info, amount) = match amount {
             Some(amount) => (tx, quote_info, amount),
             None => {

+ 13 - 1
crates/cdk/src/wallet/receive.rs

@@ -29,6 +29,9 @@ impl Wallet {
         let mint_url = &self.mint_url;
 
         let active_keyset_id = self.fetch_active_keyset().await?.id;
+        let fee_and_amounts = self
+            .get_keyset_fees_and_amounts_by_id(active_keyset_id)
+            .await?;
 
         let keys = self.load_keyset_keys(active_keyset_id).await?;
 
@@ -115,7 +118,16 @@ impl Wallet {
         tx.update_proofs(proofs_info.clone(), vec![]).await?;
 
         let mut pre_swap = self
-            .create_swap(tx, None, opts.amount_split_target, proofs, None, false)
+            .create_swap(
+                tx,
+                active_keyset_id,
+                &fee_and_amounts,
+                None,
+                opts.amount_split_target,
+                proofs,
+                None,
+                false,
+            )
             .await?;
 
         if sig_flag.eq(&SigFlag::SigAll) {

+ 20 - 16
crates/cdk/src/wallet/swap.rs

@@ -1,5 +1,7 @@
+use cdk_common::amount::FeeAndAmounts;
 use cdk_common::database::DynWalletDatabaseTransaction;
 use cdk_common::nut02::KeySetInfosMethods;
+use cdk_common::Id;
 use tracing::instrument;
 
 use crate::amount::SplitTarget;
@@ -25,10 +27,16 @@ impl Wallet {
         tracing::info!("Swapping");
         let mint_url = &self.mint_url;
         let unit = &self.unit;
+        let active_keyset_id = self.fetch_active_keyset().await?.id;
+        let fee_and_amounts = self
+            .get_keyset_fees_and_amounts_by_id(active_keyset_id)
+            .await?;
 
         let pre_swap = self
             .create_swap(
                 self.localstore.begin_db_transaction().await?,
+                active_keyset_id,
+                &fee_and_amounts,
                 amount,
                 amount_split_target.clone(),
                 input_proofs.clone(),
@@ -45,9 +53,6 @@ impl Wallet {
             .await?;
 
         let active_keyset_id = pre_swap.pre_mint_secrets.keyset_id;
-        let fee_and_amounts = self
-            .get_keyset_fees_and_amounts_by_id(active_keyset_id)
-            .await?;
 
         let active_keys = self.load_keyset_keys(active_keyset_id).await?;
 
@@ -201,9 +206,12 @@ impl Wallet {
 
     /// Create Swap Payload
     #[instrument(skip(self, proofs, tx))]
+    #[allow(clippy::too_many_arguments)]
     pub async fn create_swap(
         &self,
         mut tx: DynWalletDatabaseTransaction<'_>,
+        active_keyset_id: Id,
+        fee_and_amounts: &FeeAndAmounts,
         amount: Option<Amount>,
         amount_split_target: SplitTarget,
         proofs: Proofs,
@@ -211,14 +219,10 @@ impl Wallet {
         include_fees: bool,
     ) -> Result<PreSwap, Error> {
         tracing::info!("Creating swap");
-        let active_keyset_id = self.fetch_active_keyset().await?.id;
 
         // Desired amount is either amount passed or value of all proof
         let proofs_total = proofs.total_amount()?;
         let fee = self.get_proofs_fee(&proofs).await?;
-        let fee_and_amounts = self
-            .get_keyset_fees_and_amounts_by_id(active_keyset_id)
-            .await?;
 
         let ys: Vec<PublicKey> = proofs.ys()?;
         tx.update_proofs_state(ys, State::Reserved).await?;
@@ -236,7 +240,7 @@ impl Wallet {
             true => {
                 let split_count = amount
                     .unwrap_or(Amount::ZERO)
-                    .split_targeted(&SplitTarget::default(), &fee_and_amounts)
+                    .split_targeted(&SplitTarget::default(), fee_and_amounts)
                     .unwrap()
                     .len();
 
@@ -260,7 +264,7 @@ impl Wallet {
         // else use state refill
         let change_split_target = match amount_split_target {
             SplitTarget::None => {
-                self.determine_split_target_values(&mut tx, change_amount, &fee_and_amounts)
+                self.determine_split_target_values(&mut tx, change_amount, fee_and_amounts)
                     .await?
             }
             s => s,
@@ -273,17 +277,17 @@ impl Wallet {
             Some(_) => {
                 // For spending conditions, we only need to count change secrets
                 change_amount
-                    .split_targeted(&change_split_target, &fee_and_amounts)?
+                    .split_targeted(&change_split_target, fee_and_amounts)?
                     .len() as u32
             }
             None => {
                 // For no spending conditions, count both send and change secrets
                 let send_count = send_amount
                     .unwrap_or(Amount::ZERO)
-                    .split_targeted(&SplitTarget::default(), &fee_and_amounts)?
+                    .split_targeted(&SplitTarget::default(), fee_and_amounts)?
                     .len() as u32;
                 let change_count = change_amount
-                    .split_targeted(&change_split_target, &fee_and_amounts)?
+                    .split_targeted(&change_split_target, fee_and_amounts)?
                     .len() as u32;
                 send_count + change_count
             }
@@ -316,7 +320,7 @@ impl Wallet {
                     &self.seed,
                     change_amount,
                     &change_split_target,
-                    &fee_and_amounts,
+                    fee_and_amounts,
                 )?;
 
                 derived_secret_count = change_premint_secrets.len();
@@ -327,7 +331,7 @@ impl Wallet {
                         send_amount.unwrap_or(Amount::ZERO),
                         &SplitTarget::default(),
                         &conditions,
-                        &fee_and_amounts,
+                        fee_and_amounts,
                     )?,
                     change_premint_secrets,
                 )
@@ -339,7 +343,7 @@ impl Wallet {
                     &self.seed,
                     send_amount.unwrap_or(Amount::ZERO),
                     &SplitTarget::default(),
-                    &fee_and_amounts,
+                    fee_and_amounts,
                 )?;
 
                 count += premint_secrets.len() as u32;
@@ -350,7 +354,7 @@ impl Wallet {
                     &self.seed,
                     change_amount,
                     &change_split_target,
-                    &fee_and_amounts,
+                    fee_and_amounts,
                 )?;
 
                 derived_secret_count = change_premint_secrets.len() + premint_secrets.len();