Procházet zdrojové kódy

feat: optimize SQL balance calculation

replace proof-fetching approach with SUM aggregation

- add get_balance() method to Database trait
- implement SQL SUM aggregation in cdk-sql-common
- update total_balance() to use get_balance() instead of get_unspent_proofs()
- redb impl maintains existing behavior
vnprc před 4 měsíci
rodič
revize
909173d7fe

+ 7 - 0
crates/cdk-common/src/database/wallet.rs

@@ -96,6 +96,13 @@ pub trait Database: Debug {
         state: Option<Vec<State>>,
         spending_conditions: Option<Vec<SpendingConditions>>,
     ) -> Result<Vec<ProofInfo>, Self::Err>;
+    /// Get balance efficiently using SQL aggregation
+    async fn get_balance(
+        &self,
+        mint_url: Option<MintUrl>,
+        unit: Option<CurrencyUnit>,
+        state: Option<Vec<State>>,
+    ) -> Result<u64, Self::Err>;
     /// Update proofs state in storage
     async fn update_proofs_state(&self, ys: Vec<PublicKey>, state: State) -> Result<(), Self::Err>;
 

+ 40 - 0
crates/cdk-ffi/src/database.rs

@@ -108,6 +108,14 @@ pub trait WalletDatabase: Send + Sync {
         spending_conditions: Option<Vec<SpendingConditions>>,
     ) -> Result<Vec<ProofInfo>, FfiError>;
 
