소스 검색

Add support for ws in wasm

Cesar Rodas 4 일 전
부모
커밋
f913187852

+ 3 - 1
Cargo.lock

@@ -1194,6 +1194,7 @@ dependencies = [
  "cbor-diag",
  "cdk-common",
  "cdk-fake-wallet",
+ "cdk-http-client",
  "cdk-npubcash",
  "cdk-prometheus",
  "cdk-signatory",
@@ -1221,7 +1222,6 @@ dependencies = [
  "tls-api",
  "tls-api-native-tls",
  "tokio",
- "tokio-tungstenite 0.26.2",
  "tokio-util",
  "tor-rtcompat",
  "tracing",
@@ -1386,6 +1386,7 @@ dependencies = [
 name = "cdk-http-client"
 version = "0.14.0"
 dependencies = [
+ "futures",
  "mockito",
  "regex",
  "reqwest",
@@ -1393,6 +1394,7 @@ dependencies = [
  "serde_json",
  "thiserror 2.0.18",
  "tokio",
+ "tokio-tungstenite 0.26.2",
  "url",
 ]
 

+ 3 - 0
crates/cdk-http-client/Cargo.toml

@@ -19,6 +19,9 @@ url.workspace = true
 [target.'cfg(not(target_arch = "wasm32"))'.dependencies]
 reqwest = { workspace = true }
 regex = { workspace = true }
+tokio-tungstenite = { workspace = true, features = ["rustls", "rustls-tls-native-roots", "connect"] }
+futures = { workspace = true }
+tokio = { workspace = true }
 
 [target.'cfg(target_arch = "wasm32")'.dependencies]
 reqwest = { version = "0.12", default-features = false, features = ["json"] }

+ 6 - 0
crates/cdk-http-client/src/lib.rs

@@ -25,7 +25,13 @@ mod error;
 mod request;
 mod response;
 
+#[cfg(not(target_arch = "wasm32"))]
+mod ws;
+
 pub use client::{fetch, HttpClient, HttpClientBuilder};
 pub use error::HttpError;
 pub use request::RequestBuilder;
 pub use response::{RawResponse, Response};
+
+#[cfg(not(target_arch = "wasm32"))]
+pub use ws::{WsClient, WsClientBuilder, WsConnection, WsError, WsReceiver, WsSender};

+ 118 - 0
crates/cdk-http-client/src/ws/error.rs

@@ -0,0 +1,118 @@
+//! WebSocket error types
+
+use thiserror::Error;
+
+/// WebSocket errors that can occur during connections and message handling
+#[derive(Debug, Error)]
+pub enum WsError {
+    /// Connection error
+    #[error("WebSocket connection error: {0}")]
+    Connection(String),
+
+    /// Error sending a message
+    #[error("WebSocket send error: {0}")]
+    Send(String),
+
+    /// Error receiving a message
+    #[error("WebSocket receive error: {0}")]
+    Receive(String),
+
+    /// Error closing the connection
+    #[error("WebSocket close error: {0}")]
+    Close(String),
+
+    /// Invalid URL
+    #[error("Invalid WebSocket URL: {0}")]
+    InvalidUrl(String),
+
+    /// Client build error
+    #[error("WebSocket client build error: {0}")]
+    Build(String),
+
+    /// Protocol error
+    #[error("WebSocket protocol error: {0}")]
+    Protocol(String),
+}
+
+impl From<tokio_tungstenite::tungstenite::Error> for WsError {
+    fn from(err: tokio_tungstenite::tungstenite::Error) -> Self {
+        use tokio_tungstenite::tungstenite::Error;
+
+        match err {
+            Error::ConnectionClosed | Error::AlreadyClosed => {
+                WsError::Connection("Connection closed".to_string())
+            }
+            Error::Io(e) => WsError::Connection(e.to_string()),
+            Error::Tls(e) => WsError::Connection(format!("TLS error: {}", e)),
+            Error::Protocol(e) => WsError::Protocol(e.to_string()),
+            Error::Url(e) => WsError::InvalidUrl(e.to_string()),
+            Error::Http(response) => {
+                WsError::Connection(format!("HTTP error: status {}", response.status()))
+            }
+            Error::HttpFormat(e) => WsError::Connection(format!("HTTP format error: {}", e)),
+            _ => WsError::Connection(err.to_string()),
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_ws_error_connection_display() {
+        let error = WsError::Connection("connection refused".to_string());
+        assert_eq!(
+            format!("{}", error),
+            "WebSocket connection error: connection refused"
+        );
+    }
+
+    #[test]
+    fn test_ws_error_send_display() {
+        let error = WsError::Send("buffer full".to_string());
+        assert_eq!(format!("{}", error), "WebSocket send error: buffer full");
+    }
+
+    #[test]
+    fn test_ws_error_receive_display() {
+        let error = WsError::Receive("stream ended".to_string());
+        assert_eq!(
+            format!("{}", error),
+            "WebSocket receive error: stream ended"
+        );
+    }
+
+    #[test]
+    fn test_ws_error_close_display() {
+        let error = WsError::Close("abrupt close".to_string());
+        assert_eq!(format!("{}", error), "WebSocket close error: abrupt close");
+    }
+
+    #[test]
+    fn test_ws_error_invalid_url_display() {
+        let error = WsError::InvalidUrl("missing scheme".to_string());
+        assert_eq!(
+            format!("{}", error),
+            "Invalid WebSocket URL: missing scheme"
+        );
+    }
+
+    #[test]
+    fn test_ws_error_build_display() {
+        let error = WsError::Build("invalid config".to_string());
+        assert_eq!(
+            format!("{}", error),
+            "WebSocket client build error: invalid config"
+        );
+    }
+
+    #[test]
+    fn test_ws_error_protocol_display() {
+        let error = WsError::Protocol("invalid frame".to_string());
+        assert_eq!(
+            format!("{}", error),
+            "WebSocket protocol error: invalid frame"
+        );
+    }
+}

+ 349 - 0
crates/cdk-http-client/src/ws/mod.rs

@@ -0,0 +1,349 @@
+//! WebSocket client abstraction
+//!
+//! This module provides a WebSocket client that abstracts the underlying
+//! tokio-tungstenite library, providing a clean API for establishing
+//! WebSocket connections with custom headers.
+
+mod error;
+
+pub use error::WsError;
+
+use futures::stream::{SplitSink, SplitStream};
+use futures::{SinkExt, StreamExt};
+use tokio::net::TcpStream;
+use tokio_tungstenite::tungstenite::client::IntoClientRequest;
+use tokio_tungstenite::tungstenite::http::HeaderValue;
+use tokio_tungstenite::tungstenite::Message;
+use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
+
+type InnerStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
+
+/// WebSocket client for establishing connections
+///
+/// This client provides a simple API for connecting to WebSocket servers
+/// with optional custom headers (e.g., for authentication).
+#[derive(Debug, Clone)]
+pub struct WsClient {
+    headers: Vec<(String, String)>,
+    #[allow(dead_code)]
+    accept_invalid_certs: bool,
+}
+
+impl Default for WsClient {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl WsClient {
+    /// Create a new WebSocket client with default settings
+    pub fn new() -> Self {
+        Self {
+            headers: Vec::new(),
+            accept_invalid_certs: false,
+        }
+    }
+
+    /// Create a new WebSocket client builder
+    pub fn builder() -> WsClientBuilder {
+        WsClientBuilder::default()
+    }
+
+    /// Connect to a WebSocket server
+    ///
+    /// The URL can be in http/https or ws/wss format.
+    /// HTTP URLs are automatically converted to WebSocket URLs:
+    /// - `http://` → `ws://`
+    /// - `https://` → `wss://`
+    pub async fn connect(&self, url: &str) -> Result<WsConnection, WsError> {
+        let ws_url = convert_url_scheme(url)?;
+
+        let mut request = ws_url
+            .into_client_request()
+            .map_err(|e| WsError::InvalidUrl(format!("Failed to create request: {}", e)))?;
+
+        // Add custom headers
+        for (key, value) in &self.headers {
+            let header_name = key
+                .parse::<tokio_tungstenite::tungstenite::http::header::HeaderName>()
+                .map_err(|e| WsError::Build(format!("Invalid header name: {}", e)))?;
+            let header_value = HeaderValue::from_str(value)
+                .map_err(|e| WsError::Build(format!("Invalid header value: {}", e)))?;
+            request.headers_mut().insert(header_name, header_value);
+        }
+
+        let (ws_stream, _response) = connect_async(request).await?;
+
+        Ok(WsConnection { inner: ws_stream })
+    }
+}
+
+/// Builder for configuring a WebSocket client
+#[derive(Debug, Default)]
+pub struct WsClientBuilder {
+    headers: Vec<(String, String)>,
+    accept_invalid_certs: bool,
+}
+
+impl WsClientBuilder {
+    /// Add a header to be sent with the WebSocket upgrade request
+    ///
+    /// This is useful for authentication headers like `Clear-auth` or `Blind-auth`.
+    pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
+        self.headers.push((key.into(), value.into()));
+        self
+    }
+
+    /// Accept invalid TLS certificates
+    ///
+    /// **Warning**: This should only be used for testing purposes.
+    /// Using this in production is a security risk.
+    pub fn danger_accept_invalid_certs(mut self, accept: bool) -> Self {
+        self.accept_invalid_certs = accept;
+        self
+    }
+
+    /// Build the WebSocket client
+    pub fn build(self) -> Result<WsClient, WsError> {
+        Ok(WsClient {
+            headers: self.headers,
+            accept_invalid_certs: self.accept_invalid_certs,
+        })
+    }
+}
+
+/// Active WebSocket connection
+///
+/// Supports sending and receiving text messages.
+/// Can be split into separate sender and receiver halves for concurrent operations.
+#[derive(Debug)]
+pub struct WsConnection {
+    inner: InnerStream,
+}
+
+impl WsConnection {
+    /// Send a text message
+    pub async fn send(&mut self, msg: &str) -> Result<(), WsError> {
+        self.inner
+            .send(Message::Text(msg.to_string().into()))
+            .await
+            .map_err(|e| WsError::Send(e.to_string()))
+    }
+
+    /// Receive the next text message
+    ///
+    /// Returns `None` if the connection is closed.
+    /// Non-text messages (ping/pong/binary/close) are handled automatically.
+    pub async fn recv(&mut self) -> Result<Option<String>, WsError> {
+        loop {
+            match self.inner.next().await {
+                Some(Ok(Message::Text(text))) => return Ok(Some(text.to_string())),
+                Some(Ok(Message::Close(_))) => return Ok(None),
+                Some(Ok(Message::Ping(data))) => {
+                    // Respond to ping with pong
+                    self.inner
+                        .send(Message::Pong(data))
+                        .await
+                        .map_err(|e| WsError::Send(e.to_string()))?;
+                }
+                Some(Ok(Message::Pong(_))) => {
+                    // Ignore pong messages
+                    continue;
+                }
+                Some(Ok(Message::Binary(_))) => {
+                    // Skip binary messages (we only support text)
+                    continue;
+                }
+                Some(Ok(Message::Frame(_))) => {
+                    // Skip raw frames
+                    continue;
+                }
+                Some(Err(e)) => return Err(WsError::Receive(e.to_string())),
+                None => return Ok(None),
+            }
+        }
+    }
+
+    /// Split the connection into separate sender and receiver halves
+    ///
+    /// This allows concurrent sending and receiving on the same connection.
+    pub fn split(self) -> (WsSender, WsReceiver) {
+        let (sink, stream) = self.inner.split();
+        (WsSender { inner: sink }, WsReceiver { inner: stream })
+    }
+
+    /// Close the connection gracefully
+    pub async fn close(mut self) -> Result<(), WsError> {
+        self.inner
+            .send(Message::Close(None))
+            .await
+            .map_err(|e| WsError::Close(e.to_string()))
+    }
+}
+
+/// Send half of a split WebSocket connection
+#[derive(Debug)]
+pub struct WsSender {
+    inner: SplitSink<InnerStream, Message>,
+}
+
+impl WsSender {
+    /// Send a text message
+    pub async fn send(&mut self, msg: &str) -> Result<(), WsError> {
+        self.inner
+            .send(Message::Text(msg.to_string().into()))
+            .await
+            .map_err(|e| WsError::Send(e.to_string()))
+    }
+
+    /// Close the connection gracefully
+    pub async fn close(mut self) -> Result<(), WsError> {
+        self.inner
+            .send(Message::Close(None))
+            .await
+            .map_err(|e| WsError::Close(e.to_string()))
+    }
+}
+
+/// Receive half of a split WebSocket connection
+#[derive(Debug)]
+pub struct WsReceiver {
+    inner: SplitStream<InnerStream>,
+}
+
+impl WsReceiver {
+    /// Receive the next text message
+    ///
+    /// Returns `None` if the connection is closed.
+    /// Non-text messages (ping/pong/binary/close) are skipped.
+    pub async fn recv(&mut self) -> Result<Option<String>, WsError> {
+        loop {
+            match self.inner.next().await {
+                Some(Ok(Message::Text(text))) => return Ok(Some(text.to_string())),
+                Some(Ok(Message::Close(_))) => return Ok(None),
+                Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {
+                    // Note: When split, we can't respond to pings from the receiver
+                    // The sender should be used to send pongs if needed
+                    continue;
+                }
+                Some(Ok(Message::Binary(_))) | Some(Ok(Message::Frame(_))) => {
+                    continue;
+                }
+                Some(Err(e)) => return Err(WsError::Receive(e.to_string())),
+                None => return Ok(None),
+            }
+        }
+    }
+}
+
+/// Convert HTTP/HTTPS URLs to WebSocket URLs
+///
+/// - `http://` → `ws://`
+/// - `https://` → `wss://`
+/// - `ws://` and `wss://` are left unchanged
+fn convert_url_scheme(url: &str) -> Result<String, WsError> {
+    if url.starts_with("http://") {
+        Ok(url.replacen("http://", "ws://", 1))
+    } else if url.starts_with("https://") {
+        Ok(url.replacen("https://", "wss://", 1))
+    } else if url.starts_with("ws://") || url.starts_with("wss://") {
+        Ok(url.to_string())
+    } else {
+        Err(WsError::InvalidUrl(format!(
+            "URL must start with http://, https://, ws://, or wss://: {}",
+            url
+        )))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_url_scheme_conversion_http_to_ws() {
+        let result = convert_url_scheme("http://example.com/ws");
+        assert!(result.is_ok());
+        assert_eq!(result.expect("Should succeed"), "ws://example.com/ws");
+    }
+
+    #[test]
+    fn test_url_scheme_conversion_https_to_wss() {
+        let result = convert_url_scheme("https://example.com/ws");
+        assert!(result.is_ok());
+        assert_eq!(result.expect("Should succeed"), "wss://example.com/ws");
+    }
+
+    #[test]
+    fn test_url_scheme_conversion_ws_unchanged() {
+        let result = convert_url_scheme("ws://example.com/ws");
+        assert!(result.is_ok());
+        assert_eq!(result.expect("Should succeed"), "ws://example.com/ws");
+    }
+
+    #[test]
+    fn test_url_scheme_conversion_wss_unchanged() {
+        let result = convert_url_scheme("wss://example.com/ws");
+        assert!(result.is_ok());
+        assert_eq!(result.expect("Should succeed"), "wss://example.com/ws");
+    }
+
+    #[test]
+    fn test_url_scheme_conversion_invalid() {
+        let result = convert_url_scheme("ftp://example.com/ws");
+        assert!(result.is_err());
+        if let Err(WsError::InvalidUrl(msg)) = result {
+            assert!(msg.contains("must start with"));
+        } else {
+            panic!("Expected WsError::InvalidUrl");
+        }
+    }
+
+    #[test]
+    fn test_client_new() {
+        let client = WsClient::new();
+        assert!(client.headers.is_empty());
+    }
+
+    #[test]
+    fn test_client_default() {
+        let client = WsClient::default();
+        assert!(client.headers.is_empty());
+    }
+
+    #[test]
+    fn test_builder_default() {
+        let builder = WsClientBuilder::default();
+        let client = builder.build();
+        assert!(client.is_ok());
+    }
+
+    #[test]
+    fn test_builder_with_headers() {
+        let client = WsClient::builder()
+            .header("Authorization", "Bearer token")
+            .header("X-Custom", "value")
+            .build()
+            .expect("Should build successfully");
+
+        assert_eq!(client.headers.len(), 2);
+        assert_eq!(
+            client.headers[0],
+            ("Authorization".to_string(), "Bearer token".to_string())
+        );
+        assert_eq!(
+            client.headers[1],
+            ("X-Custom".to_string(), "value".to_string())
+        );
+    }
+
+    #[test]
+    fn test_builder_accept_invalid_certs() {
+        let client = WsClient::builder()
+            .danger_accept_invalid_certs(true)
+            .build()
+            .expect("Should build successfully");
+
+        assert!(client.accept_invalid_certs);
+    }
+}

+ 1 - 5
crates/cdk/Cargo.toml

@@ -74,11 +74,7 @@ tokio = { workspace = true, features = [
 ] }
 getrandom = { version = "0.2" }
 cdk-signatory = { workspace = true, features = ["grpc"], optional = true }
-tokio-tungstenite = { workspace = true, features = [
-    "rustls",
-    "rustls-tls-native-roots",
-    "connect"
-] }
+cdk-http-client.workspace = true
 # Tor dependencies (optional; enabled by feature "tor")
 hyper = { version = "0.14", optional = true, features = ["client", "http1", "http2"] }
 http = { version = "0.2", optional = true }

+ 32 - 54
crates/cdk/src/wallet/subscription/ws.rs

@@ -3,11 +3,8 @@ use cdk_common::pub_sub::remote_consumer::{InternalRelay, StreamCtrl, SubscribeM
 use cdk_common::pub_sub::Error as PubsubError;
 #[cfg(feature = "auth")]
 use cdk_common::{Method, RoutePath};
-use futures::{SinkExt, StreamExt};
+use cdk_http_client::{WsClient, WsError};
 use tokio::sync::mpsc;
-use tokio_tungstenite::connect_async;
-use tokio_tungstenite::tungstenite::client::IntoClientRequest;
-use tokio_tungstenite::tungstenite::Message;
 
 use super::{MintSubTopics, SubscriptionClient};
 
@@ -18,33 +15,19 @@ pub(crate) async fn stream_client(
     topics: Vec<SubscribeMessage<MintSubTopics>>,
     reply_to: InternalRelay<MintSubTopics>,
 ) -> Result<(), PubsubError> {
-    let mut url = client
+    let url = client
         .mint_url
         .join_paths(&["v1", "ws"])
         .expect("Could not join paths");
 
-    if url.scheme() == "https" {
-        url.set_scheme("wss").expect("Could not set scheme");
-    } else {
-        url.set_scheme("ws").expect("Could not set scheme");
-    }
-
+    // Build WsClient with auth headers if enabled
     #[cfg(not(feature = "auth"))]
-    let request = url.to_string().into_client_request().map_err(|err| {
-        tracing::error!("Failed to create client request: {:?}", err);
-        // Fallback to HTTP client if we can't create the WebSocket request
-        cdk_common::pub_sub::Error::NotSupported
-    })?;
+    let ws_client = WsClient::new();
 
     #[cfg(feature = "auth")]
-    let mut request = url.to_string().into_client_request().map_err(|err| {
-        tracing::error!("Failed to create client request: {:?}", err);
-        // Fallback to HTTP client if we can't create the WebSocket request
-        cdk_common::pub_sub::Error::NotSupported
-    })?;
+    let ws_client = {
+        let mut builder = WsClient::builder();
 
-    #[cfg(feature = "auth")]
-    {
         let auth_wallet = client.http_client.get_auth_wallet().await;
         let token = match auth_wallet.as_ref() {
             Some(auth_wallet) => {
@@ -66,29 +49,26 @@ pub(crate) async fn stream_client(
                 cdk_common::AuthToken::BlindAuth(_) => "Blind-auth",
             };
 
-            match auth_token.to_string().parse() {
-                Ok(header_value) => {
-                    request.headers_mut().insert(header_key, header_value);
-                }
-                Err(err) => {
-                    tracing::warn!("Failed to parse auth token as header value: {:?}", err);
-                }
-            }
+            builder = builder.header(header_key, auth_token.to_string());
         }
-    }
+
+        builder.build().map_err(|e: WsError| {
+            tracing::error!("Failed to build WsClient: {:?}", e);
+            PubsubError::NotSupported
+        })?
+    };
 
     tracing::debug!("Connecting to {}", url);
-    let ws_stream = connect_async(request)
+    let ws_conn = ws_client
+        .connect(url.as_str())
         .await
-        .map(|(ws_stream, _)| ws_stream)
-        .map_err(|err| {
+        .map_err(|err: WsError| {
             tracing::error!("Error connecting: {err:?}");
-
-            cdk_common::pub_sub::Error::Internal(Box::new(err))
+            PubsubError::Internal(Box::new(err))
         })?;
 
     tracing::debug!("Connected to {}", url);
-    let (mut write, mut read) = ws_stream.split();
+    let (mut write, mut read) = ws_conn.split();
 
     for (name, index) in topics {
         let (_, req) = if let Some(req) = client.get_sub_request(name, index) {
@@ -97,7 +77,7 @@ pub(crate) async fn stream_client(
             continue;
         };
 
-        let _ = write.send(Message::Text(req.into())).await;
+        let _ = write.send(&req).await;
     }
 
     loop {
@@ -110,7 +90,7 @@ pub(crate) async fn stream_client(
                         } else {
                             continue;
                         };
-                        let _ = write.send(Message::Text(req.into())).await;
+                        let _ = write.send(&req).await;
                     }
                     StreamCtrl::Unsubscribe(msg) => {
                         let req = if let Some(req) = client.get_unsub_request(msg) {
@@ -118,30 +98,29 @@ pub(crate) async fn stream_client(
                         } else {
                             continue;
                         };
-                        let _ = write.send(Message::Text(req.into())).await;
+                        let _ = write.send(&req).await;
                     }
                     StreamCtrl::Stop => {
-                        if let Err(err) = write.send(Message::Close(None)).await {
+                        if let Err(err) = write.close().await {
                             tracing::error!("Closing error {err:?}");
                         }
                         break;
                     }
                 };
             }
-            Some(msg) = read.next() => {
-                let msg = match msg {
-                    Ok(msg) => msg,
-                    Err(_) => {
-                        if let Err(err) = write.send(Message::Close(None)).await {
-                            tracing::error!("Closing error {err:?}");
-                        }
+            msg_result = read.recv() => {
+                let msg = match msg_result {
+                    Ok(Some(msg)) => msg,
+                    Ok(None) => {
+                        // Connection closed
+                        break;
+                    }
+                    Err(err) => {
+                        tracing::error!("Receive error: {err:?}");
                         break;
                     }
                 };
-                let msg = match msg {
-                    Message::Text(msg) => msg,
-                    _ => continue,
-                };
+
                 let msg = match serde_json::from_str::<WsMessageOrResponse<String>>(&msg) {
                     Ok(msg) => msg,
                     Err(_) => continue,
@@ -159,7 +138,6 @@ pub(crate) async fn stream_client(
                         return Err(PubsubError::InternalStr(error.error.message));
                     }
                 }
-
             }
         }
     }