Browse Source

Send current state of the subscriptions

Cesar Rodas 4 months ago
parent
commit
7fc962025b
2 changed files with 109 additions and 2 deletions
  1. 97 2
      crates/cdk-axum/src/ws/subscribe.rs
  2. 12 0
      crates/cdk-integration-tests/tests/regtest.rs

+ 97 - 2
crates/cdk-axum/src/ws/subscribe.rs

@@ -3,7 +3,10 @@ use super::{
     WsContext, WsError, JSON_RPC_VERSION,
 };
 use cdk::{
-    nuts::nut17::{NotificationPayload, Params},
+    nuts::{
+        nut17::{Kind, NotificationPayload, Params},
+        MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey,
+    },
     pub_sub::SubId,
 };
 
@@ -11,17 +14,26 @@ use cdk::{
 pub struct Method(Params);
 
 #[derive(Debug, Clone, serde::Serialize)]
+/// The response to a subscription request
 pub struct Response {
+    /// Status
     status: String,
+    /// Subscription ID
     #[serde(rename = "subId")]
     sub_id: SubId,
 }
 
 #[derive(Debug, Clone, serde::Serialize)]
+/// The notification
+///
+/// This is the notification that is sent to the client when an event matches a
+/// subscription
 pub struct Notification {
+    /// The subscription ID
     #[serde(rename = "subId")]
     pub sub_id: SubId,
 
+    /// The notification payload
     pub payload: NotificationPayload,
 }
 
@@ -39,13 +51,96 @@ impl From<(SubId, NotificationPayload)> for WsNotification<Notification> {
 impl WsHandle for Method {
     type Response = Response;
 
+    /// The `handle` method is called when a client sends a subscription request
     async fn handle(self, context: &mut WsContext) -> Result<Self::Response, WsError> {
         let sub_id = self.0.id.clone();
         if context.subscriptions.contains_key(&sub_id) {
+            // Subscription ID already exits. Returns an error instead of
+            // replacing the other subscription or avoiding it.
             return Err(WsError::InvalidParams);
         }
-        let mut subscription = context.state.mint.pubsub_manager.subscribe(self.0).await;
+
+        let mut subscription = context
+            .state
+            .mint
+            .pubsub_manager
+            .subscribe(self.0.clone())
+            .await;
         let publisher = context.publisher.clone();
+
+        let current_notification_to_send: Vec<NotificationPayload> = match self.0.kind {
+            Kind::Bolt11MeltQuote => {
+                let queries = self
+                    .0
+                    .filters
+                    .iter()
+                    .map(|id| context.state.mint.localstore.get_melt_quote(id))
+                    .collect::<Vec<_>>();
+
+                futures::future::try_join_all(queries)
+                    .await
+                    .map(|quotes| {
+                        quotes
+                            .into_iter()
+                            .filter_map(|quote| quote.map(|x| x.into()))
+                            .map(|x: MeltQuoteBolt11Response| x.into())
+                            .collect::<Vec<_>>()
+                    })
+                    .unwrap_or_default()
+            }
+            Kind::Bolt11MintQuote => {
+                let queries = self
+                    .0
+                    .filters
+                    .iter()
+                    .map(|id| context.state.mint.localstore.get_mint_quote(id))
+                    .collect::<Vec<_>>();
+
+                futures::future::try_join_all(queries)
+                    .await
+                    .map(|quotes| {
+                        quotes
+                            .into_iter()
+                            .filter_map(|quote| quote.map(|x| x.into()))
+                            .map(|x: MintQuoteBolt11Response| x.into())
+                            .collect::<Vec<_>>()
+                    })
+                    .unwrap_or_default()
+            }
+            Kind::ProofState => {
+                if let Ok(public_keys) = self
+                    .0
+                    .filters
+                    .iter()
+                    .map(PublicKey::from_hex)
+                    .collect::<Result<Vec<PublicKey>, _>>()
+                {
+                    context
+                        .state
+                        .mint
+                        .localstore
+                        .get_proofs_states(&public_keys)
+                        .await
+                        .map(|x| {
+                            x.into_iter()
+                                .enumerate()
+                                .filter_map(|(idx, state)| {
+                                    state.map(|state| (public_keys[idx], state).into())
+                                })
+                                .map(|x: ProofState| x.into())
+                                .collect::<Vec<_>>()
+                        })
+                        .unwrap_or_default()
+                } else {
+                    vec![]
+                }
+            }
+        };
+
+        for notification in current_notification_to_send.into_iter() {
+            let _ = publisher.send((sub_id.clone(), notification)).await;
+        }
+
         context.subscriptions.insert(
             sub_id.clone(),
             tokio::spawn(async move {

+ 12 - 0
crates/cdk-integration-tests/tests/regtest.rs

@@ -113,6 +113,18 @@ async fn test_regtest_mint_melt_round_trip() -> Result<()> {
     assert!(melt_response.state == MeltQuoteState::Paid);
 
     let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await;
+    // first message is the current state
+    assert_eq!("test-sub", sub_id);
+    let payload = match payload {
+        NotificationPayload::MeltQuoteBolt11Response(melt) => melt,
+        _ => panic!("Wrong payload"),
+    };
+    assert_eq!(payload.amount + payload.fee_reserve, 100.into());
+    assert_eq!(payload.quote, melt.id);
+    assert_eq!(payload.state, MeltQuoteState::Unpaid);
+
+    // get current state
+    let (sub_id, payload) = get_notification(&mut reader, Duration::from_millis(15000)).await;
     assert_eq!("test-sub", sub_id);
     let payload = match payload {
         NotificationPayload::MeltQuoteBolt11Response(melt) => melt,