+    /// Get balance efficiently using SQL aggregation
+    async fn get_balance(
+        &self,
+        mint_url: Option<MintUrl>,
+        unit: Option<CurrencyUnit>,
+        state: Option<Vec<ProofState>>,
+    ) -> Result<u64, FfiError>;
+
     /// Update proofs state in storage
     async fn update_proofs_state(
         &self,
@@ -465,6 +473,22 @@ impl CdkWalletDatabase for WalletDatabaseBridge {
         cdk_result
     }
 
+    async fn get_balance(
+        &self,
+        mint_url: Option<cdk::mint_url::MintUrl>,
+        unit: Option<cdk::nuts::CurrencyUnit>,
+        state: Option<Vec<cdk::nuts::State>>,
+    ) -> Result<u64, Self::Err> {
+        let ffi_mint_url = mint_url.map(Into::into);
+        let ffi_unit = unit.map(Into::into);
+        let ffi_state = state.map(|s| s.into_iter().map(Into::into).collect());
+
+        self.ffi_db
+            .get_balance(ffi_mint_url, ffi_unit, ffi_state)
+            .await
+            .map_err(|e| cdk::cdk_database::Error::Database(e.to_string().into()))
+    }
+
     async fn update_proofs_state(
         &self,
         ys: Vec<cdk::nuts::PublicKey>,
@@ -870,6 +894,22 @@ impl WalletDatabase for WalletSqliteDatabase {
         Ok(result.into_iter().map(Into::into).collect())
     }
 
+    async fn get_balance(
+        &self,
+        mint_url: Option<MintUrl>,
+        unit: Option<CurrencyUnit>,
+        state: Option<Vec<ProofState>>,
+    ) -> Result<u64, FfiError> {
+        let cdk_mint_url = mint_url.map(|u| u.try_into()).transpose()?;
+        let cdk_unit = unit.map(Into::into);
+        let cdk_state = state.map(|s| s.into_iter().map(Into::into).collect());
+
+        self.inner
+            .get_balance(cdk_mint_url, cdk_unit, cdk_state)
+            .await
+            .map_err(|e| FfiError::Database { msg: e.to_string() })
+    }
+
     async fn update_proofs_state(
         &self,
         ys: Vec<PublicKey>,

+ 12 - 0
crates/cdk-redb/src/wallet/mod.rs

@@ -721,6 +721,18 @@ impl WalletDatabase for WalletRedbDatabase {
         Ok(proofs)
     }
 
+    async fn get_balance(
+        &self,
+        mint_url: Option<MintUrl>,
+        unit: Option<CurrencyUnit>,
+        state: Option<Vec<State>>,
+    ) -> Result<u64, database::Error> {
+        // For redb, we still need to fetch all proofs and sum them
+        // since redb doesn't have SQL aggregation
+        let proofs = self.get_proofs(mint_url, unit, state, None).await?;
+        Ok(proofs.iter().map(|p| u64::from(p.proof.amount)).sum())
+    }
+
     async fn update_proofs_state(
         &self,
         ys: Vec<PublicKey>,

+ 69 - 0
crates/cdk-sql-common/src/wallet/mod.rs

@@ -836,6 +836,75 @@ ON CONFLICT(id) DO UPDATE SET
         .collect::<Vec<_>>())
     }
 
+    async fn get_balance(
+        &self,
+        mint_url: Option<MintUrl>,
+        unit: Option<CurrencyUnit>,
+        state: Option<Vec<State>>,
+    ) -> Result<u64, Self::Err> {
+        let start_time = std::time::Instant::now();
+        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
+
+        let mut query_str = "SELECT COALESCE(SUM(amount), 0) as total FROM proof".to_string();
+        let mut where_clauses = Vec::new();
+
+        if mint_url.is_some() {
+            where_clauses.push("mint_url = :mint_url");
+        }
+        if unit.is_some() {
+            where_clauses.push("unit = :unit");
+        }
+        if let Some(ref states) = state {
+            if !states.is_empty() {
+                where_clauses.push("state IN (:states)");
+            }
+        }
+
+        if !where_clauses.is_empty() {
+            query_str.push_str(" WHERE ");
+            query_str.push_str(&where_clauses.join(" AND "));
+        }
+
+        let mut q = query(&query_str)?;
+
+        if let Some(ref mint_url) = mint_url {
+            q = q.bind("mint_url", mint_url.to_string());
+        }
+        if let Some(ref unit) = unit {
+            q = q.bind("unit", unit.to_string());
+        }
+        if let Some(ref states) = state {
+            if !states.is_empty() {
+                q = q.bind_vec("states", states.iter().map(|s| s.to_string()).collect());
+            }
+        }
+
+        let balance = q
+            .pluck(&*conn)
+            .await?
+            .map(|n| {
+                // SQLite SUM returns INTEGER which we need to convert to u64
+                match n {
+                    crate::stmt::Column::Integer(i) => Ok(i as u64),
+                    crate::stmt::Column::Real(f) => Ok(f as u64),
+                    _ => Err(Error::Database(Box::new(std::io::Error::new(
+                        std::io::ErrorKind::InvalidData,
+                        "Invalid balance type",
+                    )))),
+                }
+            })
+            .transpose()?
+            .unwrap_or(0);
+
+        let elapsed = start_time.elapsed();
+        tracing::debug!(
+            "Balance query completed in {:.2}ms: {}",
+            elapsed.as_secs_f64() * 1000.0,
+            balance
+        );
+        Ok(balance)
+    }
+
     async fn update_proofs_state(&self, ys: Vec<PublicKey>, state: State) -> Result<(), Self::Err> {
         let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
         query("UPDATE proof SET state = :state WHERE y IN (:ys)")?

+ 11 - 2
crates/cdk/src/wallet/balance.rs

@@ -1,13 +1,22 @@
 use tracing::instrument;
 
-use crate::nuts::nut00::ProofsMethods;
+use crate::nuts::{nut00::ProofsMethods, State};
 use crate::{Amount, Error, Wallet};
 
 impl Wallet {
     /// Total unspent balance of wallet
     #[instrument(skip(self))]
     pub async fn total_balance(&self) -> Result<Amount, Error> {
-        Ok(self.get_unspent_proofs().await?.total_amount()?)
+        // Use the efficient balance query instead of fetching all proofs
+        let balance = self
+            .localstore
+            .get_balance(
+                Some(self.mint_url.clone()),
+                Some(self.unit.clone()),
+                Some(vec![State::Unspent]),
+            )
+            .await?;
+        Ok(Amount::from(balance))
     }
 
     /// Total pending balance