Quellcode durchsuchen

Use parking_lot

Cesar Rodas vor 3 Monaten
Ursprung
Commit
9e1e122d0a

+ 1 - 0
crates/cdk-common/Cargo.toml

@@ -41,6 +41,7 @@ serde_json.workspace = true
 serde_with.workspace = true
 web-time.workspace = true
 tokio.workspace = true
+parking_lot = "0.12.5"
 
 [target.'cfg(target_arch = "wasm32")'.dependencies]
 uuid = { workspace = true, features = ["js"], optional = true }

+ 2 - 0
crates/cdk-common/src/lib.rs

@@ -33,3 +33,5 @@ pub use cashu::nuts::{self, *};
 pub use cashu::quote_id::{self, *};
 pub use cashu::{dhke, ensure_cdk, mint_url, secret, util, SECP256K1};
 pub use error::Error;
+/// Re-export parking_lot for reuse
+pub use parking_lot;

+ 0 - 4
crates/cdk-common/src/pub_sub/error.rs

@@ -5,10 +5,6 @@ use tokio::sync::mpsc::error::TrySendError;
 #[derive(thiserror::Error, Debug)]
 /// Error
 pub enum Error {
-    /// Poison locked
-    #[error("Poisoned lock")]
-    Poison,
-
     /// Already subscribed
     #[error("Already subscribed")]
     AlreadySubscribed,

+ 4 - 3
crates/cdk-common/src/pub_sub/pubsub.rs

@@ -3,8 +3,9 @@
 use std::cmp::Ordering;
 use std::collections::{BTreeMap, HashSet};
 use std::sync::atomic::AtomicUsize;
-use std::sync::{Arc, RwLock};
+use std::sync::Arc;
 
+use parking_lot::RwLock;
 use tokio::sync::mpsc;
 
 use super::subscriber::{ActiveSubscription, SubscriptionRequest};
@@ -61,7 +62,7 @@ where
     /// Publish an event to all listenrs
     #[inline(always)]
     fn publish_internal(event: S::Event, listeners_index: &TopicTree<S>) -> Result<(), Error> {
-        let index_storage = listeners_index.read().map_err(|_| Error::Poison)?;
+        let index_storage = listeners_index.read();
 
         let mut sent = HashSet::new();
         for topic in event.get_topics() {
@@ -131,7 +132,7 @@ where
     {
         let subscription_name = request.subscription_name();
         let sender = Subscriber::new(subscription_name.clone(), sender);
-        let mut index_storage = self.listeners_topics.write().map_err(|_| Error::Poison)?;
+        let mut index_storage = self.listeners_topics.write();
         let subscription_internal_id = self
             .unique_subscription_counter
             .fetch_add(1, std::sync::atomic::Ordering::Relaxed);

+ 25 - 64
crates/cdk-common/src/pub_sub/remote_consumer.rs

@@ -3,9 +3,10 @@
 //! Consumers are designed to connect to a producer, through a transport, and subscribe to events.
 use std::collections::{HashMap, VecDeque};
 use std::sync::atomic::AtomicBool;
-use std::sync::{Arc, RwLock};
+use std::sync::Arc;
 use std::time::Duration;
 
+use parking_lot::RwLock;
 use tokio::sync::mpsc;
 use tokio::time::{sleep, Instant};
 
@@ -162,11 +163,7 @@ where
         X: Into<S::Event>,
     {
         let event = event.into();
-        let mut cached_events = self.cached_events.write().unwrap_or_else(|mut err| {
-            **err.get_mut() = HashMap::new();
-            self.cached_events.clear_poison();
-            err.into_inner()
-        });
+        let mut cached_events = self.cached_events.write();
 
         for topic in event.get_topics() {
             cached_events.insert(topic, event.clone());
@@ -222,13 +219,7 @@ where
                 break;
             }
 
-            if instance
-                .remote_subscriptions
-                .read()
-                .map(|x| x.len())
-                .unwrap_or_default()
-                == 0
-            {
+            if instance.remote_subscriptions.read().is_empty() {
                 sleep(Duration::from_millis(100)).await;
                 continue;
             }
@@ -242,22 +233,16 @@ where
                 let (sender, receiver) = mpsc::channel(INTERNAL_POLL_SIZE);
 
                 {
-                    *(instance.stream_ctrl.write().unwrap_or_else(|mut err| {
-                        **err.get_mut() = None;
-                        instance.stream_ctrl.clear_poison();
-                        err.into_inner()
-                    })) = Some(sender);
+                    *instance.stream_ctrl.write() = Some(sender);
                 }
 
                 let current_subscriptions = {
-                    if let Ok(remote_subscriptions) = instance.remote_subscriptions.read() {
-                        remote_subscriptions
-                            .iter()
-                            .map(|(key, name)| (name.name.clone(), key.clone()))
-                            .collect::<Vec<_>>()
-                    } else {
-                        vec![]
-                    }
+                    instance
+                        .remote_subscriptions
+                        .read()
+                        .iter()
+                        .map(|(key, name)| (name.name.clone(), key.clone()))
+                        .collect::<Vec<_>>()
                 };
 
                 if let Err(err) = instance
@@ -285,27 +270,17 @@ where
                 }
 
                 // remove sender to stream, as there is no stream
-                let _ = instance
-                    .stream_ctrl
-                    .write()
-                    .unwrap_or_else(|mut err| {
-                        **err.get_mut() = None;
-                        instance.stream_ctrl.clear_poison();
-                        err.into_inner()
-                    })
-                    .take();
+                let _ = instance.stream_ctrl.write().take();
             }
 
             if poll_supported {
                 let current_subscriptions = {
-                    if let Ok(remote_subscriptions) = instance.remote_subscriptions.read() {
-                        remote_subscriptions
-                            .iter()
-                            .map(|(key, name)| (name.name.clone(), key.clone()))
-                            .collect::<Vec<_>>()
-                    } else {
-                        vec![]
-                    }
+                    instance
+                        .remote_subscriptions
+                        .read()
+                        .iter()
+                        .map(|(key, name)| (name.name.clone(), key.clone()))
+                        .collect::<Vec<_>>()
                 };
 
                 if let Err(err) = instance
@@ -339,14 +314,10 @@ where
         let topics = self
             .subscriptions
             .write()
-            .map_err(|_| Error::Poison)?
             .remove(&subscription_name)
             .ok_or(Error::AlreadySubscribed)?;
 
-        let mut remote_subscriptions = self
-            .remote_subscriptions
-            .write()
-            .map_err(|_| Error::Poison)?;
+        let mut remote_subscriptions = self.remote_subscriptions.write();
 
         for topic in topics {
             let mut remote_subscription =
@@ -362,11 +333,7 @@ where
                 .unwrap_or_default();
 
             if remote_subscription.total_subscribers == 0 {
-                let mut cached_events = self.cached_events.write().unwrap_or_else(|mut err| {
-                    **err.get_mut() = HashMap::new();
-                    self.cached_events.clear_poison();
-                    err.into_inner()
-                });
+                let mut cached_events = self.cached_events.write();
 
                 cached_events.remove(&topic);
 
@@ -385,7 +352,7 @@ where
 
     #[inline(always)]
     fn message_to_stream(&self, message: StreamCtrl<T::Spec>) -> Result<(), Error> {
-        let to_stream = self.stream_ctrl.read().map_err(|_| Error::Poison)?;
+        let to_stream = self.stream_ctrl.read();
 
         if let Some(to_stream) = to_stream.as_ref() {
             Ok(to_stream.try_send(message)?)
@@ -414,21 +381,15 @@ where
         let subscription_name = request.subscription_name();
         let topics = request.try_get_topics()?;
 
-        let mut remote_subscriptions = self
-            .remote_subscriptions
-            .write()
-            .map_err(|_| Error::Poison)?;
-        let mut subscriptions = self.subscriptions.write().map_err(|_| Error::Poison)?;
+        let mut remote_subscriptions = self.remote_subscriptions.write();
+        let mut subscriptions = self.subscriptions.write();
 
         if subscriptions.get(&subscription_name).is_some() {
             return Err(Error::AlreadySubscribed);
         }
 
         let mut previous_messages = Vec::new();
-        let cached_events = self.cached_events.read().unwrap_or_else(|e| {
-            self.cached_events.clear_poison();
-            e.into_inner()
-        });
+        let cached_events = self.cached_events.read();
 
         for topic in topics.iter() {
             if let Some(subscription) = remote_subscriptions.get_mut(topic) {
@@ -470,7 +431,7 @@ where
     fn drop(&mut self) {
         self.still_running
             .store(false, std::sync::atomic::Ordering::Release);
-        if let Ok(Some(to_stream)) = self.stream_ctrl.read().map(|sender| sender.clone()) {
+        if let Some(to_stream) = self.stream_ctrl.read().as_ref() {
             let _ = to_stream.try_send(StreamCtrl::Stop).inspect_err(|err| {
                 tracing::error!("Failed to send message LongPoll::Stop due to {err:?}")
             });

+ 1 - 1
crates/cdk-common/src/pub_sub/subscriber.rs

@@ -85,7 +85,7 @@ where
 {
     fn drop(&mut self) {
         // remove the listener
-        let mut topics = self.topics.write().unwrap();
+        let mut topics = self.topics.write();
         for index in self.subscribed_to.drain(..) {
             topics.remove(&(index, self.id));
         }

+ 5 - 4
crates/cdk/src/wallet/subscription/mod.rs

@@ -8,10 +8,11 @@
 use std::collections::HashMap;
 use std::fmt::Debug;
 use std::sync::atomic::AtomicUsize;
-use std::sync::{Arc, RwLock};
+use std::sync::Arc;
 
 use cdk_common::nut17::ws::{WsMethodRequest, WsRequest, WsUnsubscribeRequest};
 use cdk_common::nut17::{Kind, NotificationId};
+use cdk_common::parking_lot::RwLock;
 use cdk_common::pub_sub::remote_consumer::{
     Consumer, InternalRelay, RemoteActiveConsumer, StreamCtrl, SubscribeMessage, Transport,
 };
@@ -63,8 +64,9 @@ impl Debug for SubscriptionManager {
             "Subscription Manager connected to {:?}",
             self.all_connections
                 .write()
-                .map(|connections| connections.keys().cloned().collect::<Vec<_>>())
-                .unwrap_or_default()
+                .keys()
+                .cloned()
+                .collect::<Vec<_>>()
         )
     }
 }
@@ -87,7 +89,6 @@ impl SubscriptionManager {
     ) -> Result<RemoteActiveConsumer<SubscriptionClient>, PubsubError> {
         self.all_connections
             .write()
-            .map_err(|_| PubsubError::Poison)?
             .entry(mint_url.clone())
             .or_insert_with(|| {
                 Consumer::new(