Browse Source

refactor: calc fee

thesimplekid 7 months ago
parent
commit
9b78305722

+ 2 - 2
crates/cdk-axum/src/router_handlers.rs

@@ -317,7 +317,7 @@ pub async fn post_melt_bolt11(
                         .map_err(|_| into_response(Error::UnitUnsupported))?,
                 };
 
-                if amount_to_pay + quote.fee_reserve > inputs_amount_quote_unit {
+                if amount_to_pay + quote.fee_reserve != inputs_amount_quote_unit {
                     tracing::debug!(
                         "Not enough inuts provided: {} msats needed {} msats",
                         inputs_amount_quote_unit,
@@ -327,7 +327,7 @@ pub async fn post_melt_bolt11(
                     if let Err(err) = state.mint.process_unpaid_melt(&payload).await {
                         tracing::error!("Could not reset melt quote state: {}", err);
                     }
-                    return Err(into_response(Error::InsufficientInputs(
+                    return Err(into_response(Error::TransactionUnbalanced(
                         inputs_amount_quote_unit.into(),
                         amount_to_pay.into(),
                         quote.fee_reserve.into(),

+ 8 - 5
crates/cdk-integration-tests/src/lib.rs

@@ -53,6 +53,7 @@ pub async fn start_mint(
         LnKey,
         Arc<dyn MintLightning<Err = cdk::cdk_lightning::Error> + Sync + Send>,
     >,
+    supported_units: HashMap<CurrencyUnit, (u64, u8)>,
 ) -> Result<()> {
     let nuts = Nuts::new()
         .nut07(true)
@@ -67,9 +68,6 @@ pub async fn start_mint(
 
     let mnemonic = Mnemonic::generate(12)?;
 
-    let mut supported_units = HashMap::new();
-    supported_units.insert(CurrencyUnit::Sat, (0, 64));
-
     let mint = Mint::new(
         MINT_URL,
         &mnemonic.to_seed_normalized(""),
@@ -150,7 +148,11 @@ async fn handle_paid_invoice(mint: Arc<Mint>, request_lookup_id: &str) -> Result
     Ok(())
 }
 
-pub async fn wallet_mint(wallet: Arc<Wallet>, amount: Amount) -> Result<()> {
+pub async fn wallet_mint(
+    wallet: Arc<Wallet>,
+    amount: Amount,
+    split_target: SplitTarget,
+) -> Result<()> {
     let quote = wallet.mint_quote(amount).await?;
 
     loop {
@@ -163,7 +165,8 @@ pub async fn wallet_mint(wallet: Arc<Wallet>, amount: Amount) -> Result<()> {
 
         sleep(Duration::from_secs(2)).await;
     }
-    let receive_amount = wallet.mint(&quote.id, SplitTarget::default(), None).await?;
+
+    let receive_amount = wallet.mint(&quote.id, split_target, None).await?;
 
     println!("Minted: {}", receive_amount);
 

+ 106 - 0
crates/cdk-integration-tests/tests/fees.rs

@@ -0,0 +1,106 @@
+//! Test calc fee
+
+use std::collections::HashMap;
+use std::str::FromStr;
+use std::sync::Arc;
+use std::time::Duration;
+
+use anyhow::Result;
+use bip39::Mnemonic;
+use cdk::amount::SplitTarget;
+use cdk::cdk_database::WalletMemoryDatabase;
+use cdk::mint_url::MintUrl;
+use cdk::nuts::CurrencyUnit;
+use cdk::Wallet;
+use cdk_integration_tests::{create_backends_fake_wallet, start_mint, wallet_mint, MINT_URL};
+
+#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+pub async fn test_mint_fee() -> Result<()> {
+    tokio::spawn(async move {
+        let ln_backends = create_backends_fake_wallet();
+
+        let mut supported_units = HashMap::new();
+        supported_units.insert(CurrencyUnit::Sat, (1, 32));
+
+        start_mint(ln_backends, supported_units)
+            .await
+            .expect("Could not start mint")
+    });
+
+    tokio::time::sleep(Duration::from_millis(500)).await;
+
+    let mnemonic = Mnemonic::generate(12)?;
+
+    let wallet = Wallet::new(
+        MINT_URL,
+        CurrencyUnit::Sat,
+        Arc::new(WalletMemoryDatabase::default()),
+        &mnemonic.to_seed_normalized(""),
+        None,
+    )?;
+
+    let wallet = Arc::new(wallet);
+
+    wallet_mint(
+        Arc::clone(&wallet),
+        10000.into(),
+        SplitTarget::Value(1.into()),
+    )
+    .await
+    .unwrap();
+    println!("Minted");
+
+    let proofs = wallet
+        .localstore
+        .get_proofs(Some(MintUrl::from_str(MINT_URL)?), None, None, None)
+        .await?;
+
+    let proofs: Vec<cdk::nuts::Proof> = proofs.into_iter().map(|p| p.proof).collect();
+
+    let five_proofs = proofs[..5].to_vec();
+
+    let fee = wallet.get_proofs_fee(&five_proofs).await?;
+
+    // Check wallet gets fee calc correct
+    assert_eq!(fee, 1.into());
+
+    let _swap = wallet
+        .swap(None, SplitTarget::Value(1.into()), five_proofs, None, false)
+        .await?;
+
+    let wallet_bal = wallet.total_balance().await?;
+
+    // Check 1 sat was paid in fees for the swap
+    assert_eq!(wallet_bal, 9999.into());
+
+    let proofs = wallet
+        .localstore
+        .get_proofs(Some(MintUrl::from_str(MINT_URL)?), None, None, None)
+        .await?;
+
+    let proofs: Vec<cdk::nuts::Proof> = proofs.into_iter().map(|p| p.proof).collect();
+
+    let thousand_proofs = proofs[..1001].to_vec();
+
+    let fee = wallet.get_proofs_fee(&thousand_proofs).await?;
+
+    // Check wallet gets fee calc correct
+    assert_eq!(fee, 2.into());
+
+    let _swap = wallet
+        .swap(
+            None,
+            SplitTarget::Value(1.into()),
+            thousand_proofs,
+            None,
+            false,
+        )
+        .await?;
+
+    let wallet_bal = wallet.total_balance().await?;
+
+    // Check 1 sat was paid in fees for the swap
+    assert_eq!(wallet_bal, 9997.into());
+
+    Ok(())
+}

+ 9 - 4
crates/cdk-integration-tests/tests/mint.rs

@@ -1,5 +1,6 @@
 //! Mint integration tests
 
+use std::collections::HashMap;
 use std::sync::Arc;
 use std::time::Duration;
 
@@ -18,7 +19,12 @@ pub async fn test_mint_double_receive() -> Result<()> {
     tokio::spawn(async move {
         let ln_backends = create_backends_fake_wallet();
 
-        start_mint(ln_backends).await.expect("Could not start mint")
+        let mut supported_units = HashMap::new();
+        supported_units.insert(CurrencyUnit::Sat, (0, 64));
+
+        start_mint(ln_backends, supported_units)
+            .await
+            .expect("Could not start mint")
     });
 
     tokio::time::sleep(Duration::from_millis(500)).await;
@@ -35,7 +41,7 @@ pub async fn test_mint_double_receive() -> Result<()> {
 
     let wallet = Arc::new(wallet);
 
-    wallet_mint(Arc::clone(&wallet), 100.into()).await.unwrap();
+    wallet_mint(Arc::clone(&wallet), 100.into(), SplitTarget::default()).await?;
     println!("Minted");
 
     let token = wallet
@@ -47,8 +53,7 @@ pub async fn test_mint_double_receive() -> Result<()> {
             &SendKind::default(),
             false,
         )
-        .await
-        .unwrap();
+        .await?;
 
     let mnemonic = Mnemonic::generate(12)?;
 

+ 8 - 2
crates/cdk-integration-tests/tests/overflow.rs

@@ -1,9 +1,10 @@
+use std::collections::HashMap;
 use std::time::Duration;
 
 use anyhow::{bail, Result};
 use cdk::amount::SplitTarget;
 use cdk::dhke::construct_proofs;
-use cdk::nuts::{PreMintSecrets, SwapRequest};
+use cdk::nuts::{CurrencyUnit, PreMintSecrets, SwapRequest};
 use cdk::Amount;
 use cdk::HttpClient;
 use cdk_integration_tests::{create_backends_fake_wallet, mint_proofs, start_mint, MINT_URL};
@@ -98,7 +99,12 @@ pub async fn test_overflow() -> Result<()> {
     tokio::spawn(async move {
         let ln_backends = create_backends_fake_wallet();
 
-        start_mint(ln_backends).await.expect("Could not start mint")
+        let mut supported_units = HashMap::new();
+        supported_units.insert(CurrencyUnit::Sat, (0, 32));
+
+        start_mint(ln_backends, supported_units)
+            .await
+            .expect("Could not start mint")
     });
 
     // Wait for mint server to start

+ 8 - 2
crates/cdk-integration-tests/tests/p2pk.rs

@@ -1,3 +1,4 @@
+use std::collections::HashMap;
 use std::sync::Arc;
 use std::time::Duration;
 
@@ -15,7 +16,12 @@ pub async fn test_p2pk_swap() -> Result<()> {
     tokio::spawn(async move {
         let ln_backends = create_backends_fake_wallet();
 
-        start_mint(ln_backends).await.expect("Could not start mint")
+        let mut supported_units = HashMap::new();
+        supported_units.insert(CurrencyUnit::Sat, (0, 32));
+
+        start_mint(ln_backends, supported_units)
+            .await
+            .expect("Could not start mint")
     });
     tokio::time::sleep(Duration::from_millis(500)).await;
 
@@ -32,7 +38,7 @@ pub async fn test_p2pk_swap() -> Result<()> {
     let wallet = Arc::new(wallet);
 
     // Mint 100 sats for the wallet
-    wallet_mint(Arc::clone(&wallet), 100.into()).await?;
+    wallet_mint(Arc::clone(&wallet), 100.into(), SplitTarget::default()).await?;
 
     let secret = SecretKey::generate();
 

+ 8 - 2
crates/cdk-integration-tests/tests/unbalanced.rs

@@ -1,10 +1,11 @@
 //! Test that if a wallet attempts to swap for less outputs then inputs correct error is returned
 
+use std::collections::HashMap;
 use std::time::Duration;
 
 use anyhow::{bail, Result};
 use cdk::amount::SplitTarget;
-use cdk::nuts::{PreMintSecrets, SwapRequest};
+use cdk::nuts::{CurrencyUnit, PreMintSecrets, SwapRequest};
 use cdk::Error;
 use cdk::HttpClient;
 use cdk_integration_tests::{create_backends_fake_wallet, mint_proofs, start_mint, MINT_URL};
@@ -14,7 +15,12 @@ pub async fn test_unbalanced_swap() -> Result<()> {
     tokio::spawn(async move {
         let ln_backends = create_backends_fake_wallet();
 
-        start_mint(ln_backends).await.expect("Could not start mint")
+        let mut supported_units = HashMap::new();
+        supported_units.insert(CurrencyUnit::Sat, (0, 32));
+
+        start_mint(ln_backends, supported_units)
+            .await
+            .expect("Could not start mint")
     });
 
     // Wait for mint server to start

+ 113 - 0
crates/cdk-integration-tests/tests/wrong_fee.rs

@@ -0,0 +1,113 @@
+//! Fee tests for over and underpaying
+
+use std::collections::HashMap;
+use std::str::FromStr;
+use std::sync::Arc;
+use std::time::Duration;
+
+use anyhow::{bail, Result};
+use bip39::Mnemonic;
+use cdk::cdk_database::WalletMemoryDatabase;
+use cdk::mint_url::MintUrl;
+use cdk::nuts::{CurrencyUnit, SwapRequest};
+use cdk::wallet::client::HttpClient;
+use cdk::Wallet;
+use cdk::{amount::SplitTarget, nuts::PreMintSecrets};
+use cdk_integration_tests::{create_backends_fake_wallet, start_mint, wallet_mint, MINT_URL};
+
+#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
+pub async fn test_swap_overpay_underpay() -> Result<()> {
+    tokio::spawn(async move {
+        let ln_backends = create_backends_fake_wallet();
+
+        let mut supported_units = HashMap::new();
+        supported_units.insert(CurrencyUnit::Sat, (1, 32));
+
+        start_mint(ln_backends, supported_units)
+            .await
+            .expect("Could not start mint")
+    });
+
+    tokio::time::sleep(Duration::from_millis(500)).await;
+
+    let mnemonic = Mnemonic::generate(12)?;
+
+    let wallet = Wallet::new(
+        MINT_URL,
+        CurrencyUnit::Sat,
+        Arc::new(WalletMemoryDatabase::default()),
+        &mnemonic.to_seed_normalized(""),
+        None,
+    )?;
+
+    let wallet = Arc::new(wallet);
+
+    wallet_mint(
+        Arc::clone(&wallet),
+        10000.into(),
+        SplitTarget::Value(1.into()),
+    )
+    .await
+    .unwrap();
+    println!("Minted");
+
+    let proofs = wallet
+        .localstore
+        .get_proofs(Some(MintUrl::from_str(MINT_URL)?), None, None, None)
+        .await?;
+
+    let proofs: Vec<cdk::nuts::Proof> = proofs.into_iter().map(|p| p.proof).collect();
+
+    let keyset_id = proofs.first().unwrap().keyset_id;
+
+    let pre_swap_proofs = proofs[..1000].to_vec();
+
+    // Attempt to swap while overpaying fee
+
+    let pre_swap_secret = PreMintSecrets::random(keyset_id, 450.into(), &SplitTarget::default())?;
+
+    let swap_request =
+        SwapRequest::new(pre_swap_proofs.clone(), pre_swap_secret.blinded_messages());
+
+    let wallet_client = HttpClient::new();
+
+    match wallet_client
+        .post_swap(MINT_URL.parse()?, swap_request)
+        .await
+    {
+        Ok(_res) => {
+            bail!("Swap should not have succeeded");
+        }
+        Err(err) => match err {
+            cdk::error::Error::TransactionUnbalanced(_, _, _) => (),
+            _ => {
+                println!("{:?}", err);
+                bail!("Swap returned the wrong error when overpaying fee");
+            }
+        },
+    };
+
+    // Attempt to swap while underpaying fee
+
+    let pre_swap_secret = PreMintSecrets::random(keyset_id, 1000.into(), &SplitTarget::default())?;
+    let swap_request =
+        SwapRequest::new(pre_swap_proofs.clone(), pre_swap_secret.blinded_messages());
+    match wallet_client
+        .post_swap(MINT_URL.parse()?, swap_request)
+        .await
+    {
+        Ok(_res) => {
+            bail!("Swap should not have succeeded");
+        }
+        // In the context of this test an error response here is good.
+        // It means the mint does not allow us to swap for more then we should by overflowing
+        Err(err) => match err {
+            cdk::error::Error::TransactionUnbalanced(_, _, _) => (),
+            _ => {
+                println!("{:?}", err);
+                bail!("Swap returned the wrong error when underpaying fee");
+            }
+        },
+    };
+    Ok(())
+}

+ 13 - 10
crates/cdk/src/error.rs

@@ -8,7 +8,7 @@ use thiserror::Error;
 
 #[cfg(feature = "wallet")]
 use crate::wallet::multi_mint_wallet::WalletKey;
-use crate::{util::hex, Amount};
+use crate::{nuts::Id, util::hex, Amount};
 
 /// CDK Error
 #[derive(Debug, Error)]
@@ -16,6 +16,12 @@ pub enum Error {
     /// Mint does not have a key for amount
     #[error("No Key for Amount")]
     AmountKey,
+    /// Keyset is not known
+    #[error("Keyset id not known: `{0}`")]
+    KeysetUnknown(Id),
+    /// Unsupported unit
+    #[error("Unit unsupported")]
+    UnsupportedUnit,
     /// Payment failed
     #[error("Payment failed")]
     PaymentFailed,
@@ -72,9 +78,6 @@ pub enum Error {
     /// Inactive Keyset
     #[error("Inactive Keyset")]
     InactiveKeyset,
-    /// Not engough inputs provided
-    #[error("Inputs: `{0}`, Outputs: `{1}`, Expected Fee: `{2}`")]
-    InsufficientInputs(u64, u64, u64),
     /// Transaction unbalanced
     #[error("Inputs: `{0}`, Outputs: `{1}`, Expected Fee: `{2}`")]
     TransactionUnbalanced(u64, u64, u64),
@@ -286,7 +289,12 @@ impl ErrorResponse {
 impl From<Error> for ErrorResponse {
     fn from(err: Error) -> ErrorResponse {
         match err {
-            Error::UnitUnsupported => ErrorResponse {
+            Error::TokenAlreadySpent => ErrorResponse {
+                code: ErrorCode::TokenAlreadySpent,
+                error: Some(err.to_string()),
+                detail: None,
+            },
+            Error::UnsupportedUnit => ErrorResponse {
                 code: ErrorCode::UnitUnsupported,
                 error: Some(err.to_string()),
                 detail: None,
@@ -301,11 +309,6 @@ impl From<Error> for ErrorResponse {
                 error: Some("Invoice already paid.".to_string()),
                 detail: None,
             },
-            Error::TokenAlreadySpent => ErrorResponse {
-                code: ErrorCode::TokenAlreadySpent,
-                error: Some("Token is already spent.".to_string()),
-                detail: None,
-            },
             Error::TransactionUnbalanced(inputs_total, outputs_total, fee_expected) => {
                 ErrorResponse {
                     code: ErrorCode::TransactionUnbalanced,

+ 89 - 0
crates/cdk/src/fees.rs

@@ -0,0 +1,89 @@
+//! Calculate fees
+//!
+//! <https://github.com/cashubtc/nuts/blob/main/02.md>
+
+use std::collections::HashMap;
+
+use tracing::instrument;
+
+use crate::error::Error;
+use crate::nuts::Id;
+use crate::Amount;
+
+/// Fee required for proof set
+#[instrument(skip_all)]
+pub fn calculate_fee(
+    proofs_count: &HashMap<Id, u64>,
+    keyset_fee: &HashMap<Id, u64>,
+) -> Result<Amount, Error> {
+    let mut sum_fee = 0;
+
+    for (keyset_id, proof_count) in proofs_count {
+        let keyset_fee_ppk = keyset_fee
+            .get(keyset_id)
+            .ok_or(Error::KeysetUnknown(*keyset_id))?;
+
+        let proofs_fee = keyset_fee_ppk * proof_count;
+
+        sum_fee += proofs_fee;
+        println!("{}", sum_fee);
+    }
+
+    println!("{}", sum_fee);
+
+    let fee = (sum_fee + 999) / 1000;
+
+    Ok(fee.into())
+}
+
+#[cfg(test)]
+mod tests {
+
+    use std::str::FromStr;
+
+    use super::*;
+
+    #[test]
+    fn test_calc_fee() -> anyhow::Result<()> {
+        let keyset_id = Id::from_str("001711afb1de20cb").unwrap();
+
+        let fee = 2;
+
+        let mut keyset_fees = HashMap::new();
+        keyset_fees.insert(keyset_id, fee);
+
+        let mut proofs_count = HashMap::new();
+
+        proofs_count.insert(keyset_id, 1);
+
+        let sum_fee = calculate_fee(&proofs_count, &keyset_fees)?;
+
+        assert_eq!(sum_fee, 1.into());
+
+        proofs_count.insert(keyset_id, 500);
+
+        let sum_fee = calculate_fee(&proofs_count, &keyset_fees)?;
+
+        assert_eq!(sum_fee, 1.into());
+
+        proofs_count.insert(keyset_id, 1000);
+
+        let sum_fee = calculate_fee(&proofs_count, &keyset_fees)?;
+
+        assert_eq!(sum_fee, 2.into());
+
+        proofs_count.insert(keyset_id, 2000);
+        let sum_fee = calculate_fee(&proofs_count, &keyset_fees)?;
+        assert_eq!(sum_fee, 4.into());
+
+        proofs_count.insert(keyset_id, 3500);
+        let sum_fee = calculate_fee(&proofs_count, &keyset_fees)?;
+        assert_eq!(sum_fee, 7.into());
+
+        proofs_count.insert(keyset_id, 3501);
+        let sum_fee = calculate_fee(&proofs_count, &keyset_fees)?;
+        assert_eq!(sum_fee, 8.into());
+
+        Ok(())
+    }
+}

+ 2 - 0
crates/cdk/src/lib.rs

@@ -20,6 +20,8 @@ pub mod util;
 #[cfg(feature = "wallet")]
 pub mod wallet;
 
+pub mod fees;
+
 #[doc(hidden)]
 pub use bitcoin::secp256k1;
 #[doc(hidden)]

+ 23 - 28
crates/cdk/src/mint/mod.rs

@@ -15,6 +15,7 @@ use self::nut11::EnforceSigFlag;
 use crate::cdk_database::{self, MintDatabase};
 use crate::dhke::{hash_to_curve, sign_message, verify_message};
 use crate::error::Error;
+use crate::fees::calculate_fee;
 use crate::mint_url::MintUrl;
 use crate::nuts::nut11::enforce_sig_flag;
 use crate::nuts::*;
@@ -414,21 +415,30 @@ impl Mint {
     /// Fee required for proof set
     #[instrument(skip_all)]
     pub async fn get_proofs_fee(&self, proofs: &Proofs) -> Result<Amount, Error> {
-        let mut sum_fee = 0;
+        let mut proofs_per_keyset = HashMap::new();
+        let mut fee_per_keyset = HashMap::new();
 
         for proof in proofs {
-            let input_fee_ppk = self
-                .localstore
-                .get_keyset_info(&proof.keyset_id)
-                .await?
-                .ok_or(Error::UnknownKeySet)?;
+            if let std::collections::hash_map::Entry::Vacant(e) =
+                fee_per_keyset.entry(proof.keyset_id)
+            {
+                let mint_keyset_info = self
+                    .localstore
+                    .get_keyset_info(&proof.keyset_id)
+                    .await?
+                    .ok_or(Error::UnknownKeySet)?;
+                e.insert(mint_keyset_info.input_fee_ppk);
+            }
 
-            sum_fee += input_fee_ppk.input_fee_ppk;
+            proofs_per_keyset
+                .entry(proof.keyset_id)
+                .and_modify(|count| *count += 1)
+                .or_insert(1);
         }
 
-        let fee = (sum_fee + 999) / 1000;
+        let fee = calculate_fee(&proofs_per_keyset, &fee_per_keyset)?;
 
-        Ok(Amount::from(fee))
+        Ok(fee)
     }
 
     /// Check melt quote status
@@ -740,20 +750,6 @@ impl Mint {
 
         let total_with_fee = output_total.checked_add(fee).ok_or(Error::AmountOverflow)?;
 
-        if proofs_total < total_with_fee {
-            tracing::info!(
-                "Swap request without enough inputs: {}, outputs {}, fee {}",
-                proofs_total,
-                output_total,
-                fee
-            );
-            return Err(Error::InsufficientInputs(
-                proofs_total.into(),
-                output_total.into(),
-                fee.into(),
-            ));
-        }
-
         if proofs_total != total_with_fee {
             tracing::info!(
                 "Swap request unbalanced: {}, outputs {}, fee {}",
@@ -1038,17 +1034,16 @@ impl Mint {
 
         let required_total = quote.amount + quote.fee_reserve + fee;
 
-        if proofs_total < required_total {
+        if proofs_total != required_total {
             tracing::info!(
-                "Swap request without enough inputs: {}, quote amount {}, fee_reserve: {} fee {}",
+                "Swap request unbalanced: {}, outputs {}, fee {}",
                 proofs_total,
                 quote.amount,
-                quote.fee_reserve,
                 fee
             );
-            return Err(Error::InsufficientInputs(
+            return Err(Error::TransactionUnbalanced(
                 proofs_total.into(),
-                (quote.amount + quote.fee_reserve).into(),
+                quote.amount.into(),
                 fee.into(),
             ));
         }

+ 19 - 9
crates/cdk/src/wallet/mod.rs

@@ -15,6 +15,7 @@ use crate::amount::SplitTarget;
 use crate::cdk_database::{self, WalletDatabase};
 use crate::dhke::{construct_proofs, hash_to_curve};
 use crate::error::Error;
+use crate::fees::calculate_fee;
 use crate::mint_url::MintUrl;
 use crate::nuts::nut00::token::Token;
 use crate::nuts::{
@@ -100,21 +101,30 @@ impl Wallet {
     /// Fee required for proof set
     #[instrument(skip_all)]
     pub async fn get_proofs_fee(&self, proofs: &Proofs) -> Result<Amount, Error> {
-        let mut sum_fee = 0;
+        let mut proofs_per_keyset = HashMap::new();
+        let mut fee_per_keyset = HashMap::new();
 
         for proof in proofs {
-            let input_fee_ppk = self
-                .localstore
-                .get_keyset_by_id(&proof.keyset_id)
-                .await?
-                .ok_or(Error::UnknownKeySet)?;
+            if let std::collections::hash_map::Entry::Vacant(e) =
+                fee_per_keyset.entry(proof.keyset_id)
+            {
+                let mint_keyset_info = self
+                    .localstore
+                    .get_keyset_by_id(&proof.keyset_id)
+                    .await?
+                    .ok_or(Error::UnknownKeySet)?;
+                e.insert(mint_keyset_info.input_fee_ppk);
+            }
 
-            sum_fee += input_fee_ppk.input_fee_ppk;
+            proofs_per_keyset
+                .entry(proof.keyset_id)
+                .and_modify(|count| *count += 1)
+                .or_insert(1);
         }
 
-        let fee = (sum_fee + 999) / 1000;
+        let fee = calculate_fee(&proofs_per_keyset, &fee_per_keyset)?;
 
-        Ok(Amount::from(fee))
+        Ok(fee)
     }
 
     /// Get fee for count of proofs in a keyset