浏览代码

fix: make http wallet subscriptions wasm compatible

gudnuf 1 月之前
父节点
当前提交
817dd29c5a
共有 3 个文件被更改,包括 144 次插入77 次删除
  1. 3 0
      crates/cdk/Cargo.toml
  2. 123 64
      crates/cdk/src/wallet/subscription/http.rs
  3. 18 13
      crates/cdk/src/wallet/subscription/mod.rs

+ 3 - 0
crates/cdk/Cargo.toml

@@ -74,6 +74,9 @@ cdk-signatory = { workspace = true, default-features = false }
 getrandom = { version = "0.2", features = ["js"] }
 ring = { version = "0.17.14", features = ["wasm32_unknown_unknown_js"] }
 uuid = { workspace = true, features = ["js"] }
+wasm-bindgen = "0.2"
+wasm-bindgen-futures = "0.4"
+gloo-timers = { version = "0.3", features = ["futures"] }
 
 [[example]]
 name = "mint-token"

+ 123 - 64
crates/cdk/src/wallet/subscription/http.rs

@@ -4,6 +4,7 @@ use std::time::Duration;
 
 use cdk_common::MintQuoteBolt12Response;
 use tokio::sync::{mpsc, RwLock};
+#[cfg(not(target_arch = "wasm32"))]
 use tokio::time;
 
 use super::WsSubscriptionBody;
@@ -84,6 +85,7 @@ async fn convert_subscription(
     Some(())
 }
 
