Browse Source

fix: Send current state of the subscriptions (#444)

César D. Rodas 5 months ago
parent
commit
cc5b267367

+ 19 - 1
crates/cdk-axum/src/ws/subscribe.rs

@@ -11,17 +11,26 @@ use cdk::{
 pub struct Method(Params);
 
 #[derive(Debug, Clone, serde::Serialize)]
+/// The response to a subscription request
 pub struct Response {
+    /// Status
     status: String,
+    /// Subscription ID
     #[serde(rename = "subId")]
     sub_id: SubId,
 }
 
 #[derive(Debug, Clone, serde::Serialize)]
+/// The notification
+///
+/// This is the notification that is sent to the client when an event matches a
+/// subscription
 pub struct Notification {
+    /// The subscription ID
     #[serde(rename = "subId")]
     pub sub_id: SubId,
 
+    /// The notification payload
     pub payload: NotificationPayload,
 }
 
@@ -39,12 +48,21 @@ impl From<(SubId, NotificationPayload)> for WsNotification<Notification> {
 impl WsHandle for Method {
     type Response = Response;
 
+    /// The `handle` method is called when a client sends a subscription request
     async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError> {
         let sub_id = self.0.id.clone();
         if context.subscriptions.contains_key(&sub_id) {
+            // Subscription ID already exits. Returns an error instead of
+            // replacing the other subscription or avoiding it.
             return Err(WsError::InvalidParams);
         }
-        let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await;
+
+        let mut subscription = context
+            .state
+            .mint
+            .pubsub_manager
+            .subscribe(self.0.clone())
+            .await;
         let publisher = context.publisher.clone();
         context.subscriptions.insert(
             sub_id.clone(),

+ 13 - 1
crates/cdk-integration-tests/tests/regtest.rs

@@ -35,7 +35,7 @@ async fn get_notification<T: StreamExt<Item = Result<Message, E>> + Unpin, E: De
         .unwrap();
 
     let mut response: serde_json::Value =
-        serde_json::from_str(&msg.to_text().unwrap()).expect("valid json");
+        serde_json::from_str(msg.to_text().unwrap()).expect("valid json");
 
     let mut params_raw = response
         .as_object_mut()
@@ -113,6 +113,18 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> {
     assert!(melt_response.state == MeltQuoteState::Paid);
 
     let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await;
+    // first message is the current state
+    assert_eq!("test-sub", sub_id);
+    let payload = match payload {
+        NotificationPayload::MeltQuoteBolt11Response(melt) => melt,
+        _ => panic!("Wrong payload"),
+    };
+    assert_eq!(payload.amount + payload.fee_reserve, 100.into());
+    assert_eq!(payload.quote, melt.id);
+    assert_eq!(payload.state, MeltQuoteState::Unpaid);
+
+    // get current state
+    let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await;
     assert_eq!("test-sub", sub_id);
     let payload = match payload {
         NotificationPayload::MeltQuoteBolt11Response(melt) => melt,

+ 1 - 1
crates/cdk/Cargo.toml

@@ -39,7 +39,7 @@ serde_json = "1"
 serde_with = "3"
 tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
 thiserror = "1"
-futures = { version = "0.3.28", default-features = false, optional = true }
+futures = { version = "0.3.28", default-features = false, optional = true, features = ["alloc"] }
 url = "2.3"
 utoipa = { version = "4", optional = true }
 uuid = { version = "1", features = ["v4"] }

+ 1 - 1
crates/cdk/src/mint/mod.rs

@@ -185,7 +185,7 @@ impl Mint {
         Ok(Self {
             mint_url: MintUrl::from_str(mint_url)?,
             keysets: Arc::new(RwLock::new(active_keysets)),
-            pubsub_manager: Default::default(),
+            pubsub_manager: Arc::new(localstore.clone().into()),
             secp_ctx,
             quote_ttl,
             xpriv,

+ 2 - 0
crates/cdk/src/nuts/mod.rs

@@ -18,6 +18,7 @@ pub mod nut12;
 pub mod nut13;
 pub mod nut14;
 pub mod nut15;
+#[cfg(feature = "mint")]
 pub mod nut17;
 pub mod nut18;
 
@@ -48,5 +49,6 @@ pub use nut11::{Conditions, P2PKWitness, SigFlag, SpendingConditions};
 pub use nut12::{BlindSignatureDleq, ProofDleq};
 pub use nut14::HTLCWitness;
 pub use nut15::{Mpp, MppMethodSettings, Settings as NUT15Settings};
+#[cfg(feature = "mint")]
 pub use nut17::{NotificationPayload, PubSubManager};
 pub use nut18::{PaymentRequest, PaymentRequestPayload, Transport};

+ 3 - 2
crates/cdk/src/nuts/nut06.rs

@@ -5,7 +5,7 @@
 use serde::{Deserialize, Deserializer, Serialize, Serializer};
 
 use super::nut01::PublicKey;
-use super::{nut04, nut05, nut15, nut17, MppMethodSettings};
+use super::{nut04, nut05, nut15, MppMethodSettings};
 
 /// Mint Version
 #[derive(Debug, Clone, PartialEq, Eq, Hash)]
@@ -238,7 +238,8 @@ pub struct Nuts {
     /// NUT17 Settings
     #[serde(default)]
     #[serde(rename = "17")]
-    pub nut17: nut17::SupportedSettings,
+    #[cfg(feature = "mint")]
+    pub nut17: super::nut17::SupportedSettings,
 }
 
 impl Nuts {

+ 23 - 8
crates/cdk/src/nuts/nut17.rs → crates/cdk/src/nuts/nut17/mod.rs

@@ -1,5 +1,8 @@
 //! Specific Subscription for the cdk crate
 
+use super::{BlindSignature, CurrencyUnit, PaymentMethod};
+use crate::cdk_database::{self, MintDatabase};
+pub use crate::pub_sub::SubId;
 use crate::{
     nuts::{
         MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, MintQuoteState,
@@ -8,7 +11,11 @@ use crate::{
     pub_sub::{self, Index, Indexable, SubscriptionGlobalId},
 };
 use serde::{Deserialize, Serialize};
-use std::ops::Deref;
+use std::{ops::Deref, sync::Arc};
+
+mod on_subscription;
+
+pub use on_subscription::OnSubscription;
 
 /// Subscription Parameter according to the standard
 #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -57,10 +64,6 @@ impl Default for SupportedMethods {
     }
 }
 
-pub use crate::pub_sub::SubId;
-
-use super::{BlindSignature, CurrencyUnit, PaymentMethod};
-
 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 #[serde(untagged)]
 /// Subscription response
@@ -145,15 +148,27 @@ impl From<Params> for Vec<Index<(String, Kind)>> {
 }
 
 /// Manager
-#[derive(Default)]
 /// Publish–subscribe manager
 ///
 /// Nut-17 implementation is system-wide and not only through the WebSocket, so
 /// it is possible for another part of the system to subscribe to events.
-pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind)>);
+pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind), OnSubscription>);
+
+#[allow(clippy::default_constructed_unit_structs)]
+impl Default for PubSubManager {
+    fn default() -> Self {
+        PubSubManager(OnSubscription::default().into())
+    }
+}
+
+impl From<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>> for PubSubManager {
+    fn from(val: Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>) -> Self {
+        PubSubManager(OnSubscription(Some(val)).into())
+    }
+}
 
 impl Deref for PubSubManager {
-    type Target = pub_sub::Manager<NotificationPayload, (String, Kind)>;
+    type Target = pub_sub::Manager<NotificationPayload, (String, Kind), OnSubscription>;
 
     fn deref(&self) -> &Self::Target {
         &self.0

+ 110 - 0
crates/cdk/src/nuts/nut17/on_subscription.rs

@@ -0,0 +1,110 @@
+//! On Subscription
+//!
+//! This module contains the code that is triggered when a new subscription is created.
+use super::{Kind, NotificationPayload};
+use crate::{
+    cdk_database::{self, MintDatabase},
+    nuts::{MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey},
+    pub_sub::OnNewSubscription,
+};
+use std::{collections::HashMap, sync::Arc};
+
+#[derive(Default)]
+/// Subscription Init
+///
+/// This struct triggers code when a new subscription is created.
+///
+/// It is used to send the initial state of the subscription to the client.
+pub struct OnSubscription(
+    pub(crate) Option<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>>,
+);
+
+#[async_trait::async_trait]
+impl OnNewSubscription for OnSubscription {
+    type Event = NotificationPayload;
+    type Index = (String, Kind);
+
+    async fn on_new_subscription(
+        &self,
+        request: &[&Self::Index],
+    ) -> Result<Vec<Self::Event>, String> {
+        let datastore = if let Some(localstore) = self.0.as_ref() {
+            localstore
+        } else {
+            return Ok(vec![]);
+        };
+
+        let mut to_return = vec![];
+
+        for (kind, values) in request.iter().fold(
+            HashMap::new(),
+            |mut acc: HashMap<&Kind, Vec<&String>>, (data, kind)| {
+                acc.entry(kind).or_default().push(data);
+                acc
+            },
+        ) {
+            match kind {
+                Kind::Bolt11MeltQuote => {
+                    let queries = values
+                        .iter()
+                        .map(|id| datastore.get_melt_quote(id))
+                        .collect::<Vec<_>>();
+
+                    to_return.extend(
+                        futures::future::try_join_all(queries)
+                            .await
+                            .map(|quotes| {
+                                quotes
+                                    .into_iter()
+                                    .filter_map(|quote| quote.map(|x| x.into()))
+                                    .map(|x: MeltQuoteBolt11Response| x.into())
+                                    .collect::<Vec<_>>()
+                            })
+                            .map_err(|e| e.to_string())?,
+                    );
+                }
+                Kind::Bolt11MintQuote => {
+                    let queries = values
+                        .iter()
+                        .map(|id| datastore.get_mint_quote(id))
+                        .collect::<Vec<_>>();
+
+                    to_return.extend(
+                        futures::future::try_join_all(queries)
+                            .await
+                            .map(|quotes| {
+                                quotes
+                                    .into_iter()
+                                    .filter_map(|quote| quote.map(|x| x.into()))
+                                    .map(|x: MintQuoteBolt11Response| x.into())
+                                    .collect::<Vec<_>>()
+                            })
+                            .map_err(|e| e.to_string())?,
+                    );
+                }
+                Kind::ProofState => {
+                    let public_keys = values
+                        .iter()
+                        .map(PublicKey::from_hex)
+                        .collect::<Result<Vec<PublicKey>, _>>()
+                        .map_err(|e| e.to_string())?;
+
+                    to_return.extend(
+                        datastore
+                            .get_proofs_states(&public_keys)
+                            .await
+                            .map_err(|e| e.to_string())?
+                            .into_iter()
+                            .enumerate()
+                            .filter_map(|(idx, state)| {
+                                state.map(|state| (public_keys[idx], state).into())
+                            })
+                            .map(|state: ProofState| state.into()),
+                    );
+                }
+            }
+        }
+
+        Ok(to_return)
+    }
+}

+ 64 - 5
crates/cdk/src/pub_sub/mod.rs

@@ -37,6 +37,25 @@ pub const DEFAULT_REMOVE_SIZE: usize = 10_000;
 /// Default channel size for subscription buffering
 pub const DEFAULT_CHANNEL_SIZE: usize = 10;
 
+#[async_trait::async_trait]
+/// On New Subscription trait
+///
+/// This trait is optional and it is used to notify the application when a new
+/// subscription is created. This is useful when the application needs to send
+/// the initial state to the subscriber upon subscription
+pub trait OnNewSubscription {
+    /// Index type
+    type Index;
+    /// Subscription event type
+    type Event;
+
+    /// Called when a new subscription is created
+    async fn on_new_subscription(
+        &self,
+        request: &[&Self::Index],
+    ) -> Result<Vec<Self::Event>, String>;
+}
+
 /// Subscription manager
 ///
 /// This object keep track of all subscription listener and it is also
@@ -45,21 +64,24 @@ pub const DEFAULT_CHANNEL_SIZE: usize = 10;
 /// The content of the notification is not relevant to this scope and it is up
 /// to the application, therefore the generic T is used instead of a specific
 /// type
-pub struct Manager<T, I>
+pub struct Manager<T, I, F>
 where
     T: Indexable<Type = I> + Clone + Send + Sync + 'static,
     I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
+    F: OnNewSubscription<Index = I, Event = T> + 'static,
 {
     indexes: IndexTree<T, I>,
+    on_new_subscription: Option<F>,
     unsubscription_sender: mpsc::Sender<(SubId, Vec<Index<I>>)>,
     active_subscriptions: Arc<AtomicUsize>,
     background_subscription_remover: Option<JoinHandle<()>>,
 }
 
-impl<T, I> Default for Manager<T, I>
+impl<T, I, F> Default for Manager<T, I, F>
 where
     T: Indexable<Type = I> + Clone + Send + Sync + 'static,
     I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
+    F: OnNewSubscription<Index = I, Event = T> + 'static,
 {
     fn default() -> Self {
         let (sender, receiver) = mpsc::channel(DEFAULT_REMOVE_SIZE);
@@ -72,6 +94,7 @@ where
                 storage.clone(),
                 active_subscriptions.clone(),
             ))),
+            on_new_subscription: None,
             unsubscription_sender: sender,
             active_subscriptions,
             indexes: storage,
@@ -79,10 +102,24 @@ where
     }
 }
 
-impl<T, I> Manager<T, I>
+impl<T, I, F> From<F> for Manager<T, I, F>
 where
     T: Indexable<Type = I> + Clone + Send + Sync + 'static,
-    I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
+    I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
+    F: OnNewSubscription<Index = I, Event = T> + 'static,
+{
+    fn from(value: F) -> Self {
+        let mut manager: Self = Default::default();
+        manager.on_new_subscription = Some(value);
+        manager
+    }
+}
+
+impl<T, I, F> Manager<T, I, F>
+where
+    T: Indexable<Type = I> + Clone + Send + Sync + 'static,
+    I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
+    F: OnNewSubscription<Index = I, Event = T> + 'static,
 {
     #[inline]
     /// Broadcast an event to all listeners
@@ -132,8 +169,29 @@ where
     ) -> ActiveSubscription<T, I> {
         let (sender, receiver) = mpsc::channel(10);
         let sub_id: SubId = params.as_ref().clone();
+
         let indexes: Vec<Index<I>> = params.into();
 
+        if let Some(on_new_subscription) = self.on_new_subscription.as_ref() {
+            match on_new_subscription
+                .on_new_subscription(&indexes.iter().map(|x| x.deref()).collect::<Vec<_>>())
+                .await
+            {
+                Ok(events) => {
+                    for event in events {
+                        let _ = sender.try_send((sub_id.clone(), event));
+                    }
+                }
+                Err(err) => {
+                    tracing::info!(
+                        "Failed to get initial state for subscription: {:?}, {}",
+                        sub_id,
+                        err
+                    );
+                }
+            }
+        }
+
         let mut index_storage = self.indexes.write().await;
         for index in indexes.clone() {
             index_storage.insert(index, sender.clone());
@@ -180,10 +238,11 @@ where
 }
 
 /// Manager goes out of scope, stop all background tasks
-impl<T, I> Drop for Manager<T, I>
+impl<T, I, F> Drop for Manager<T, I, F>
 where
     T: Indexable<Type = I> + Clone + Send + Sync + 'static,
     I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
+    F: OnNewSubscription<Index = I, Event = T> + 'static,
 {
     fn drop(&mut self) {
         if let Some(handler) = self.background_subscription_remover.take() {