Explorar el Código

fix: batch proof witness queries in check_state to prevent pool exhaustion (#1514)

* feat: add tests with more proofs

* fix: batch proof witness queries in check_state to prevent pool exhaustion

The check_state function was making N individual database queries (one per
proof) inside a try_join_all loop. With many proofs, this spawned concurrent
queries that exhausted PostgreSQL's connection pool, causing timeouts.

This worked fine with SQLite (in-process, serialized) but failed with
PostgreSQL (network connections, fixed pool size).

Changes:
- Replace N individual get_proofs_by_ys calls with single batched query
- Use HashMap for O(1) witness lookup while preserving input order
- Add empty slice guard to prevent invalid "WHERE y IN ()" SQL
- Remove futures::try_join_all dependency
tsk hace 2 semanas
padre
commit
2e00411447

+ 2 - 2
crates/cashu/src/nuts/nut13.rs

@@ -204,7 +204,7 @@ impl PreMintSecrets {
     ) -> Result<Self, Error> {
         let mut pre_mint_secrets = PreMintSecrets::new(keyset_id);
 
-        for i in start_count..=end_count {
+        for i in start_count..end_count {
             let secret = Secret::from_seed(seed, keyset_id, i)?;
             let blinding_factor = SecretKey::from_seed(seed, keyset_id, i)?;
 
@@ -528,7 +528,7 @@ mod tests {
 
         assert_eq!(
             pre_mint_secrets.secrets.len(),
-            (end_count - start_count + 1) as usize
+            (end_count - start_count) as usize
         );
 
         // Verify each secret in the batch

+ 2 - 1
crates/cdk-integration-tests/src/cli.rs

@@ -29,9 +29,10 @@ pub fn init_logging(enable_logging: bool, log_level: tracing::Level) {
         let rustls_filter = "rustls=warn";
         let reqwest_filter = "reqwest=warn";
         let tower_filter = "tower_http=warn";
+        let tokio_postgres_filter = "tokio_postgres=warn";
 
         let env_filter = EnvFilter::new(format!(
-            "{default_filter},{hyper_filter},{h2_filter},{rustls_filter},{reqwest_filter},{tower_filter}"
+            "{default_filter},{hyper_filter},{h2_filter},{rustls_filter},{reqwest_filter},{tower_filter},{tokio_postgres_filter}"
         ));
 
         // Ok if successful, Err if already initialized

+ 4 - 1
crates/cdk-integration-tests/src/init_pure_tests.rs

@@ -249,8 +249,11 @@ pub fn setup_tracing() {
 
     let h2_filter = "h2=warn";
     let hyper_filter = "hyper=warn";
+    let tokio_postgres = "tokio_postgres=warn";
 
-    let env_filter = EnvFilter::new(format!("{default_filter},{h2_filter},{hyper_filter}"));
+    let env_filter = EnvFilter::new(format!(
+        "{default_filter},{h2_filter},{hyper_filter},{tokio_postgres}"
+    ));
 
     // Ok if successful, Err if already initialized
     // Allows us to setup tracing at the start of several parallel tests

+ 106 - 0
crates/cdk-integration-tests/tests/happy_path_mint_wallet.rs

@@ -355,6 +355,112 @@ async fn test_restore() {
     }
 }
 
+/// Tests wallet restoration with a large number of proofs (3000)
+///
+/// This test verifies the restore process works correctly with many proofs,
+/// which is important for testing database performance (especially PostgreSQL)
+/// and ensuring the restore batching logic handles large proof sets:
+/// 1. Creates a wallet and mints 3000 sats as individual 1-sat proofs
+/// 2. Creates a new wallet instance with the same seed but empty storage
+/// 3. Restores the wallet state from the mint (requires ~30 restore batches)
+/// 4. Verifies all 3000 proofs are correctly restored
+/// 5. Swaps the proofs to ensure they're valid
+/// 6. Checks that the original proofs are now marked as spent
+#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+async fn test_restore_large_proof_count() {
+    let seed = Mnemonic::generate(12).unwrap().to_seed_normalized("");
+    let wallet = Wallet::new(
+        &get_mint_url_from_env(),
+        CurrencyUnit::Sat,
+        Arc::new(memory::empty().await.unwrap()),
+        seed,
+        None,
+    )
+    .expect("failed to create new wallet");
+
+    let mint_amount: u64 = 3000;
+    let batch_size: u64 = 999; // Keep under 1000 outputs per request
+
+    // Mint in batches to avoid exceeding the 1000 output limit per request
+    let mut total_proofs = 0usize;
+    let mut remaining = mint_amount;
+
+    while remaining > 0 {
+        let batch = remaining.min(batch_size);
+
+        let mint_quote = wallet.mint_quote(batch.into(), None).await.unwrap();
+
+        let invoice = Bolt11Invoice::from_str(&mint_quote.request).unwrap();
+        pay_if_regtest(&get_test_temp_dir(), &invoice)
+            .await
+            .unwrap();
+
+        // Mint with SplitTarget::Value(1) to create individual 1-sat proofs
+        let proofs = wallet
+            .wait_and_mint_quote(
+                mint_quote.clone(),
+                SplitTarget::Value(1.into()),
+                None,
+                tokio::time::Duration::from_secs(120),
+            )
+            .await
+            .expect("payment");
+
+        total_proofs += proofs.len();
+        remaining -= batch;
+    }
+
+    assert_eq!(total_proofs, mint_amount as usize);
+    assert_eq!(wallet.total_balance().await.unwrap(), mint_amount.into());
+
+    let wallet_2 = Wallet::new(
+        &get_mint_url_from_env(),
+        CurrencyUnit::Sat,
+        Arc::new(memory::empty().await.unwrap()),
+        seed,
+        None,
+    )
+    .expect("failed to create new wallet");
+
+    assert_eq!(wallet_2.total_balance().await.unwrap(), 0.into());
+
+    let restored = wallet_2.restore().await.unwrap();
+    let proofs = wallet_2.get_unspent_proofs().await.unwrap();
+
+    assert_eq!(proofs.len(), mint_amount as usize);
+    assert_eq!(restored, mint_amount.into());
+
+    // Swap in batches to avoid exceeding the 1000 input limit per request
+    let mut total_fee = Amount::ZERO;
+    for batch in proofs.chunks(batch_size as usize) {
+        let batch_vec = batch.to_vec();
+        let batch_fee = wallet_2.get_proofs_fee(&batch_vec).await.unwrap().total;
+        total_fee += batch_fee;
+        wallet_2
+            .swap(None, SplitTarget::default(), batch.to_vec(), None, false)
+            .await
+            .unwrap();
+    }
+
+    // Since we have to do a swap we expect to restore amount - fee
+    assert_eq!(
+        wallet_2.total_balance().await.unwrap(),
+        Amount::from(mint_amount) - total_fee
+    );
+
+    let proofs = wallet.get_unspent_proofs().await.unwrap();
+
+    // Check proofs in batches to avoid large queries
+    for batch in proofs.chunks(100) {
+        let states = wallet.check_proofs_spent(batch.to_vec()).await.unwrap();
+        for state in states {
+            if state.state != State::Spent {
+                panic!("All proofs should be spent");
+            }
+        }
+    }
+}
+
 /// Tests that wallet restore correctly handles non-sequential counter values
 ///
 /// This test verifies that after restoring a wallet where there were gaps in the

+ 2 - 1
crates/cdk-mintd/src/lib.rs

@@ -121,9 +121,10 @@ pub fn setup_tracing(
     let tower_http = "tower_http=warn";
     let rustls = "rustls=warn";
     let tungstenite = "tungstenite=warn";
+    let tokio_postgres = "tokio_postgres=warn";
 
     let env_filter = EnvFilter::new(format!(
-        "{default_filter},{hyper_filter},{h2_filter},{tower_filter},{tower_http},{rustls},{tungstenite}"
+        "{default_filter},{hyper_filter},{h2_filter},{tower_filter},{tower_http},{rustls},{tungstenite},{tokio_postgres}"
     ));
 
     use config::LoggingOutput;

+ 40 - 24
crates/cdk/src/mint/check_spendable.rs

@@ -1,4 +1,5 @@
-use futures::future::try_join_all;
+use std::collections::HashMap;
+
 use tracing::instrument;
 
 use super::{CheckStateRequest, CheckStateResponse, Mint, ProofState, State};
@@ -12,29 +13,44 @@ impl Mint {
         check_state: &CheckStateRequest,
     ) -> Result<CheckStateResponse, Error> {
         let states = self.localstore.get_proofs_states(&check_state.ys).await?;
-        assert_eq!(check_state.ys.len(), states.len());
-
-        let proof_states_futures =
-            check_state
-                .ys
-                .iter()
-                .zip(states.iter())
-                .map(|(y, state)| async move {
-                    let witness: Result<Option<cdk_common::Witness>, Error> = if state.is_some() {
-                        let proofs = self.localstore.get_proofs_by_ys(&[*y]).await?;
-                        Ok(proofs.first().cloned().flatten().and_then(|p| p.witness))
-                    } else {
-                        Ok(None)
-                    };
-
-                    witness.map(|w| ProofState {
-                        y: *y,
-                        state: state.unwrap_or(State::Unspent),
-                        witness: w,
-                    })
-                });
-
-        let proof_states = try_join_all(proof_states_futures).await?;
+
+        if check_state.ys.len() != states.len() {
+            tracing::error!("Database did not return states for all proofs");
+            return Err(Error::UnknownPaymentState);
+        }
+
+        // Collect ys that need witness fetching (where state.is_some())
+        let ys_needing_witness: Vec<_> = check_state
+            .ys
+            .iter()
+            .zip(states.iter())
+            .filter_map(|(y, state)| state.as_ref().map(|_| *y))
+            .collect();
+
+        // Build a lookup map for witnesses (only query if there are ys to fetch)
+        let witness_map: HashMap<_, _> = if ys_needing_witness.is_empty() {
+            HashMap::new()
+        } else {
+            self.localstore
+                .get_proofs_by_ys(&ys_needing_witness)
+                .await?
+                .into_iter()
+                .flatten()
+                .filter_map(|p| p.y().ok().map(|y| (y, p.witness)))
+                .collect()
+        };
+
+        // Construct response without additional queries
+        let proof_states = check_state
+            .ys
+            .iter()
+            .zip(states.iter())
+            .map(|(y, state)| ProofState {
+                y: *y,
+                state: state.unwrap_or(State::Unspent),
+                witness: witness_map.get(y).cloned().flatten(),
+            })
+            .collect();
 
         Ok(CheckStateResponse {
             states: proof_states,