Преглед изворни кода

Merge pull request #944 from vnprc/keyset_counter_atomic_increment_2

fix: implement atomic keyset counter
thesimplekid пре 2 месеци
родитељ
комит
3104a999e9

+ 2 - 4
crates/cdk-common/src/database/wallet.rs

@@ -99,10 +99,8 @@ pub trait Database: Debug {
     /// Update proofs state in storage
     async fn update_proofs_state(&self, ys: Vec<PublicKey>, state: State) -> Result<(), Self::Err>;
 
-    /// Increment Keyset counter
-    async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<(), Self::Err>;
-    /// Get current Keyset counter
-    async fn get_keyset_counter(&self, keyset_id: &Id) -> Result<u32, Self::Err>;
+    /// Atomically increment Keyset counter and return new value
+    async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<u32, Self::Err>;
 
     /// Add transaction to storage
     async fn add_transaction(&self, transaction: Transaction) -> Result<(), Self::Err>;

+ 5 - 15
crates/cdk-redb/src/wallet/mod.rs

@@ -760,10 +760,11 @@ impl WalletDatabase for WalletRedbDatabase {
     }
 
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
-    async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<(), Self::Err> {
+    async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<u32, Self::Err> {
         let write_txn = self.db.begin_write().map_err(Error::from)?;
 
         let current_counter;
+        let new_counter;
         {
             let table = write_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?;
             let counter = table
@@ -774,11 +775,12 @@ impl WalletDatabase for WalletRedbDatabase {
                 Some(c) => c.value(),
                 None => 0,
             };
+
+            new_counter = current_counter + count;
         }
 
         {
             let mut table = write_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?;
-            let new_counter = current_counter + count;
 
             table
                 .insert(keyset_id.to_string().as_str(), new_counter)
@@ -786,19 +788,7 @@ impl WalletDatabase for WalletRedbDatabase {
         }
         write_txn.commit().map_err(Error::from)?;
 
-        Ok(())
-    }
-
-    #[instrument(skip(self), fields(keyset_id = %keyset_id))]
-    async fn get_keyset_counter(&self, keyset_id: &Id) -> Result<u32, Self::Err> {
-        let read_txn = self.db.begin_read().map_err(Error::from)?;
-        let table = read_txn.open_table(KEYSET_COUNTER).map_err(Error::from)?;
-
-        let counter = table
-            .get(keyset_id.to_string().as_str())
-            .map_err(Error::from)?;
-
-        Ok(counter.map_or(0, |c| c.value()))
+        Ok(new_counter)
     }
 
     #[instrument(skip(self))]

+ 26 - 24
crates/cdk-sql-common/src/wallet/mod.rs

@@ -839,42 +839,44 @@ ON CONFLICT(id) DO UPDATE SET
     }
 
     #[instrument(skip(self), fields(keyset_id = %keyset_id))]
-    async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<(), Self::Err> {
+    async fn increment_keyset_counter(&self, keyset_id: &Id, count: u32) -> Result<u32, Self::Err> {
         let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
-        query(
+        let tx = ConnectionWithTransaction::new(conn).await?;
+
+        // Lock the row and get current counter
+        let current_counter = query(
             r#"
-            UPDATE keyset
-            SET counter=counter+:count
+            SELECT counter
+            FROM keyset
             WHERE id=:id
+            FOR UPDATE
             "#,
         )?
-        .bind("count", count)
         .bind("id", keyset_id.to_string())
-        .execute(&*conn)
-        .await?;
+        .pluck(&tx)
+        .await?
+        .map(|n| Ok::<_, Error>(column_as_number!(n)))
+        .transpose()?
+        .unwrap_or(0);
 
-        Ok(())
-    }
+        let new_counter = current_counter + count;
 
-    #[instrument(skip(self), fields(keyset_id = %keyset_id))]
-    async fn get_keyset_counter(&self, keyset_id: &Id) -> Result<u32, Self::Err> {
-        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
-        Ok(query(
+        // Update with the new counter value
+        query(
             r#"
-            SELECT
-                counter
-            FROM
-                keyset
-            WHERE
-                id=:id
+            UPDATE keyset
+            SET counter=:new_counter
+            WHERE id=:id
             "#,
         )?
+        .bind("new_counter", new_counter)
         .bind("id", keyset_id.to_string())
-        .pluck(&*conn)
-        .await?
-        .map(|n| Ok::<_, Error>(column_as_number!(n)))
-        .transpose()?
-        .unwrap_or(0))
+        .execute(&tx)
+        .await?;
+
+        tx.commit().await?;
+
+        Ok(new_counter)
     }
 
     #[instrument(skip(self))]

+ 27 - 25
crates/cdk/src/wallet/issue/issue_bolt11.rs

@@ -229,11 +229,6 @@ impl Wallet {
 
         let active_keyset_id = self.fetch_active_keyset().await?.id;
 
-        let count = self
-            .localstore
-            .get_keyset_counter(&active_keyset_id)
-            .await?;
-
         let premint_secrets = match &spending_conditions {
             Some(spending_conditions) => PreMintSecrets::with_conditions(
                 active_keyset_id,
@@ -241,13 +236,33 @@ impl Wallet {
                 &amount_split_target,
                 spending_conditions,
             )?,
-            None => PreMintSecrets::from_seed(
-                active_keyset_id,
-                count,
-                &self.seed,
-                amount_mintable,
-                &amount_split_target,
-            )?,
+            None => {
+                // Calculate how many secrets we'll need
+                let amount_split = amount_mintable.split_targeted(&amount_split_target)?;
+                let num_secrets = amount_split.len() as u32;
+
+                tracing::debug!(
+                    "Incrementing keyset {} counter by {}",
+                    active_keyset_id,
+                    num_secrets
+                );
+
+                // Atomically get the counter range we need
+                let new_counter = self
+                    .localstore
+                    .increment_keyset_counter(&active_keyset_id, num_secrets)
+                    .await?;
+
+                let count = new_counter - num_secrets;
+
+                PreMintSecrets::from_seed(
+                    active_keyset_id,
+                    count,
+                    &self.seed,
+                    amount_mintable,
+                    &amount_split_target,
+                )?
+            }
         };
 
         let mut request = MintRequest {
@@ -286,19 +301,6 @@ impl Wallet {
         // Remove filled quote from store
         self.localstore.remove_mint_quote(&quote_info.id).await?;
 
-        if spending_conditions.is_none() {
-            tracing::debug!(
-                "Incrementing keyset {} counter by {}",
-                active_keyset_id,
-                proofs.len()
-            );
-
-            // Update counter for keyset
-            self.localstore
-                .increment_keyset_counter(&active_keyset_id, proofs.len() as u32)
-                .await?;
-        }
-
         let proof_infos = proofs
             .iter()
             .map(|proof| {

+ 27 - 19
crates/cdk/src/wallet/issue/issue_bolt12.rs

@@ -107,11 +107,6 @@ impl Wallet {
 
         let active_keyset_id = self.fetch_active_keyset().await?.id;
 
-        let count = self
-            .localstore
-            .get_keyset_counter(&active_keyset_id)
-            .await?;
-
         let amount = match amount {
             Some(amount) => amount,
             None => {
@@ -135,13 +130,33 @@ impl Wallet {
                 &amount_split_target,
                 spending_conditions,
             )?,
-            None => PreMintSecrets::from_seed(
-                active_keyset_id,
-                count,
-                &self.seed,
-                amount,
-                &amount_split_target,
-            )?,
+            None => {
+                // Calculate how many secrets we'll need without generating them
+                let amount_split = amount.split_targeted(&amount_split_target)?;
+                let num_secrets = amount_split.len() as u32;
+
+                tracing::debug!(
+                    "Incrementing keyset {} counter by {}",
+                    active_keyset_id,
+                    num_secrets
+                );
+
+                // Atomically get the counter range we need
+                let new_counter = self
+                    .localstore
+                    .increment_keyset_counter(&active_keyset_id, num_secrets)
+                    .await?;
+
+                let count = new_counter - num_secrets;
+
+                PreMintSecrets::from_seed(
+                    active_keyset_id,
+                    count,
+                    &self.seed,
+                    amount,
+                    &amount_split_target,
+                )?
+            }
         };
 
         let mut request = MintRequest {
@@ -190,13 +205,6 @@ impl Wallet {
 
         self.localstore.add_mint_quote(quote_info.clone()).await?;
 
-        if spending_conditions.is_none() {
-            // Update counter for keyset
-            self.localstore
-                .increment_keyset_counter(&active_keyset_id, proofs.len() as u32)
-                .await?;
-        }
-
         let proof_infos = proofs
             .iter()
             .map(|proof| {

+ 26 - 16
crates/cdk/src/wallet/melt/melt_bolt11.rs

@@ -15,7 +15,7 @@ use crate::nuts::{
 use crate::types::{Melted, ProofInfo};
 use crate::util::unix_time;
 use crate::wallet::MeltQuote;
-use crate::{ensure_cdk, Error, Wallet};
+use crate::{ensure_cdk, Amount, Error, Wallet};
 
 impl Wallet {
     /// Melt Quote
@@ -148,17 +148,32 @@ impl Wallet {
 
         let active_keyset_id = self.fetch_active_keyset().await?.id;
 
-        let count = self
-            .localstore
-            .get_keyset_counter(&active_keyset_id)
-            .await?;
+        let change_amount = proofs_total - quote_info.amount;
 
-        let premint_secrets = PreMintSecrets::from_seed_blank(
-            active_keyset_id,
-            count,
-            &self.seed,
-            proofs_total - quote_info.amount,
-        )?;
+        let premint_secrets = if change_amount <= Amount::ZERO {
+            PreMintSecrets::new(active_keyset_id)
+        } else {
+            // TODO: consolidate this calculation with from_seed_blank into a shared function
+            // Calculate how many secrets will be needed using the same logic as from_seed_blank
+            let num_secrets =
+                ((u64::from(change_amount) as f64).log2().ceil() as u64).max(1) as u32;
+
+            tracing::debug!(
+                "Incrementing keyset {} counter by {}",
+                active_keyset_id,
+                num_secrets
+            );
+
+            // Atomically get the counter range we need
+            let new_counter = self
+                .localstore
+                .increment_keyset_counter(&active_keyset_id, num_secrets)
+                .await?;
+
+            let count = new_counter - num_secrets;
+
+            PreMintSecrets::from_seed_blank(active_keyset_id, count, &self.seed, change_amount)?
+        };
 
         let request = MeltRequest::new(
             quote_id.to_string(),
@@ -226,11 +241,6 @@ impl Wallet {
                     change_proofs.total_amount()?
                 );
 
-                // Update counter for keyset
-                self.localstore
-                    .increment_keyset_counter(&active_keyset_id, change_proofs.len() as u32)
-                    .await?;
-
                 change_proofs
                     .into_iter()
                     .map(|proof| {

+ 36 - 8
crates/cdk/src/wallet/swap.rs

@@ -52,10 +52,6 @@ impl Wallet {
             &active_keys,
         )?;
 
-        self.localstore
-            .increment_keyset_counter(&active_keyset_id, pre_swap.derived_secret_count)
-            .await?;
-
         let mut added_proofs = Vec::new();
         let change_proofs;
         let send_proofs;
@@ -248,10 +244,42 @@ impl Wallet {
 
         let derived_secret_count;
 
-        let mut count = self
-            .localstore
-            .get_keyset_counter(&active_keyset_id)
-            .await?;
+        // Calculate total secrets needed and atomically reserve counter range
+        let total_secrets_needed = match spending_conditions {
+            Some(_) => {
+                // For spending conditions, we only need to count change secrets
+                change_amount.split_targeted(&change_split_target)?.len() as u32
+            }
+            None => {
+                // For no spending conditions, count both send and change secrets
+                let send_count = send_amount
+                    .unwrap_or(Amount::ZERO)
+                    .split_targeted(&SplitTarget::default())?
+                    .len() as u32;
+                let change_count = change_amount.split_targeted(&change_split_target)?.len() as u32;
+                send_count + change_count
+            }
+        };
+
+        // Atomically get the counter range we need
+        let starting_counter = if total_secrets_needed > 0 {
+            tracing::debug!(
+                "Incrementing keyset {} counter by {}",
+                active_keyset_id,
+                total_secrets_needed
+            );
+
+            let new_counter = self
+                .localstore
+                .increment_keyset_counter(&active_keyset_id, total_secrets_needed)
+                .await?;
+
+            new_counter - total_secrets_needed
+        } else {
+            0 // No secrets needed, don't increment the counter
+        };
+
+        let mut count = starting_counter;
 
         let (mut desired_messages, change_messages) = match spending_conditions {
             Some(conditions) => {