+#[cfg(not(target_arch = "wasm32"))]
 #[inline]
 pub async fn http_main<S: IntoIterator<Item = SubId>>(
     initial_state: S,
@@ -94,7 +96,7 @@ pub async fn http_main<S: IntoIterator<Item = SubId>>(
     _wallet: Arc<Wallet>,
 ) {
     let mut interval = time::interval(Duration::from_secs(2));
-    let mut subscribed_to = HashMap::<UrlType, (mpsc::Sender<_>, _, AnyState)>::new();
+    let mut subscribed_to = SubscribedTo::new();
 
     for sub_id in initial_state {
         convert_subscription(sub_id, &subscriptions, &mut subscribed_to).await;
@@ -103,76 +105,133 @@ pub async fn http_main<S: IntoIterator<Item = SubId>>(
     loop {
         tokio::select! {
             _ = interval.tick() => {
-                for (url, (sender, _, last_state)) in subscribed_to.iter_mut() {
-                    tracing::debug!("Polling: {:?}", url);
-                    match url {
-                        UrlType::MintBolt12(id) => {
-                            let response = http_client.get_mint_quote_bolt12_status(id).await;
-                            if let Ok(response) = response {
-                                if *last_state == AnyState::MintBolt12QuoteState(response.clone()) {
-                                    continue;
-                                }
-                                *last_state = AnyState::MintBolt12QuoteState(response.clone());
-                                if let Err(err) = sender.try_send(NotificationPayload::MintQuoteBolt12Response(response)) {
-                                    tracing::error!("Error sending mint quote response: {:?}", err);
-                                }
-                            }
-                        },
-                        UrlType::Mint(id) => {
+                poll_subscriptions(&http_client, &mut subscribed_to).await;
+            }
+            Some(subid) = new_subscription_recv.recv() => {
+                convert_subscription(subid, &subscriptions, &mut subscribed_to).await;
+            }
+            Some(id) = on_drop.recv() => {
+                subscribed_to.retain(|_, (_, sub_id, _)| *sub_id != id);
+            }
+        }
+    }
+}
 
-                            let response = http_client.get_mint_quote_status(id).await;
-                            if let Ok(response) = response {
-                                if *last_state == AnyState::MintQuoteState(response.state) {
-                                    continue;
-                                }
-                                *last_state = AnyState::MintQuoteState(response.state);
-                                if let Err(err) = sender.try_send(NotificationPayload::MintQuoteBolt11Response(response)) {
-                                    tracing::error!("Error sending mint quote response: {:?}", err);
-                                }
-                            }
-                        }
-                        UrlType::Melt(id) => {
+#[cfg(target_arch = "wasm32")]
+#[inline]
+pub async fn http_main<S: IntoIterator<Item = SubId>>(
+    initial_state: S,
+    http_client: Arc<dyn MintConnector + Send + Sync>,
+    subscriptions: Arc<RwLock<HashMap<SubId, WsSubscriptionBody>>>,
+    mut new_subscription_recv: mpsc::Receiver<SubId>,
+    mut on_drop: mpsc::Receiver<SubId>,
+    _wallet: Arc<Wallet>,
+) {
+    let mut subscribed_to = SubscribedTo::new();
 
-                            let response = http_client.get_melt_quote_status(id).await;
-                            if let Ok(response) = response {
-                                if *last_state == AnyState::MeltQuoteState(response.state) {
-                                    continue;
-                                }
-                                *last_state = AnyState::MeltQuoteState(response.state);
-                                if let Err(err) =  sender.try_send(NotificationPayload::MeltQuoteBolt11Response(response)) {
-                                    tracing::error!("Error sending melt quote response: {:?}", err);
-                                }
-                            }
-                        }
-                        UrlType::PublicKey(id) => {
-                            let responses = http_client.post_check_state(CheckStateRequest {
-                                ys: vec![*id],
-                            }
-                            ).await;
-                            if let Ok(mut responses) = responses {
-                                let response = if let Some(state) = responses.states.pop() {
-                                    state
-                                } else {
-                                    continue;
-                                };
+    for sub_id in initial_state {
+        convert_subscription(sub_id, &subscriptions, &mut subscribed_to).await;
+    }
 
-                                if *last_state == AnyState::PublicKey(response.state) {
-                                    continue;
-                                }
-                                *last_state = AnyState::PublicKey(response.state);
-                                if let Err(err) = sender.try_send(NotificationPayload::ProofState(response)) {
-                                    tracing::error!("Error sending proof state response: {:?}", err);
-                                }
-                            }
-                        }
+    loop {
+        tokio::select! {
+            _ = gloo_timers::future::sleep(Duration::from_secs(2)) => {
+                poll_subscriptions(&http_client, &mut subscribed_to).await;
+            }
+            subid = new_subscription_recv.recv() => {
+                match subid {
+                    Some(subid) => {
+                        convert_subscription(subid, &subscriptions, &mut subscribed_to).await;
+                    }
+                    None => {
+                        // New subscription channel closed - SubscriptionClient was dropped, terminate worker
+                        break;
                     }
                 }
             }
-            Some(subid) = new_subscription_recv.recv() => {
-                convert_subscription(subid, &subscriptions, &mut subscribed_to).await;
+            id = on_drop.recv() => {
+                match id {
+                    Some(id) => {
+                        subscribed_to.retain(|_, (_, sub_id, _)| *sub_id != id);
+                    }
+                    None => {
+                        // Drop notification channel closed - SubscriptionClient was dropped, terminate worker
+                        break;
+                    }
+                }
             }
-            Some(id) = on_drop.recv() => {
-                subscribed_to.retain(|_, (_, sub_id, _)| *sub_id != id);
+        }
+    }
+}
+
+async fn poll_subscriptions(
+    http_client: &Arc<dyn MintConnector + Send + Sync>,
+    subscribed_to: &mut SubscribedTo,
+) {
+    for (url, (sender, _, last_state)) in subscribed_to.iter_mut() {
+        tracing::debug!("Polling: {:?}", url);
+        match url {
+            UrlType::MintBolt12(id) => {
+                let response = http_client.get_mint_quote_bolt12_status(id).await;
+                if let Ok(response) = response {
+                    if *last_state == AnyState::MintBolt12QuoteState(response.clone()) {
+                        continue;
+                    }
+                    *last_state = AnyState::MintBolt12QuoteState(response.clone());
+                    if let Err(err) =
+                        sender.try_send(NotificationPayload::MintQuoteBolt12Response(response))
+                    {
+                        tracing::error!("Error sending mint quote response: {:?}", err);
+                    }
+                }
+            }
+            UrlType::Mint(id) => {
+                let response = http_client.get_mint_quote_status(id).await;
+                if let Ok(response) = response {
+                    if *last_state == AnyState::MintQuoteState(response.state) {
+                        continue;
+                    }
+                    *last_state = AnyState::MintQuoteState(response.state);
+                    if let Err(err) =
+                        sender.try_send(NotificationPayload::MintQuoteBolt11Response(response))
+                    {
+                        tracing::error!("Error sending mint quote response: {:?}", err);
+                    }
+                }
+            }
+            UrlType::Melt(id) => {
+                let response = http_client.get_melt_quote_status(id).await;
+                if let Ok(response) = response {
+                    if *last_state == AnyState::MeltQuoteState(response.state) {
+                        continue;
+                    }
+                    *last_state = AnyState::MeltQuoteState(response.state);
+                    if let Err(err) =
+                        sender.try_send(NotificationPayload::MeltQuoteBolt11Response(response))
+                    {
+                        tracing::error!("Error sending melt quote response: {:?}", err);
+                    }
+                }
+            }
+            UrlType::PublicKey(id) => {
+                let responses = http_client
+                    .post_check_state(CheckStateRequest { ys: vec![*id] })
+                    .await;
+                if let Ok(mut responses) = responses {
+                    let response = if let Some(state) = responses.states.pop() {
+                        state
+                    } else {
+                        continue;
+                    };
+
+                    if *last_state == AnyState::PublicKey(response.state) {
+                        continue;
+                    }
+                    *last_state = AnyState::PublicKey(response.state);
+                    if let Err(err) = sender.try_send(NotificationPayload::ProofState(response)) {
+                        tracing::error!("Error sending proof state response: {:?}", err);
+                    }
+                }
             }
         }
     }

+ 18 - 13
crates/cdk/src/wallet/subscription/mod.rs

@@ -13,6 +13,8 @@ use cdk_common::subscription::Params;
 use tokio::sync::{mpsc, RwLock};
 use tokio::task::JoinHandle;
 use tracing::error;
+#[cfg(target_arch = "wasm32")]
+use wasm_bindgen_futures;
 
 use super::Wallet;
 use crate::mint_url::MintUrl;
@@ -207,7 +209,7 @@ impl SubscriptionClient {
             new_subscription_notif,
             on_drop_notif,
             subscriptions: subscriptions.clone(),
-            worker: Some(Self::start_worker(
+            worker: Self::start_worker(
                 prefer_ws_method,
                 http_client,
                 url,
@@ -215,7 +217,7 @@ impl SubscriptionClient {
                 new_subscription_recv,
                 on_drop_recv,
                 wallet,
-            )),
+            ),
         }
     }
 
@@ -228,7 +230,7 @@ impl SubscriptionClient {
         new_subscription_recv: mpsc::Receiver<SubId>,
         on_drop_recv: mpsc::Receiver<SubId>,
         wallet: Arc<Wallet>,
-    ) -> JoinHandle<()> {
+    ) -> Option<JoinHandle<()>> {
         #[cfg(any(
             feature = "http_subscription",
             not(feature = "mint"),
@@ -293,7 +295,7 @@ impl SubscriptionClient {
         new_subscription_recv: mpsc::Receiver<SubId>,
         on_drop: mpsc::Receiver<SubId>,
         wallet: Arc<Wallet>,
-    ) -> JoinHandle<()> {
+    ) -> Option<JoinHandle<()>> {
         let http_worker = http::http_main(
             vec![],
             http_client,
@@ -304,12 +306,15 @@ impl SubscriptionClient {
         );
 
         #[cfg(target_arch = "wasm32")]
-        let ret = tokio::task::spawn_local(http_worker);
+        {
+            wasm_bindgen_futures::spawn_local(http_worker);
+            None
+        }
 
         #[cfg(not(target_arch = "wasm32"))]
-        let ret = tokio::spawn(http_worker);
-
-        ret
+        {
+            Some(tokio::spawn(http_worker))
+        }
     }
 
     /// WebSocket subscription client
@@ -328,22 +333,22 @@ impl SubscriptionClient {
         new_subscription_recv: mpsc::Receiver<SubId>,
         on_drop: mpsc::Receiver<SubId>,
         wallet: Arc<Wallet>,
-    ) -> JoinHandle<()> {
-        tokio::spawn(ws::ws_main(
+    ) -> Option<JoinHandle<()>> {
+        Some(tokio::spawn(ws::ws_main(
             http_client,
             url,
             subscriptions,
             new_subscription_recv,
             on_drop,
             wallet,
-        ))
+        )))
     }
 }
 
 impl Drop for SubscriptionClient {
     fn drop(&mut self) {
-        if let Some(sender) = self.worker.take() {
-            sender.abort();
+        if let Some(handle) = self.worker.take() {
+            handle.abort();
         }
     }
 }