浏览代码

Simplify `process_swap_request` (#631)

* Simplify process_swap_request

* Fix occasional test_swap_to_send wallet errors
ok300 1 月之前
父节点
当前提交
5a7362c09f

+ 2 - 6
crates/cdk-integration-tests/src/init_pure_tests.rs

@@ -157,9 +157,7 @@ pub async fn create_and_start_test_mint() -> anyhow::Result<Arc<Mint>> {
 
     let mut mint_builder = MintBuilder::new();
 
-    let database = cdk_sqlite::mint::memory::empty()
-        .await
-        .expect("valid db instance");
+    let database = cdk_sqlite::mint::memory::empty().await?;
 
     let localstore = Arc::new(database);
     mint_builder = mint_builder.with_localstore(localstore.clone());
@@ -216,9 +214,7 @@ pub async fn create_test_wallet_for_mint(mint: Arc<Mint>) -> anyhow::Result<Arc<
     let seed = Mnemonic::generate(12)?.to_seed_normalized("");
     let mint_url = "http://aa".to_string();
     let unit = CurrencyUnit::Sat;
-    let localstore = cdk_sqlite::wallet::memory::empty()
-        .await
-        .expect("valid db instance");
+    let localstore = cdk_sqlite::wallet::memory::empty().await?;
     let mut wallet = Wallet::new(&mint_url, unit, Arc::new(localstore), &seed, None)?;
 
     wallet.set_client(connector);

+ 6 - 17
crates/cdk-integration-tests/tests/fake_wallet.rs

@@ -996,17 +996,10 @@ async fn test_fake_mint_swap_spend_after_fail() -> Result<()> {
 
     match response {
         Err(err) => match err {
-            cdk::Error::TokenAlreadySpent => (),
-            err => {
-                bail!(
-                    "Wrong mint error returned expected already spent: {}",
-                    err.to_string()
-                );
-            }
+            cdk::Error::TransactionUnbalanced(_, _, _) => (),
+            err => bail!("Wrong mint error returned expected TransactionUnbalanced, got: {err}"),
         },
-        Ok(_) => {
-            bail!("Should not have allowed swap with unbalanced");
-        }
+        Ok(_) => bail!("Should not have allowed swap with unbalanced"),
     }
 
     let pre_mint = PreMintSecrets::random(active_keyset_id, 100.into(), &SplitTarget::None)?;
@@ -1076,14 +1069,10 @@ async fn test_fake_mint_melt_spend_after_fail() -> Result<()> {
 
     match response {
         Err(err) => match err {
-            cdk::Error::TokenAlreadySpent => (),
-            err => {
-                bail!("Wrong mint error returned: {}", err.to_string());
-            }
+            cdk::Error::TransactionUnbalanced(_, _, _) => (),
+            err => bail!("Wrong mint error returned expected TransactionUnbalanced, got: {err}"),
         },
-        Ok(_) => {
-            bail!("Should not have allowed to mint with multiple units");
-        }
+        Ok(_) => bail!("Should not have allowed swap with unbalanced"),
     }
 
     let input_amount: u64 = proofs.total_amount()?.into();

+ 9 - 25
crates/cdk-integration-tests/tests/mint.rs

@@ -2,7 +2,6 @@
 
 use std::collections::{HashMap, HashSet};
 use std::sync::Arc;
-use std::time::Duration;
 
 use anyhow::{bail, Result};
 use bip39::Mnemonic;
@@ -21,7 +20,6 @@ use cdk::util::unix_time;
 use cdk::Mint;
 use cdk_fake_wallet::FakeWallet;
 use cdk_sqlite::mint::memory;
-use tokio::time::sleep;
 
 pub const MINT_URL: &str = "http://127.0.0.1:8088";
 
@@ -215,20 +213,12 @@ pub async fn test_p2pk_swap() -> Result<()> {
 
     let swap_request = SwapRequest::new(proofs.clone(), pre_swap.blinded_messages());
 
+    // Listen for status updates on all input proof pks
     let public_keys_to_listen: Vec<_> = swap_request
         .inputs
-        .ys()
-        .expect("key")
-        .into_iter()
-        .enumerate()
-        .filter_map(|(key, pk)| {
-            if key % 2 == 0 {
-                // Only expect messages from every other key
-                Some(pk.to_string())
-            } else {
-                None
-            }
-        })
+        .ys()?
+        .iter()
+        .map(|pk| pk.to_string())
         .collect();
 
     let mut listener = mint
@@ -265,21 +255,14 @@ pub async fn test_p2pk_swap() -> Result<()> {
 
     assert!(attempt_swap.is_ok());
 
-    sleep(Duration::from_millis(10)).await;
-
     let mut msgs = HashMap::new();
     while let Ok((sub_id, msg)) = listener.try_recv() {
         assert_eq!(sub_id, "test".into());
         match msg {
             NotificationPayload::ProofState(ProofState { y, state, .. }) => {
-                let pk = y.to_string();
-                msgs.get_mut(&pk)
-                    .map(|x: &mut Vec<State>| {
-                        x.push(state);
-                    })
-                    .unwrap_or_else(|| {
-                        msgs.insert(pk, vec![state]);
-                    });
+                msgs.entry(y.to_string())
+                    .or_insert_with(Vec::new)
+                    .push(state);
             }
             _ => bail!("Wrong message received"),
         }
@@ -287,7 +270,8 @@ pub async fn test_p2pk_swap() -> Result<()> {
 
     for keys in public_keys_to_listen {
         let statuses = msgs.remove(&keys).expect("some events");
-        assert_eq!(statuses, vec![State::Pending, State::Pending, State::Spent]);
+        // Every input pk receives two state updates, as there are only two state transitions
+        assert_eq!(statuses, vec![State::Pending, State::Spent]);
     }
 
     assert!(listener.try_recv().is_err(), "no other event is happening");

+ 5 - 1
crates/cdk-sqlite/src/wallet/memory.rs

@@ -6,7 +6,11 @@ use super::WalletSqliteDatabase;
 
 /// Creates a new in-memory [`WalletSqliteDatabase`] instance
 pub async fn empty() -> Result<WalletSqliteDatabase, Error> {
-    let db = WalletSqliteDatabase::new(":memory:").await?;
+    let db = WalletSqliteDatabase {
+        pool: sqlx::sqlite::SqlitePool::connect(":memory:")
+            .await
+            .map_err(|e| Error::Database(Box::new(e)))?,
+    };
     db.migrate().await;
     Ok(db)
 }

+ 28 - 24
crates/cdk/src/mint/swap.rs

@@ -12,38 +12,22 @@ impl Mint {
         &self,
         swap_request: SwapRequest,
     ) -> Result<SwapResponse, Error> {
-        let input_ys = swap_request.inputs.ys()?;
-
-        self.localstore
-            .add_proofs(swap_request.inputs.clone(), None)
-            .await?;
-        self.check_ys_spendable(&input_ys, State::Pending).await?;
-
         if let Err(err) = self
             .verify_transaction_balanced(&swap_request.inputs, &swap_request.outputs)
             .await
         {
-            tracing::debug!("Attempt to swap unbalanced transaction: {}", err);
-            self.localstore.remove_proofs(&input_ys, None).await?;
+            tracing::debug!("Attempt to swap unbalanced transaction, aborting: {err}");
             return Err(err);
         };
 
-        let EnforceSigFlag {
-            sig_flag,
-            pubkeys,
-            sigs_required,
-        } = enforce_sig_flag(swap_request.inputs.clone());
+        self.validate_sig_flag(&swap_request).await?;
 
-        if sig_flag.eq(&SigFlag::SigAll) {
-            let pubkeys = pubkeys.into_iter().collect();
-            for blinded_message in &swap_request.outputs {
-                if let Err(err) = blinded_message.verify_p2pk(&pubkeys, sigs_required) {
-                    tracing::info!("Could not verify p2pk in swap request");
-                    self.localstore.remove_proofs(&input_ys, None).await?;
-                    return Err(err.into());
-                }
-            }
-        }
+        // After swap request is fully validated, add the new proofs to DB
+        let input_ys = swap_request.inputs.ys()?;
+        self.localstore
+            .add_proofs(swap_request.inputs.clone(), None)
+            .await?;
+        self.check_ys_spendable(&input_ys, State::Pending).await?;
 
         let mut promises = Vec::with_capacity(swap_request.outputs.len());
 
@@ -74,4 +58,24 @@ impl Mint {
 
         Ok(SwapResponse::new(promises))
     }
+
+    async fn validate_sig_flag(&self, swap_request: &SwapRequest) -> Result<(), Error> {
+        let EnforceSigFlag {
+            sig_flag,
+            pubkeys,
+            sigs_required,
+        } = enforce_sig_flag(swap_request.inputs.clone());
+
+        if sig_flag.eq(&SigFlag::SigAll) {
+            let pubkeys = pubkeys.into_iter().collect();
+            for blinded_message in &swap_request.outputs {
+                if let Err(err) = blinded_message.verify_p2pk(&pubkeys, sigs_required) {
+                    tracing::info!("Could not verify p2pk in swap request");
+                    return Err(err.into());
+                }
+            }
+        }
+
+        Ok(())
+    }
 }