Эх сурвалжийг харах

Refactor mint tranactions checks (#585)

* refactor: consolidate validation checks

* refactor: melt verification checks

* refactor: mint verification

* chore: clippy

* chore: use error codes

* fix: order of verifications

* fix: p2pk test ws updates

We only expect the proof to be set to pending once. As a proof without
a signature failes before the spent check where the state is chaged.

* fix: mint_melt regtest frome wait
thesimplekid 2 сар өмнө
parent
commit
a8ec52612b

+ 50 - 4
crates/cdk-common/src/error.rs

@@ -103,11 +103,17 @@ pub enum Error {
     #[error("Inputs: `{0}`, Outputs: `{1}`, Expected Fee: `{2}`")]
     TransactionUnbalanced(u64, u64, u64),
     /// Duplicate proofs provided
-    #[error("Duplicate proofs")]
-    DuplicateProofs,
+    #[error("Duplicate Inputs")]
+    DuplicateInputs,
+    /// Duplicate output
+    #[error("Duplicate outputs")]
+    DuplicateOutputs,
     /// Multiple units provided
     #[error("Cannot have multiple units")]
     MultipleUnits,
+    /// Unit mismatch
+    #[error("Input unit must match output")]
+    UnitMismatch,
     /// Sig all cannot be used in melt
     #[error("Sig all cannot be used in melt")]
     SigAllUsedInMelt,
@@ -393,6 +399,26 @@ impl From<Error> for ErrorResponse {
                 error: Some(err.to_string()),
                 detail: None,
             },
+            Error::DuplicateInputs => ErrorResponse {
+                code: ErrorCode::DuplicateInputs,
+                error: Some(err.to_string()),
+                detail: None,
+            },
+            Error::DuplicateOutputs => ErrorResponse {
+                code: ErrorCode::DuplicateOutputs,
+                error: Some(err.to_string()),
+                detail: None,
+            },
+            Error::MultipleUnits => ErrorResponse {
+                code: ErrorCode::MultipleUnits,
+                error: Some(err.to_string()),
+                detail: None,
+            },
+            Error::UnitMismatch => ErrorResponse {
+                code: ErrorCode::UnitMismatch,
+                error: Some(err.to_string()),
+                detail: None,
+            },
             _ => ErrorResponse {
                 code: ErrorCode::Unknown(9999),
                 error: Some(err.to_string()),
@@ -423,6 +449,10 @@ impl From<ErrorResponse> for Error {
             }
             ErrorCode::TokenPending => Self::TokenPending,
             ErrorCode::WitnessMissingOrInvalid => Self::SignatureMissingOrInvalid,
+            ErrorCode::DuplicateInputs => Self::DuplicateInputs,
+            ErrorCode::DuplicateOutputs => Self::DuplicateOutputs,
+            ErrorCode::MultipleUnits => Self::MultipleUnits,
+            ErrorCode::UnitMismatch => Self::UnitMismatch,
             _ => Self::UnknownErrorResponse(err.to_string()),
         }
     }
@@ -466,6 +496,14 @@ pub enum ErrorCode {
     AmountOutofLimitRange,
     /// Witness missing or invalid
     WitnessMissingOrInvalid,
+    /// Duplicate Inputs
+    DuplicateInputs,
+    /// Duplicate Outputs
+    DuplicateOutputs,
+    /// Multiple Units
+    MultipleUnits,
+    /// Input unit does not match output
+    UnitMismatch,
     /// Unknown error code
     Unknown(u16),
 }
@@ -480,7 +518,11 @@ impl ErrorCode {
             11002 => Self::TransactionUnbalanced,
             11005 => Self::UnsupportedUnit,
             11006 => Self::AmountOutofLimitRange,
-            11007 => Self::TokenPending,
+            11007 => Self::DuplicateInputs,
+            11008 => Self::DuplicateOutputs,
+            11009 => Self::MultipleUnits,
+            11010 => Self::UnitMismatch,
+            11012 => Self::TokenPending,
             12001 => Self::KeysetNotFound,
             12002 => Self::KeysetInactive,
             20000 => Self::LightningError,
@@ -504,7 +546,11 @@ impl ErrorCode {
             Self::TransactionUnbalanced => 11002,
             Self::UnsupportedUnit => 11005,
             Self::AmountOutofLimitRange => 11006,
-            Self::TokenPending => 11007,
+            Self::DuplicateInputs => 11007,
+            Self::DuplicateOutputs => 11008,
+            Self::MultipleUnits => 11009,
+            Self::UnitMismatch => 11010,
+            Self::TokenPending => 11012,
             Self::KeysetNotFound => 12001,
             Self::KeysetInactive => 12002,
             Self::LightningError => 20000,

+ 13 - 0
crates/cdk-integration-tests/src/init_pure_tests.rs

@@ -23,6 +23,7 @@ use cdk::wallet::Wallet;
 use cdk::{Amount, Error, Mint};
 use cdk_fake_wallet::FakeWallet;
 use tokio::sync::Notify;
+use tracing_subscriber::EnvFilter;
 use uuid::Uuid;
 
 use crate::wait_for_mint_to_be_paid;
@@ -143,6 +144,18 @@ impl MintConnector for DirectMintConnection {
 }
 
 pub async fn create_and_start_test_mint() -> anyhow::Result<Arc<Mint>> {
+    let default_filter = "debug";
+
+    let sqlx_filter = "sqlx=warn";
+    let hyper_filter = "hyper=warn";
+
+    let env_filter = EnvFilter::new(format!(
+        "{},{},{}",
+        default_filter, sqlx_filter, hyper_filter
+    ));
+
+    tracing_subscriber::fmt().with_env_filter(env_filter).init();
+
     let mut mint_builder = MintBuilder::new();
 
     let database = MintMemoryDatabase::default();

+ 224 - 5
crates/cdk-integration-tests/tests/fake_wallet.rs

@@ -586,7 +586,7 @@ async fn test_fake_mint_multiple_units() -> Result<()> {
 
     match response {
         Err(err) => match err {
-            cdk::Error::UnsupportedUnit => (),
+            cdk::Error::MultipleUnits => (),
             err => {
                 bail!("Wrong mint error returned: {}", err.to_string());
             }
@@ -652,7 +652,7 @@ async fn test_fake_mint_multiple_unit_swap() -> Result<()> {
 
         match response {
             Err(err) => match err {
-                cdk::Error::UnsupportedUnit => (),
+                cdk::Error::MultipleUnits => (),
                 err => {
                     bail!("Wrong mint error returned: {}", err.to_string());
                 }
@@ -689,7 +689,7 @@ async fn test_fake_mint_multiple_unit_swap() -> Result<()> {
 
         match response {
             Err(err) => match err {
-                cdk::Error::UnsupportedUnit => (),
+                cdk::Error::MultipleUnits => (),
                 err => {
                     bail!("Wrong mint error returned: {}", err.to_string());
                 }
@@ -763,7 +763,7 @@ async fn test_fake_mint_multiple_unit_melt() -> Result<()> {
 
         match response {
             Err(err) => match err {
-                cdk::Error::UnsupportedUnit => (),
+                cdk::Error::MultipleUnits => (),
                 err => {
                     bail!("Wrong mint error returned: {}", err.to_string());
                 }
@@ -807,7 +807,7 @@ async fn test_fake_mint_multiple_unit_melt() -> Result<()> {
 
         match response {
             Err(err) => match err {
-                cdk::Error::UnsupportedUnit => (),
+                cdk::Error::MultipleUnits => (),
                 err => {
                     bail!("Wrong mint error returned: {}", err.to_string());
                 }
@@ -820,3 +820,222 @@ async fn test_fake_mint_multiple_unit_melt() -> Result<()> {
 
     Ok(())
 }
+
+/// Test swap where input unit != output unit
+#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+async fn test_fake_mint_input_output_mismatch() -> Result<()> {
+    let wallet = Wallet::new(
+        MINT_URL,
+        CurrencyUnit::Sat,
+        Arc::new(WalletMemoryDatabase::default()),
+        &Mnemonic::generate(12)?.to_seed_normalized(""),
+        None,
+    )?;
+
+    let mint_quote = wallet.mint_quote(100.into(), None).await?;
+
+    wait_for_mint_to_be_paid(&wallet, &mint_quote.id, 60).await?;
+
+    let proofs = wallet.mint(&mint_quote.id, SplitTarget::None, None).await?;
+
+    let wallet_usd = Wallet::new(
+        MINT_URL,
+        CurrencyUnit::Usd,
+        Arc::new(WalletMemoryDatabase::default()),
+        &Mnemonic::generate(12)?.to_seed_normalized(""),
+        None,
+    )?;
+
+    let usd_active_keyset_id = wallet_usd.get_active_mint_keyset().await?.id;
+
+    let inputs = proofs;
+
+    let pre_mint = PreMintSecrets::random(
+        usd_active_keyset_id,
+        inputs.total_amount()?,
+        &SplitTarget::None,
+    )?;
+
+    let swap_request = SwapRequest {
+        inputs,
+        outputs: pre_mint.blinded_messages(),
+    };
+
+    let http_client = HttpClient::new(MINT_URL.parse()?);
+    let response = http_client.post_swap(swap_request.clone()).await;
+
+    match response {
+        Err(err) => match err {
+            cdk::Error::UnsupportedUnit => (),
+            _ => {}
+        },
+        Ok(_) => {
+            bail!("Should not have allowed to mint with multiple units");
+        }
+    }
+
+    Ok(())
+}
+
+/// Test swap where input is less the output
+#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+async fn test_fake_mint_swap_inflated() -> Result<()> {
+    let wallet = Wallet::new(
+        MINT_URL,
+        CurrencyUnit::Sat,
+        Arc::new(WalletMemoryDatabase::default()),
+        &Mnemonic::generate(12)?.to_seed_normalized(""),
+        None,
+    )?;
+
+    let mint_quote = wallet.mint_quote(100.into(), None).await?;
+
+    wait_for_mint_to_be_paid(&wallet, &mint_quote.id, 60).await?;
+
+    let proofs = wallet.mint(&mint_quote.id, SplitTarget::None, None).await?;
+    let active_keyset_id = wallet.get_active_mint_keyset().await?.id;
+    let pre_mint = PreMintSecrets::random(active_keyset_id, 101.into(), &SplitTarget::None)?;
+
+    let swap_request = SwapRequest {
+        inputs: proofs,
+        outputs: pre_mint.blinded_messages(),
+    };
+
+    let http_client = HttpClient::new(MINT_URL.parse()?);
+    let response = http_client.post_swap(swap_request.clone()).await;
+
+    match response {
+        Err(err) => match err {
+            cdk::Error::TransactionUnbalanced(_, _, _) => (),
+            err => {
+                bail!("Wrong mint error returned: {}", err.to_string());
+            }
+        },
+        Ok(_) => {
+            bail!("Should not have allowed to mint with multiple units");
+        }
+    }
+
+    Ok(())
+}
+
+/// Test swap where input unit != output unit
+#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+async fn test_fake_mint_duplicate_proofs_swap() -> Result<()> {
+    let wallet = Wallet::new(
+        MINT_URL,
+        CurrencyUnit::Sat,
+        Arc::new(WalletMemoryDatabase::default()),
+        &Mnemonic::generate(12)?.to_seed_normalized(""),
+        None,
+    )?;
+
+    let mint_quote = wallet.mint_quote(100.into(), None).await?;
+
+    wait_for_mint_to_be_paid(&wallet, &mint_quote.id, 60).await?;
+
+    let proofs = wallet.mint(&mint_quote.id, SplitTarget::None, None).await?;
+
+    let active_keyset_id = wallet.get_active_mint_keyset().await?.id;
+
+    let inputs = vec![proofs[0].clone(), proofs[0].clone()];
+
+    let pre_mint =
+        PreMintSecrets::random(active_keyset_id, inputs.total_amount()?, &SplitTarget::None)?;
+
+    let swap_request = SwapRequest {
+        inputs: inputs.clone(),
+        outputs: pre_mint.blinded_messages(),
+    };
+
+    let http_client = HttpClient::new(MINT_URL.parse()?);
+    let response = http_client.post_swap(swap_request.clone()).await;
+
+    match response {
+        Err(err) => match err {
+            cdk::Error::DuplicateInputs => (),
+            err => {
+                bail!(
+                    "Wrong mint error returned, expected duplicate inputs: {}",
+                    err.to_string()
+                );
+            }
+        },
+        Ok(_) => {
+            bail!("Should not have allowed duplicate inputs");
+        }
+    }
+
+    let blinded_message = pre_mint.blinded_messages();
+
+    let outputs = vec![blinded_message[0].clone(), blinded_message[0].clone()];
+
+    let swap_request = SwapRequest { inputs, outputs };
+
+    let http_client = HttpClient::new(MINT_URL.parse()?);
+    let response = http_client.post_swap(swap_request.clone()).await;
+
+    match response {
+        Err(err) => match err {
+            cdk::Error::DuplicateOutputs => (),
+            err => {
+                bail!(
+                    "Wrong mint error returned, expected duplicate outputs: {}",
+                    err.to_string()
+                );
+            }
+        },
+        Ok(_) => {
+            bail!("Should not have allow duplicate inputs");
+        }
+    }
+
+    Ok(())
+}
+
+/// Test duplicate proofs in melt
+#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+async fn test_fake_mint_duplicate_proofs_melt() -> Result<()> {
+    let wallet = Wallet::new(
+        MINT_URL,
+        CurrencyUnit::Sat,
+        Arc::new(WalletMemoryDatabase::default()),
+        &Mnemonic::generate(12)?.to_seed_normalized(""),
+        None,
+    )?;
+
+    let mint_quote = wallet.mint_quote(100.into(), None).await?;
+
+    wait_for_mint_to_be_paid(&wallet, &mint_quote.id, 60).await?;
+
+    let proofs = wallet.mint(&mint_quote.id, SplitTarget::None, None).await?;
+
+    let inputs = vec![proofs[0].clone(), proofs[0].clone()];
+
+    let invoice = create_fake_invoice(7000, "".to_string());
+
+    let melt_quote = wallet.melt_quote(invoice.to_string(), None).await?;
+
+    let melt_request = MeltBolt11Request {
+        quote: melt_quote.id,
+        inputs,
+        outputs: None,
+    };
+
+    let http_client = HttpClient::new(MINT_URL.parse()?);
+    let response = http_client.post_melt(melt_request.clone()).await;
+
+    match response {
+        Err(err) => match err {
+            cdk::Error::DuplicateInputs => (),
+            err => {
+                bail!("Wrong mint error returned: {}", err.to_string());
+            }
+        },
+        Ok(_) => {
+            bail!("Should not have allow duplicate inputs");
+        }
+    }
+
+    Ok(())
+}

+ 3 - 1
crates/cdk-integration-tests/tests/mint.rs

@@ -169,6 +169,8 @@ async fn test_attempt_to_swap_by_overflowing() -> Result<()> {
         Ok(_) => bail!("Swap occurred with overflow"),
         Err(err) => match err {
             cdk::Error::NUT03(cdk::nuts::nut03::Error::Amount(_)) => (),
+            cdk::Error::AmountOverflow => (),
+            cdk::Error::AmountError(_) => (),
             _ => {
                 println!("{:?}", err);
                 bail!("Wrong error returned in swap overflow")
@@ -288,7 +290,7 @@ pub async fn test_p2pk_swap() -> Result<()> {
 
     for keys in public_keys_to_listen {
         let statuses = msgs.remove(&keys).expect("some events");
-        assert_eq!(statuses, vec![State::Pending, State::Pending, State::Spent]);
+        assert_eq!(statuses, vec![State::Pending, State::Spent]);
     }
 
     assert!(listener.try_recv().is_err(), "no other event is happening");

+ 0 - 2
crates/cdk-integration-tests/tests/regtest.rs

@@ -96,8 +96,6 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> {
 
     lnd_client.pay_invoice(mint_quote.request).await.unwrap();
 
-    wait_for_mint_to_be_paid(&wallet, &mint_quote.id, 60).await?;
-
     let proofs = wallet
         .mint(&mint_quote.id, SplitTarget::default(), None)
         .await?;

+ 1 - 0
crates/cdk/src/mint/check_spendable.rs

@@ -57,6 +57,7 @@ impl Mint {
         }
 
         for public_key in ys {
+            tracing::debug!("proof: {} set to {}", public_key.to_hex(), proof_state);
             self.pubsub_manager.proof_state((*public_key, proof_state));
         }
 

+ 21 - 67
crates/cdk/src/mint/melt.rs

@@ -1,4 +1,3 @@
-use std::collections::HashSet;
 use std::str::FromStr;
 
 use anyhow::bail;
@@ -14,9 +13,10 @@ use super::{
 };
 use crate::amount::to_unit;
 use crate::cdk_lightning::{MintLightning, PayInvoiceResponse};
+use crate::mint::verification::Verification;
 use crate::mint::SigFlag;
 use crate::nuts::nut11::{enforce_sig_flag, EnforceSigFlag};
-use crate::nuts::{Id, MeltQuoteState};
+use crate::nuts::MeltQuoteState;
 use crate::types::LnKey;
 use crate::util::unix_time;
 use crate::{cdk_lightning, Amount, Error};
@@ -231,15 +231,6 @@ impl Mint {
                 .msat_to_pay
                 .ok_or(Error::InvoiceAmountUndefined)?,
         };
-        /*
-        let invoice_amount_msats: Amount = match melt_quote.msat_to_pay {
-            Some(amount) => amount,
-            None => invoice
-                .amount_milli_satoshis()
-                .ok_or(Error::InvoiceAmountUndefined)?
-                .into(),
-        };
-        */
 
         let partial_amount = match invoice_amount_msats > quote_msats {
             true => {
@@ -298,29 +289,18 @@ impl Mint {
             MeltQuoteState::Unknown => Err(Error::UnknownPaymentState),
         }?;
 
-        let ys = melt_request.inputs.ys()?;
-
-        // Ensure proofs are unique and not being double spent
-        if melt_request.inputs.len() != ys.iter().collect::<HashSet<_>>().len() {
-            return Err(Error::DuplicateProofs);
-        }
-
-        self.localstore
-            .add_proofs(melt_request.inputs.clone(), Some(melt_request.quote))
-            .await?;
-        self.check_ys_spendable(&ys, State::Pending).await?;
-
-        for proof in &melt_request.inputs {
-            self.verify_proof(proof).await?;
-        }
-
         let quote = self
             .localstore
             .get_melt_quote(&melt_request.quote)
             .await?
             .ok_or(Error::UnknownQuote)?;
 
-        let proofs_total = melt_request.proofs_amount()?;
+        let Verification {
+            amount: input_amount,
+            unit: input_unit,
+        } = self.verify_inputs(&melt_request.inputs).await?;
+
+        let input_ys = melt_request.inputs.ys()?;
 
         let fee = self.get_proofs_fee(&melt_request.inputs).await?;
 
@@ -328,33 +308,25 @@ impl Mint {
 
         // Check that the inputs proofs are greater then total.
         // Transaction does not need to be balanced as wallet may not want change.
-        if proofs_total < required_total {
+        if input_amount < required_total {
             tracing::info!(
                 "Swap request unbalanced: {}, outputs {}, fee {}",
-                proofs_total,
+                input_amount,
                 quote.amount,
                 fee
             );
             return Err(Error::TransactionUnbalanced(
-                proofs_total.into(),
+                input_amount.into(),
                 quote.amount.into(),
                 (fee + quote.fee_reserve).into(),
             ));
         }
 
-        let input_keyset_ids: HashSet<Id> =
-            melt_request.inputs.iter().map(|p| p.keyset_id).collect();
-
-        let mut keyset_units = HashSet::with_capacity(input_keyset_ids.capacity());
+        self.localstore
+            .add_proofs(melt_request.inputs.clone(), None)
+            .await?;
 
-        for id in input_keyset_ids {
-            let keyset = self
-                .localstore
-                .get_keyset_info(&id)
-                .await?
-                .ok_or(Error::UnknownKeySet)?;
-            keyset_units.insert(keyset.unit);
-        }
+        self.check_ys_spendable(&input_ys, State::Pending).await?;
 
         let EnforceSigFlag { sig_flag, .. } = enforce_sig_flag(melt_request.inputs.clone());
 
@@ -363,34 +335,16 @@ impl Mint {
         }
 
         if let Some(outputs) = &melt_request.outputs {
-            let output_keysets_ids: HashSet<Id> = outputs.iter().map(|b| b.keyset_id).collect();
-            for id in output_keysets_ids {
-                let keyset = self
-                    .localstore
-                    .get_keyset_info(&id)
-                    .await?
-                    .ok_or(Error::UnknownKeySet)?;
+            let Verification {
+                amount: _,
+                unit: output_unit,
+            } = self.verify_outputs(outputs).await?;
 
-                // Get the active keyset for the unit
-                let active_keyset_id = self
-                    .localstore
-                    .get_active_keyset_id(&keyset.unit)
-                    .await?
-                    .ok_or(Error::InactiveKeyset)?;
-
-                // Check output is for current active keyset
-                if id.ne(&active_keyset_id) {
-                    return Err(Error::InactiveKeyset);
-                }
-                keyset_units.insert(keyset.unit);
+            if input_unit != output_unit {
+                return Err(Error::UnsupportedUnit);
             }
         }
 
-        // Check that all input and output proofs are the same unit
-        if keyset_units.len().gt(&1) {
-            return Err(Error::UnsupportedUnit);
-        }
-
         tracing::debug!("Verified melt quote: {}", melt_request.quote);
         Ok(quote)
     }

+ 15 - 49
crates/cdk/src/mint/mint_nut04.rs

@@ -1,9 +1,7 @@
-use std::collections::HashSet;
-
-use cdk_common::Id;
 use tracing::instrument;
 use uuid::Uuid;
 
+use super::verification::Verification;
 use super::{
     nut04, CurrencyUnit, Mint, MintQuote, MintQuoteBolt11Request, MintQuoteBolt11Response,
     NotificationPayload, PaymentMethod, PublicKey,
@@ -303,8 +301,20 @@ impl Mint {
             mint_request.verify_signature(pubkey)?;
         }
 
+        let Verification { amount, unit } = match self.verify_outputs(&mint_request.outputs).await {
+            Ok(verification) => verification,
+            Err(err) => {
+                tracing::debug!("Could not verify mint outputs");
+                self.localstore
+                    .update_mint_quote_state(&mint_request.quote, MintQuoteState::Paid)
+                    .await?;
+
+                return Err(err);
+            }
+        };
+
         // We check the the total value of blinded messages == mint quote
-        if mint_request.total_amount()? != mint_quote.amount {
+        if amount != mint_quote.amount {
             return Err(Error::TransactionUnbalanced(
                 mint_quote.amount.into(),
                 mint_request.total_amount()?.into(),
@@ -312,54 +322,10 @@ impl Mint {
             ));
         }
 
-        let keyset_ids: HashSet<Id> = mint_request.outputs.iter().map(|b| b.keyset_id).collect();
-
-        let mut keyset_units = HashSet::new();
-
-        for keyset_id in keyset_ids {
-            let keyset = self.keyset(&keyset_id).await?.ok_or(Error::UnknownKeySet)?;
-
-            keyset_units.insert(keyset.unit);
-        }
-
-        if keyset_units.len() != 1 {
-            tracing::debug!("Client attempted to mint with outputs of multiple units");
+        if unit != mint_quote.unit {
             return Err(Error::UnsupportedUnit);
         }
 
-        if keyset_units.iter().next().expect("Checked len above") != &mint_quote.unit {
-            tracing::debug!("Client attempted to mint with unit not in quote");
-            return Err(Error::UnsupportedUnit);
-        }
-
-        let blinded_messages: Vec<PublicKey> = mint_request
-            .outputs
-            .iter()
-            .map(|b| b.blinded_secret)
-            .collect();
-
-        if self
-            .localstore
-            .get_blind_signatures(&blinded_messages)
-            .await?
-            .iter()
-            .flatten()
-            .next()
-            .is_some()
-        {
-            tracing::info!("Output has already been signed",);
-            tracing::info!(
-                "Mint {} did not succeed returning quote to Paid state",
-                mint_request.quote
-            );
-
-            self.localstore
-                .update_mint_quote_state(&mint_request.quote, MintQuoteState::Paid)
-                .await?;
-
-            return Err(Error::BlindedMessageAlreadySigned);
-        }
-
         let mut blind_signatures = Vec::with_capacity(mint_request.outputs.len());
 
         for blinded_message in mint_request.outputs.iter() {

+ 1 - 0
crates/cdk/src/mint/mod.rs

@@ -32,6 +32,7 @@ mod mint_nut04;
 mod start_up_check;
 pub mod subscription;
 mod swap;
+mod verification;
 
 pub use builder::{MintBuilder, MintMeltLimits};
 pub use cdk_common::mint::{MeltQuote, MintQuote};

+ 10 - 106
crates/cdk/src/mint/swap.rs

@@ -1,9 +1,7 @@
-use std::collections::HashSet;
-
 use tracing::instrument;
 
 use super::nut11::{enforce_sig_flag, EnforceSigFlag};
-use super::{Id, Mint, PublicKey, SigFlag, State, SwapRequest, SwapResponse};
+use super::{Mint, PublicKey, SigFlag, State, SwapRequest, SwapResponse};
 use crate::nuts::nut00::ProofsMethods;
 use crate::Error;
 
@@ -14,116 +12,22 @@ impl Mint {
         &self,
         swap_request: SwapRequest,
     ) -> Result<SwapResponse, Error> {
-        let blinded_messages: Vec<PublicKey> = swap_request
-            .outputs
-            .iter()
-            .map(|b| b.blinded_secret)
-            .collect();
+        let input_ys = swap_request.inputs.ys()?;
 
-        if self
-            .localstore
-            .get_blind_signatures(&blinded_messages)
-            .await?
-            .iter()
-            .flatten()
-            .next()
-            .is_some()
+        if let Err(err) = self
+            .verify_transaction_balanced(&swap_request.inputs, &swap_request.outputs)
+            .await
         {
-            tracing::info!("Output has already been signed",);
-
-            return Err(Error::BlindedMessageAlreadySigned);
-        }
-
-        let proofs_total = swap_request.input_amount()?;
-
-        let output_total = swap_request.output_amount()?;
-
-        let fee = self.get_proofs_fee(&swap_request.inputs).await?;
-
-        let total_with_fee = output_total.checked_add(fee).ok_or(Error::AmountOverflow)?;
-
-        if proofs_total != total_with_fee {
-            tracing::info!(
-                "Swap request unbalanced: {}, outputs {}, fee {}",
-                proofs_total,
-                output_total,
-                fee
-            );
-            return Err(Error::TransactionUnbalanced(
-                proofs_total.into(),
-                output_total.into(),
-                fee.into(),
-            ));
-        }
-
-        let proof_count = swap_request.inputs.len();
-
-        let input_ys = swap_request.inputs.ys()?;
+            tracing::debug!("Attempt to swap unbalanced transaction: {}", err);
+            self.localstore.remove_proofs(&input_ys, None).await?;
+            return Err(err);
+        };
 
         self.localstore
             .add_proofs(swap_request.inputs.clone(), None)
             .await?;
-        self.check_ys_spendable(&input_ys, State::Pending).await?;
-
-        // Check that there are no duplicate proofs in request
-        if input_ys
-            .iter()
-            .collect::<HashSet<&PublicKey>>()
-            .len()
-            .ne(&proof_count)
-        {
-            self.localstore.remove_proofs(&input_ys, None).await?;
-            return Err(Error::DuplicateProofs);
-        }
-
-        for proof in &swap_request.inputs {
-            if let Err(err) = self.verify_proof(proof).await {
-                tracing::info!("Error verifying proof in swap");
-                self.localstore.remove_proofs(&input_ys, None).await?;
-                return Err(err);
-            }
-        }
-
-        let input_keyset_ids: HashSet<Id> =
-            swap_request.inputs.iter().map(|p| p.keyset_id).collect();
-
-        let mut keyset_units = HashSet::with_capacity(input_keyset_ids.capacity());
-
-        for id in input_keyset_ids {
-            match self.localstore.get_keyset_info(&id).await? {
-                Some(keyset) => {
-                    keyset_units.insert(keyset.unit);
-                }
-                None => {
-                    tracing::info!("Swap request with unknown keyset in inputs");
-                    self.localstore.remove_proofs(&input_ys, None).await?;
-                }
-            }
-        }
 
-        let output_keyset_ids: HashSet<Id> =
-            swap_request.outputs.iter().map(|p| p.keyset_id).collect();
-
-        for id in &output_keyset_ids {
-            match self.localstore.get_keyset_info(id).await? {
-                Some(keyset) => {
-                    keyset_units.insert(keyset.unit);
-                }
-                None => {
-                    tracing::info!("Swap request with unknown keyset in outputs");
-                    self.localstore.remove_proofs(&input_ys, None).await?;
-                }
-            }
-        }
-
-        // Check that all proofs are the same unit
-        // in the future it maybe possible to support multiple units but unsupported for
-        // now
-        if keyset_units.len().gt(&1) {
-            tracing::error!("Only one unit is allowed in request: {:?}", keyset_units);
-            self.localstore.remove_proofs(&input_ys, None).await?;
-            return Err(Error::UnsupportedUnit);
-        }
+        self.check_ys_spendable(&input_ys, State::Pending).await?;
 
         let EnforceSigFlag {
             sig_flag,

+ 215 - 0
crates/cdk/src/mint/verification.rs

@@ -0,0 +1,215 @@
+use std::collections::HashSet;
+
+use cdk_common::{Amount, BlindedMessage, CurrencyUnit, Id, Proofs, ProofsMethods, PublicKey};
+
+use super::{Error, Mint};
+
+#[derive(Debug, Clone, Hash, PartialEq, Eq)]
+pub struct Verification {
+    pub amount: Amount,
+    pub unit: CurrencyUnit,
+}
+
+impl Mint {
+    /// Verify that the inputs to the transaction are unique
+    pub fn check_inputs_unique(inputs: &Proofs) -> Result<(), Error> {
+        let proof_count = inputs.len();
+
+        if inputs
+            .iter()
+            .map(|i| i.y())
+            .collect::<Result<HashSet<PublicKey>, _>>()?
+            .len()
+            .ne(&proof_count)
+        {
+            return Err(Error::DuplicateInputs);
+        }
+
+        Ok(())
+    }
+
+    /// Verify that the outputs to are unique
+    pub fn check_outputs_unique(outputs: &[BlindedMessage]) -> Result<(), Error> {
+        let output_count = outputs.len();
+
+        if outputs
+            .iter()
+            .map(|o| &o.blinded_secret)
+            .collect::<HashSet<&PublicKey>>()
+            .len()
+            .ne(&output_count)
+        {
+            return Err(Error::DuplicateOutputs);
+        }
+
+        Ok(())
+    }
+
+    /// Verify output keyset
+    ///
+    /// Checks that the outputs are all of the same unit and the keyset is active
+    pub async fn verify_outputs_keyset(
+        &self,
+        outputs: &[BlindedMessage],
+    ) -> Result<CurrencyUnit, Error> {
+        let mut keyset_units = HashSet::new();
+
+        let output_keyset_ids: HashSet<Id> = outputs.iter().map(|p| p.keyset_id).collect();
+
+        for id in &output_keyset_ids {
+            match self.localstore.get_keyset_info(id).await? {
+                Some(keyset) => {
+                    if !keyset.active {
+                        return Err(Error::InactiveKeyset);
+                    }
+                    keyset_units.insert(keyset.unit);
+                }
+                None => {
+                    tracing::info!("Swap request with unknown keyset in outputs");
+                    return Err(Error::UnknownKeySet);
+                }
+            }
+        }
+
+        // Check that all proofs are the same unit
+        // in the future it maybe possible to support multiple units but unsupported for
+        // now
+        if keyset_units.len() != 1 {
+            tracing::error!("Only one unit is allowed in request: {:?}", keyset_units);
+            return Err(Error::MultipleUnits);
+        }
+
+        Ok(keyset_units
+            .into_iter()
+            .next()
+            .expect("Length is check above"))
+    }
+
+    /// Verify input keyset
+    ///
+    /// Checks that the inputs are all of the same unit
+    pub async fn verify_inputs_keyset(&self, inputs: &Proofs) -> Result<CurrencyUnit, Error> {
+        let mut keyset_units = HashSet::new();
+
+        let inputs_keyset_ids: HashSet<Id> = inputs.iter().map(|p| p.keyset_id).collect();
+
+        for id in &inputs_keyset_ids {
+            match self.localstore.get_keyset_info(id).await? {
+                Some(keyset) => {
+                    keyset_units.insert(keyset.unit);
+                }
+                None => {
+                    tracing::info!("Swap request with unknown keyset in outputs");
+                    return Err(Error::UnknownKeySet);
+                }
+            }
+        }
+
+        // Check that all proofs are the same unit
+        // in the future it maybe possible to support multiple units but unsupported for
+        // now
+        if keyset_units.len() != 1 {
+            tracing::error!("Only one unit is allowed in request: {:?}", keyset_units);
+            return Err(Error::MultipleUnits);
+        }
+
+        Ok(keyset_units
+            .into_iter()
+            .next()
+            .expect("Length is check above"))
+    }
+
+    /// Verifies that the outputs have not already been signed
+    pub async fn check_output_already_signed(
+        &self,
+        outputs: &[BlindedMessage],
+    ) -> Result<(), Error> {
+        let blinded_messages: Vec<PublicKey> = outputs.iter().map(|o| o.blinded_secret).collect();
+
+        if self
+            .localstore
+            .get_blind_signatures(&blinded_messages)
+            .await?
+            .iter()
+            .flatten()
+            .next()
+            .is_some()
+        {
+            tracing::info!("Output has already been signed",);
+
+            return Err(Error::BlindedMessageAlreadySigned);
+        }
+
+        Ok(())
+    }
+
+    /// Verifies outputs
+    /// Checks outputs are unique, of the same unit and not signed before
+    pub async fn verify_outputs(&self, outputs: &[BlindedMessage]) -> Result<Verification, Error> {
+        Mint::check_outputs_unique(outputs)?;
+        self.check_output_already_signed(outputs).await?;
+
+        let unit = self.verify_outputs_keyset(outputs).await?;
+
+        let amount = Amount::try_sum(outputs.iter().map(|o| o.amount).collect::<Vec<Amount>>())?;
+
+        Ok(Verification { amount, unit })
+    }
+
+    /// Verifies inputs
+    /// Checks that inputs are unique and of the same unit
+    /// **NOTE: This does not check if inputs have been spent
+    pub async fn verify_inputs(&self, inputs: &Proofs) -> Result<Verification, Error> {
+        Mint::check_inputs_unique(inputs)?;
+        let unit = self.verify_inputs_keyset(inputs).await?;
+        let amount = inputs.total_amount()?;
+
+        for proof in inputs {
+            self.verify_proof(proof).await?;
+        }
+
+        Ok(Verification { amount, unit })
+    }
+
+    /// Verify that inputs and outputs are valid and balanced
+    pub async fn verify_transaction_balanced(
+        &self,
+        inputs: &Proofs,
+        outputs: &[BlindedMessage],
+    ) -> Result<(), Error> {
+        let output_verification = self.verify_outputs(outputs).await.map_err(|err| {
+            tracing::debug!("Output verification failed: {:?}", err);
+            err
+        })?;
+        let input_verification = self.verify_inputs(inputs).await.map_err(|err| {
+            tracing::debug!("Input verification failed: {:?}", err);
+            err
+        })?;
+
+        if output_verification.unit != input_verification.unit {
+            tracing::debug!(
+                "Output unit {} does not match input unit {}",
+                output_verification.unit,
+                input_verification.unit
+            );
+            return Err(Error::MultipleUnits);
+        }
+
+        let fees = self.get_proofs_fee(inputs).await?;
+
+        if output_verification.amount
+            != input_verification
+                .amount
+                .checked_sub(fees)
+                .ok_or(Error::AmountOverflow)?
+        {
+            return Err(Error::TransactionUnbalanced(
+                input_verification.amount.into(),
+                output_verification.amount.into(),
+                fees.into(),
+            ));
+        }
+
+        Ok(())
+    }
+}

+ 1 - 2
flake.nix

@@ -67,9 +67,8 @@
 
         # Nightly used for formatting
         nightly_toolchain = pkgs.rust-bin.selectLatestNightlyWith (toolchain: toolchain.default.override {
+          extensions = [ "rustfmt" "clippy" "rust-analyzer" "rust-src" ];
           targets = [ "wasm32-unknown-unknown" ]; # wasm
-          extensions = [ "rustfmt" "clippy" "rust-src" "rust-analyzer" ];
-
         });
 
         # Common inputs