Эх сурвалжийг харах

Move logic to its trait.

I moved the logic to send the initial values to the subscriptions onto a
generic trait implemented in nut17. The main goal is to have the same behavior
regardless of whether the subscriptions come from web sockets or internally
from other parts of the systems or other crates.
Cesar Rodas 4 сар өмнө
parent
commit
d9d35a9f41

+ 1 - 78
crates/cdk-axum/src/ws/subscribe.rs

@@ -3,10 +3,7 @@ use super::{
     WsContext, WsError, JSON_RPC_VERSION,
 };
 use cdk::{
-    nuts::{
-        nut17::{Kind, NotificationPayload, Params},
-        MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey,
-    },
+    nuts::nut17::{NotificationPayload, Params},
     pub_sub::SubId,
 };
 
@@ -67,80 +64,6 @@ impl WsHandle for Method {
             .subscribe(self.0.clone())
             .await;
         let publisher = context.publisher.clone();
-
-        let current_notification_to_send: Vec<NotificationPayload> = match self.0.kind {
-            Kind::Bolt11MeltQuote => {
-                let queries = self
-                    .0
-                    .filters
-                    .iter()
-                    .map(|id| context.state.mint.localstore.get_melt_quote(id))
-                    .collect::<Vec<_>>();
-
-                futures::future::try_join_all(queries)
-                    .await
-                    .map(|quotes| {
-                        quotes
-                            .into_iter()
-                            .filter_map(|quote| quote.map(|x| x.into()))
-                            .map(|x: MeltQuoteBolt11Response| x.into())
-                            .collect::<Vec<_>>()
-                    })
-                    .unwrap_or_default()
-            }
-            Kind::Bolt11MintQuote => {
-                let queries = self
-                    .0
-                    .filters
-                    .iter()
-                    .map(|id| context.state.mint.localstore.get_mint_quote(id))
-                    .collect::<Vec<_>>();
-
-                futures::future::try_join_all(queries)
-                    .await
-                    .map(|quotes| {
-                        quotes
-                            .into_iter()
-                            .filter_map(|quote| quote.map(|x| x.into()))
-                            .map(|x: MintQuoteBolt11Response| x.into())
-                            .collect::<Vec<_>>()
-                    })
-                    .unwrap_or_default()
-            }
-            Kind::ProofState => {
-                if let Ok(public_keys) = self
-                    .0
-                    .filters
-                    .iter()
-                    .map(PublicKey::from_hex)
-                    .collect::<Result<Vec<PublicKey>, _>>()
-                {
-                    context
-                        .state
-                        .mint
-                        .localstore
-                        .get_proofs_states(&public_keys)
-                        .await
-                        .map(|x| {
-                            x.into_iter()
-                                .enumerate()
-                                .filter_map(|(idx, state)| {
-                                    state.map(|state| (public_keys[idx], state).into())
-                                })
-                                .map(|x: ProofState| x.into())
-                                .collect::<Vec<_>>()
-                        })
-                        .unwrap_or_default()
-                } else {
-                    vec![]
-                }
-            }
-        };
-
-        for notification in current_notification_to_send.into_iter() {
-            let _ = publisher.send((sub_id.clone(), notification)).await;
-        }
-
         context.subscriptions.insert(
             sub_id.clone(),
             tokio::spawn(async move {

+ 1 - 1
crates/cdk-integration-tests/tests/regtest.rs

@@ -35,7 +35,7 @@ async fn get_notification<T: StreamExt<Item = Result<Message, E>> + Unpin, E: De
         .unwrap();
 
     let mut response: serde_json::Value =
-        serde_json::from_str(&msg.to_text().unwrap()).expect("valid json");
+        serde_json::from_str(msg.to_text().unwrap()).expect("valid json");
 
     let mut params_raw = response
         .as_object_mut()

+ 1 - 1
crates/cdk/src/mint/mod.rs

@@ -185,7 +185,7 @@ impl Mint {
         Ok(Self {
             mint_url: MintUrl::from_str(mint_url)?,
             keysets: Arc::new(RwLock::new(active_keysets)),
-            pubsub_manager: Default::default(),
+            pubsub_manager: Arc::new(localstore.clone().into()),
             secp_ctx,
             quote_ttl,
             xpriv,

+ 116 - 6
crates/cdk/src/nuts/nut17.rs

@@ -1,14 +1,15 @@
 //! Specific Subscription for the cdk crate
 
 use crate::{
+    cdk_database::{self, MintDatabase},
     nuts::{
         MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, MintQuoteState,
         ProofState,
     },
-    pub_sub::{self, Index, Indexable, SubscriptionGlobalId},
+    pub_sub::{self, Index, Indexable, OnNewSubscription, SubscriptionGlobalId},
 };
 use serde::{Deserialize, Serialize};
-use std::ops::Deref;
+use std::{collections::HashMap, ops::Deref, sync::Arc};
 
 /// Subscription Parameter according to the standard
 #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -59,7 +60,7 @@ impl Default for SupportedMethods {
 
 pub use crate::pub_sub::SubId;
 
-use super::{BlindSignature, CurrencyUnit, PaymentMethod};
+use super::{BlindSignature, CurrencyUnit, PaymentMethod, PublicKey};
 
 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 #[serde(untagged)]
@@ -144,16 +145,125 @@ impl From<Params> for Vec<Index<(String, Kind)>> {
     }
 }
 
-/// Manager
 #[derive(Default)]
+/// Subscription Init
+///
+/// This struct triggers code when a new subscription is created.
+///
+/// It is used to send the initial state of the subscription to the client.
+pub struct SubscriptionInit(Option<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>>);
+
+#[async_trait::async_trait]
+impl OnNewSubscription for SubscriptionInit {
+    type Event = NotificationPayload;
+    type Index = (String, Kind);
+
+    async fn on_new_subscription(
+        &self,
+        request: &[&Self::Index],
+    ) -> Result<Vec<Self::Event>, String> {
+        let datastore = if let Some(localstore) = self.0.as_ref() {
+            localstore
+        } else {
+            return Ok(vec![]);
+        };
+
+        let mut to_return = vec![];
+
+        for (kind, values) in request.iter().fold(
+            HashMap::new(),
+            |mut acc: HashMap<&Kind, Vec<&String>>, (data, kind)| {
+                acc.entry(kind).or_default().push(data);
+                acc
+            },
+        ) {
+            match kind {
+                Kind::Bolt11MeltQuote => {
+                    let queries = values
+                        .iter()
+                        .map(|id| datastore.get_melt_quote(id))
+                        .collect::<Vec<_>>();
+
+                    to_return.extend(
+                        futures::future::try_join_all(queries)
+                            .await
+                            .map(|quotes| {
+                                quotes
+                                    .into_iter()
+                                    .filter_map(|quote| quote.map(|x| x.into()))
+                                    .map(|x: MeltQuoteBolt11Response| x.into())
+                                    .collect::<Vec<_>>()
+                            })
+                            .map_err(|e| e.to_string())?,
+                    );
+                }
+                Kind::Bolt11MintQuote => {
+                    let queries = values
+                        .iter()
+                        .map(|id| datastore.get_mint_quote(id))
+                        .collect::<Vec<_>>();
+
+                    to_return.extend(
+                        futures::future::try_join_all(queries)
+                            .await
+                            .map(|quotes| {
+                                quotes
+                                    .into_iter()
+                                    .filter_map(|quote| quote.map(|x| x.into()))
+                                    .map(|x: MintQuoteBolt11Response| x.into())
+                                    .collect::<Vec<_>>()
+                            })
+                            .map_err(|e| e.to_string())?,
+                    );
+                }
+                Kind::ProofState => {
+                    let public_keys = values
+                        .iter()
+                        .map(PublicKey::from_hex)
+                        .collect::<Result<Vec<PublicKey>, _>>()
+                        .map_err(|e| e.to_string())?;
+
+                    to_return.extend(
+                        datastore
+                            .get_proofs_states(&public_keys)
+                            .await
+                            .map_err(|e| e.to_string())?
+                            .into_iter()
+                            .enumerate()
+                            .filter_map(|(idx, state)| {
+                                state.map(|state| (public_keys[idx], state).into())
+                            })
+                            .map(|state: ProofState| state.into()),
+                    );
+                }
+            }
+        }
+
+        Ok(to_return)
+    }
+}
+
+/// Manager
 /// Publish–subscribe manager
 ///
 /// Nut-17 implementation is system-wide and not only through the WebSocket, so
 /// it is possible for another part of the system to subscribe to events.
-pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind)>);
+pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind), SubscriptionInit>);
+
+impl Default for PubSubManager {
+    fn default() -> Self {
+        PubSubManager(SubscriptionInit::default().into())
+    }
+}
+
+impl From<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>> for PubSubManager {
+    fn from(val: Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>) -> Self {
+        PubSubManager(SubscriptionInit(Some(val)).into())
+    }
+}
 
 impl Deref for PubSubManager {
-    type Target = pub_sub::Manager<NotificationPayload, (String, Kind)>;
+    type Target = pub_sub::Manager<NotificationPayload, (String, Kind), SubscriptionInit>;
 
     fn deref(&self) -> &Self::Target {
         &self.0

+ 64 - 5
crates/cdk/src/pub_sub/mod.rs

@@ -37,6 +37,25 @@ pub const DEFAULT_REMOVE_SIZE: usize = 10_000;
 /// Default channel size for subscription buffering
 pub const DEFAULT_CHANNEL_SIZE: usize = 10;
 
+#[async_trait::async_trait]
+/// On New Subscription trait
+///
+/// This trait is optional and it is used to notify the application when a new
+/// subscription is created. This is useful when the application needs to send
+/// the initial state to the subscriber upon subscription
+pub trait OnNewSubscription {
+    /// Index type
+    type Index;
+    /// Subscription event type
+    type Event;
+
+    /// Called when a new subscription is created
+    async fn on_new_subscription(
+        &self,
+        request: &[&Self::Index],
+    ) -> Result<Vec<Self::Event>, String>;
+}
+
 /// Subscription manager
 ///
 /// This object keep track of all subscription listener and it is also
@@ -45,21 +64,24 @@ pub const DEFAULT_CHANNEL_SIZE: usize = 10;
 /// The content of the notification is not relevant to this scope and it is up
 /// to the application, therefore the generic T is used instead of a specific
 /// type
-pub struct Manager<T, I>
+pub struct Manager<T, I, F>
 where
     T: Indexable<Type = I> + Clone + Send + Sync + 'static,
     I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
+    F: OnNewSubscription<Index = I, Event = T> + 'static,
 {
     indexes: IndexTree<T, I>,
+    on_new_subscription: Option<F>,
     unsubscription_sender: mpsc::Sender<(SubId, Vec<Index<I>>)>,
     active_subscriptions: Arc<AtomicUsize>,
     background_subscription_remover: Option<JoinHandle<()>>,
 }
 
-impl<T, I> Default for Manager<T, I>
+impl<T, I, F> Default for Manager<T, I, F>
 where
     T: Indexable<Type = I> + Clone + Send + Sync + 'static,
     I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
+    F: OnNewSubscription<Index = I, Event = T> + 'static,
 {
     fn default() -> Self {
         let (sender, receiver) = mpsc::channel(DEFAULT_REMOVE_SIZE);
@@ -72,6 +94,7 @@ where
                 storage.clone(),
                 active_subscriptions.clone(),
             ))),
+            on_new_subscription: None,
             unsubscription_sender: sender,
             active_subscriptions,
             indexes: storage,
@@ -79,10 +102,24 @@ where
     }
 }
 
-impl<T, I> Manager<T, I>
+impl<T, I, F> From<F> for Manager<T, I, F>
 where
     T: Indexable<Type = I> + Clone + Send + Sync + 'static,
-    I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
+    I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
+    F: OnNewSubscription<Index = I, Event = T> + 'static,
+{
+    fn from(value: F) -> Self {
+        let mut manager: Self = Default::default();
+        manager.on_new_subscription = Some(value);
+        manager
+    }
+}
+
+impl<T, I, F> Manager<T, I, F>
+where
+    T: Indexable<Type = I> + Clone + Send + Sync + 'static,
+    I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
+    F: OnNewSubscription<Index = I, Event = T> + 'static,
 {
     #[inline]
     /// Broadcast an event to all listeners
@@ -132,8 +169,29 @@ where
     ) -> ActiveSubscription<T, I> {
         let (sender, receiver) = mpsc::channel(10);
         let sub_id: SubId = params.as_ref().clone();
+
         let indexes: Vec<Index<I>> = params.into();
 
+        if let Some(on_new_subscription) = self.on_new_subscription.as_ref() {
+            match on_new_subscription
+                .on_new_subscription(&indexes.iter().map(|x| x.deref()).collect::<Vec<_>>())
+                .await
+            {
+                Ok(events) => {
+                    for event in events {
+                        let _ = sender.try_send((sub_id.clone(), event));
+                    }
+                }
+                Err(err) => {
+                    tracing::info!(
+                        "Failed to get initial state for subscription: {:?}, {}",
+                        sub_id,
+                        err
+                    );
+                }
+            }
+        }
+
         let mut index_storage = self.indexes.write().await;
         for index in indexes.clone() {
             index_storage.insert(index, sender.clone());
@@ -180,10 +238,11 @@ where
 }
 
 /// Manager goes out of scope, stop all background tasks
-impl<T, I> Drop for Manager<T, I>
+impl<T, I, F> Drop for Manager<T, I, F>
 where
     T: Indexable<Type = I> + Clone + Send + Sync + 'static,
     I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
+    F: OnNewSubscription<Index = I, Event = T> + 'static,
 {
     fn drop(&mut self) {
         if let Some(handler) = self.background_subscription_remover.take() {