浏览代码

Use `RwLock` instead of `Mutex`.

A mutex is too aggressive when the shared object can be accessed read-only most
of the time instead of requiring exclusive access.
Cesar Rodas 1 周之前
父节点
当前提交
2e2ce0f621
共有 1 个文件被更改,包括 44 次插入45 次删除
  1. 44 45
      crates/cdk/src/wallet/multi_mint_wallet.rs

+ 44 - 45
crates/cdk/src/wallet/multi_mint_wallet.rs

@@ -11,7 +11,7 @@ use anyhow::Result;
 use cdk_common::database;
 use cdk_common::database::WalletDatabase;
 use cdk_common::wallet::{Transaction, TransactionDirection, WalletKey};
-use tokio::sync::Mutex;
+use tokio::sync::RwLock;
 use tracing::instrument;
 
 use super::receive::ReceiveOptions;
@@ -31,7 +31,7 @@ pub struct MultiMintWallet {
     pub localstore: Arc<dyn WalletDatabase<Err = database::Error> + Send + Sync>,
     seed: Arc<[u8]>,
     /// Wallets
-    pub wallets: Arc<Mutex<BTreeMap<WalletKey, Wallet>>>,
+    pub wallets: Arc<RwLock<BTreeMap<WalletKey, Wallet>>>,
 }
 
 impl MultiMintWallet {
@@ -44,7 +44,7 @@ impl MultiMintWallet {
         Self {
             localstore,
             seed,
-            wallets: Arc::new(Mutex::new(
+            wallets: Arc::new(RwLock::new(
                 wallets
                     .into_iter()
                     .map(|w| (WalletKey::new(w.mint_url.clone(), w.unit.clone()), w))
@@ -58,7 +58,7 @@ impl MultiMintWallet {
     pub async fn add_wallet(&self, wallet: Wallet) {
         let wallet_key = WalletKey::new(wallet.mint_url.clone(), wallet.unit.clone());
 
-        let mut wallets = self.wallets.lock().await;
+        let mut wallets = self.wallets.write().await;
 
         wallets.insert(wallet_key, wallet);
     }
@@ -88,7 +88,7 @@ impl MultiMintWallet {
     /// Remove Wallet from MultiMintWallet
     #[instrument(skip(self))]
     pub async fn remove_wallet(&self, wallet_key: &WalletKey) {
-        let mut wallets = self.wallets.lock().await;
+        let mut wallets = self.wallets.write().await;
 
         wallets.remove(wallet_key);
     }
@@ -96,21 +96,19 @@ impl MultiMintWallet {
     /// Get Wallets from MultiMintWallet
     #[instrument(skip(self))]
     pub async fn get_wallets(&self) -> Vec<Wallet> {
-        self.wallets.lock().await.values().cloned().collect()
+        self.wallets.read().await.values().cloned().collect()
     }
 
     /// Get Wallet from MultiMintWallet
     #[instrument(skip(self))]
     pub async fn get_wallet(&self, wallet_key: &WalletKey) -> Option<Wallet> {
-        let wallets = self.wallets.lock().await;
-
-        wallets.get(wallet_key).cloned()
+        self.wallets.read().await.get(wallet_key).cloned()
     }
 
     /// Check if mint unit pair is in wallet
     #[instrument(skip(self))]
     pub async fn has(&self, wallet_key: &WalletKey) -> bool {
-        self.wallets.lock().await.contains_key(wallet_key)
+        self.wallets.read().await.contains_key(wallet_key)
     }
 
     /// Get wallet balances
@@ -121,7 +119,7 @@ impl MultiMintWallet {
     ) -> Result<BTreeMap<MintUrl, Amount>, Error> {
         let mut balances = BTreeMap::new();
 
-        for (WalletKey { mint_url, unit: u }, wallet) in self.wallets.lock().await.iter() {
+        for (WalletKey { mint_url, unit: u }, wallet) in self.wallets.read().await.iter() {
             if unit == u {
                 let wallet_balance = wallet.total_balance().await?;
                 balances.insert(mint_url.clone(), wallet_balance);
@@ -138,7 +136,7 @@ impl MultiMintWallet {
     ) -> Result<BTreeMap<MintUrl, (Vec<Proof>, CurrencyUnit)>, Error> {
         let mut mint_proofs = BTreeMap::new();
 
-        for (WalletKey { mint_url, unit: u }, wallet) in self.wallets.lock().await.iter() {
+        for (WalletKey { mint_url, unit: u }, wallet) in self.wallets.read().await.iter() {
             let wallet_proofs = wallet.get_unspent_proofs().await?;
             mint_proofs.insert(mint_url.clone(), (wallet_proofs, u.clone()));
         }
@@ -153,7 +151,7 @@ impl MultiMintWallet {
     ) -> Result<Vec<Transaction>, Error> {
         let mut transactions = Vec::new();
 
-        for (_, wallet) in self.wallets.lock().await.iter() {
+        for (_, wallet) in self.wallets.read().await.iter() {
             let wallet_transactions = wallet.list_transactions(direction).await?;
             transactions.extend(wallet_transactions);
         }
@@ -171,9 +169,9 @@ impl MultiMintWallet {
         amount: Amount,
         opts: SendOptions,
     ) -> Result<PreparedSend, Error> {
-        let wallet = self
-            .get_wallet(wallet_key)
-            .await
+        let wallets = self.wallets.read().await;
+        let wallet = wallets
+            .get(wallet_key)
             .ok_or(Error::UnknownWallet(wallet_key.clone()))?;
 
         wallet.prepare_send(amount, opts).await
@@ -187,9 +185,9 @@ impl MultiMintWallet {
         send: PreparedSend,
         memo: Option<SendMemo>,
     ) -> Result<Token, Error> {
-        let wallet = self
-            .get_wallet(wallet_key)
-            .await
+        let wallets = self.wallets.read().await;
+        let wallet = wallets
+            .get(wallet_key)
             .ok_or(Error::UnknownWallet(wallet_key.clone()))?;
 
         wallet.send(send, memo).await
@@ -203,9 +201,9 @@ impl MultiMintWallet {
         amount: Amount,
         description: Option<String>,
     ) -> Result<MintQuote, Error> {
-        let wallet = self
-            .get_wallet(wallet_key)
-            .await
+        let wallets = self.wallets.read().await;
+        let wallet = wallets
+            .get(wallet_key)
             .ok_or(Error::UnknownWallet(wallet_key.clone()))?;
 
         wallet.mint_quote(amount, description).await
@@ -221,16 +219,16 @@ impl MultiMintWallet {
         let mut amount_minted = HashMap::new();
         match wallet_key {
             Some(wallet_key) => {
-                let wallet = self
-                    .get_wallet(&wallet_key)
-                    .await
+                let wallets = self.wallets.read().await;
+                let wallet = wallets
+                    .get(&wallet_key)
                     .ok_or(Error::UnknownWallet(wallet_key.clone()))?;
 
                 let amount = wallet.check_all_mint_quotes().await?;
-                amount_minted.insert(wallet.unit, amount);
+                amount_minted.insert(wallet.unit.clone(), amount);
             }
             None => {
-                for (_, wallet) in self.wallets.lock().await.iter() {
+                for (_, wallet) in self.wallets.read().await.iter() {
                     let amount = wallet.check_all_mint_quotes().await?;
 
                     amount_minted
@@ -252,10 +250,11 @@ impl MultiMintWallet {
         quote_id: &str,
         conditions: Option<SpendingConditions>,
     ) -> Result<Proofs, Error> {
-        let wallet = self
-            .get_wallet(wallet_key)
-            .await
+        let wallets = self.wallets.read().await;
+        let wallet = wallets
+            .get(wallet_key)
             .ok_or(Error::UnknownWallet(wallet_key.clone()))?;
+
         wallet
             .mint(quote_id, SplitTarget::default(), conditions)
             .await
@@ -287,9 +286,9 @@ impl MultiMintWallet {
         }
 
         let wallet_key = WalletKey::new(mint_url.clone(), unit);
-        let wallet = self
-            .get_wallet(&wallet_key)
-            .await
+        let wallets = self.wallets.read().await;
+        let wallet = wallets
+            .get(&wallet_key)
             .ok_or(Error::UnknownWallet(wallet_key.clone()))?;
 
         match wallet
@@ -320,9 +319,9 @@ impl MultiMintWallet {
         wallet_key: &WalletKey,
         max_fee: Option<Amount>,
     ) -> Result<Melted, Error> {
-        let wallet = self
-            .get_wallet(wallet_key)
-            .await
+        let wallets = self.wallets.read().await;
+        let wallet = wallets
+            .get(wallet_key)
             .ok_or(Error::UnknownWallet(wallet_key.clone()))?;
 
         let quote = wallet.melt_quote(bolt11.to_string(), options).await?;
@@ -336,9 +335,9 @@ impl MultiMintWallet {
     /// Restore
     #[instrument(skip(self))]
     pub async fn restore(&self, wallet_key: &WalletKey) -> Result<Amount, Error> {
-        let wallet = self
-            .get_wallet(wallet_key)
-            .await
+        let wallets = self.wallets.read().await;
+        let wallet = wallets
+            .get(wallet_key)
             .ok_or(Error::UnknownWallet(wallet_key.clone()))?;
 
         wallet.restore().await
@@ -352,9 +351,9 @@ impl MultiMintWallet {
         token: &Token,
         conditions: SpendingConditions,
     ) -> Result<(), Error> {
-        let wallet = self
-            .get_wallet(wallet_key)
-            .await
+        let wallets = self.wallets.read().await;
+        let wallet = wallets
+            .get(wallet_key)
             .ok_or(Error::UnknownWallet(wallet_key.clone()))?;
 
         wallet.verify_token_p2pk(token, conditions)
@@ -367,9 +366,9 @@ impl MultiMintWallet {
         wallet_key: &WalletKey,
         token: &Token,
     ) -> Result<(), Error> {
-        let wallet = self
-            .get_wallet(wallet_key)
-            .await
+        let wallets = self.wallets.read().await;
+        let wallet = wallets
+            .get(wallet_key)
             .ok_or(Error::UnknownWallet(wallet_key.clone()))?;
 
         wallet.verify_token_dleq(token).await