|
@@ -0,0 +1,349 @@
|
|
|
|
|
+//! WebSocket client abstraction
|
|
|
|
|
+//!
|
|
|
|
|
+//! This module provides a WebSocket client that abstracts the underlying
|
|
|
|
|
+//! tokio-tungstenite library, providing a clean API for establishing
|
|
|
|
|
+//! WebSocket connections with custom headers.
|
|
|
|
|
+
|
|
|
|
|
+mod error;
|
|
|
|
|
+
|
|
|
|
|
+pub use error::WsError;
|
|
|
|
|
+
|
|
|
|
|
+use futures::stream::{SplitSink, SplitStream};
|
|
|
|
|
+use futures::{SinkExt, StreamExt};
|
|
|
|
|
+use tokio::net::TcpStream;
|
|
|
|
|
+use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
|
|
|
|
+use tokio_tungstenite::tungstenite::http::HeaderValue;
|
|
|
|
|
+use tokio_tungstenite::tungstenite::Message;
|
|
|
|
|
+use tokio_tungstenite::{connect_async, MaybeTlsStream, WebSocketStream};
|
|
|
|
|
+
|
|
|
|
|
+type InnerStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
|
|
|
|
+
|
|
|
|
|
+/// WebSocket client for establishing connections
|
|
|
|
|
+///
|
|
|
|
|
+/// This client provides a simple API for connecting to WebSocket servers
|
|
|
|
|
+/// with optional custom headers (e.g., for authentication).
|
|
|
|
|
+#[derive(Debug, Clone)]
|
|
|
|
|
+pub struct WsClient {
|
|
|
|
|
+ headers: Vec<(String, String)>,
|
|
|
|
|
+ #[allow(dead_code)]
|
|
|
|
|
+ accept_invalid_certs: bool,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl Default for WsClient {
|
|
|
|
|
+ fn default() -> Self {
|
|
|
|
|
+ Self::new()
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl WsClient {
|
|
|
|
|
+ /// Create a new WebSocket client with default settings
|
|
|
|
|
+ pub fn new() -> Self {
|
|
|
|
|
+ Self {
|
|
|
|
|
+ headers: Vec::new(),
|
|
|
|
|
+ accept_invalid_certs: false,
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Create a new WebSocket client builder
|
|
|
|
|
+ pub fn builder() -> WsClientBuilder {
|
|
|
|
|
+ WsClientBuilder::default()
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Connect to a WebSocket server
|
|
|
|
|
+ ///
|
|
|
|
|
+ /// The URL can be in http/https or ws/wss format.
|
|
|
|
|
+ /// HTTP URLs are automatically converted to WebSocket URLs:
|
|
|
|
|
+ /// - `http://` → `ws://`
|
|
|
|
|
+ /// - `https://` → `wss://`
|
|
|
|
|
+ pub async fn connect(&self, url: &str) -> Result<WsConnection, WsError> {
|
|
|
|
|
+ let ws_url = convert_url_scheme(url)?;
|
|
|
|
|
+
|
|
|
|
|
+ let mut request = ws_url
|
|
|
|
|
+ .into_client_request()
|
|
|
|
|
+ .map_err(|e| WsError::InvalidUrl(format!("Failed to create request: {}", e)))?;
|
|
|
|
|
+
|
|
|
|
|
+ // Add custom headers
|
|
|
|
|
+ for (key, value) in &self.headers {
|
|
|
|
|
+ let header_name = key
|
|
|
|
|
+ .parse::<tokio_tungstenite::tungstenite::http::header::HeaderName>()
|
|
|
|
|
+ .map_err(|e| WsError::Build(format!("Invalid header name: {}", e)))?;
|
|
|
|
|
+ let header_value = HeaderValue::from_str(value)
|
|
|
|
|
+ .map_err(|e| WsError::Build(format!("Invalid header value: {}", e)))?;
|
|
|
|
|
+ request.headers_mut().insert(header_name, header_value);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ let (ws_stream, _response) = connect_async(request).await?;
|
|
|
|
|
+
|
|
|
|
|
+ Ok(WsConnection { inner: ws_stream })
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+/// Builder for configuring a WebSocket client
|
|
|
|
|
+#[derive(Debug, Default)]
|
|
|
|
|
+pub struct WsClientBuilder {
|
|
|
|
|
+ headers: Vec<(String, String)>,
|
|
|
|
|
+ accept_invalid_certs: bool,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl WsClientBuilder {
|
|
|
|
|
+ /// Add a header to be sent with the WebSocket upgrade request
|
|
|
|
|
+ ///
|
|
|
|
|
+ /// This is useful for authentication headers like `Clear-auth` or `Blind-auth`.
|
|
|
|
|
+ pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
|
|
|
|
+ self.headers.push((key.into(), value.into()));
|
|
|
|
|
+ self
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Accept invalid TLS certificates
|
|
|
|
|
+ ///
|
|
|
|
|
+ /// **Warning**: This should only be used for testing purposes.
|
|
|
|
|
+ /// Using this in production is a security risk.
|
|
|
|
|
+ pub fn danger_accept_invalid_certs(mut self, accept: bool) -> Self {
|
|
|
|
|
+ self.accept_invalid_certs = accept;
|
|
|
|
|
+ self
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Build the WebSocket client
|
|
|
|
|
+ pub fn build(self) -> Result<WsClient, WsError> {
|
|
|
|
|
+ Ok(WsClient {
|
|
|
|
|
+ headers: self.headers,
|
|
|
|
|
+ accept_invalid_certs: self.accept_invalid_certs,
|
|
|
|
|
+ })
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+/// Active WebSocket connection
|
|
|
|
|
+///
|
|
|
|
|
+/// Supports sending and receiving text messages.
|
|
|
|
|
+/// Can be split into separate sender and receiver halves for concurrent operations.
|
|
|
|
|
+#[derive(Debug)]
|
|
|
|
|
+pub struct WsConnection {
|
|
|
|
|
+ inner: InnerStream,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl WsConnection {
|
|
|
|
|
+ /// Send a text message
|
|
|
|
|
+ pub async fn send(&mut self, msg: &str) -> Result<(), WsError> {
|
|
|
|
|
+ self.inner
|
|
|
|
|
+ .send(Message::Text(msg.to_string().into()))
|
|
|
|
|
+ .await
|
|
|
|
|
+ .map_err(|e| WsError::Send(e.to_string()))
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Receive the next text message
|
|
|
|
|
+ ///
|
|
|
|
|
+ /// Returns `None` if the connection is closed.
|
|
|
|
|
+ /// Non-text messages (ping/pong/binary/close) are handled automatically.
|
|
|
|
|
+ pub async fn recv(&mut self) -> Result<Option<String>, WsError> {
|
|
|
|
|
+ loop {
|
|
|
|
|
+ match self.inner.next().await {
|
|
|
|
|
+ Some(Ok(Message::Text(text))) => return Ok(Some(text.to_string())),
|
|
|
|
|
+ Some(Ok(Message::Close(_))) => return Ok(None),
|
|
|
|
|
+ Some(Ok(Message::Ping(data))) => {
|
|
|
|
|
+ // Respond to ping with pong
|
|
|
|
|
+ self.inner
|
|
|
|
|
+ .send(Message::Pong(data))
|
|
|
|
|
+ .await
|
|
|
|
|
+ .map_err(|e| WsError::Send(e.to_string()))?;
|
|
|
|
|
+ }
|
|
|
|
|
+ Some(Ok(Message::Pong(_))) => {
|
|
|
|
|
+ // Ignore pong messages
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ Some(Ok(Message::Binary(_))) => {
|
|
|
|
|
+ // Skip binary messages (we only support text)
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ Some(Ok(Message::Frame(_))) => {
|
|
|
|
|
+ // Skip raw frames
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ Some(Err(e)) => return Err(WsError::Receive(e.to_string())),
|
|
|
|
|
+ None => return Ok(None),
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Split the connection into separate sender and receiver halves
|
|
|
|
|
+ ///
|
|
|
|
|
+ /// This allows concurrent sending and receiving on the same connection.
|
|
|
|
|
+ pub fn split(self) -> (WsSender, WsReceiver) {
|
|
|
|
|
+ let (sink, stream) = self.inner.split();
|
|
|
|
|
+ (WsSender { inner: sink }, WsReceiver { inner: stream })
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Close the connection gracefully
|
|
|
|
|
+ pub async fn close(mut self) -> Result<(), WsError> {
|
|
|
|
|
+ self.inner
|
|
|
|
|
+ .send(Message::Close(None))
|
|
|
|
|
+ .await
|
|
|
|
|
+ .map_err(|e| WsError::Close(e.to_string()))
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+/// Send half of a split WebSocket connection
|
|
|
|
|
+#[derive(Debug)]
|
|
|
|
|
+pub struct WsSender {
|
|
|
|
|
+ inner: SplitSink<InnerStream, Message>,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl WsSender {
|
|
|
|
|
+ /// Send a text message
|
|
|
|
|
+ pub async fn send(&mut self, msg: &str) -> Result<(), WsError> {
|
|
|
|
|
+ self.inner
|
|
|
|
|
+ .send(Message::Text(msg.to_string().into()))
|
|
|
|
|
+ .await
|
|
|
|
|
+ .map_err(|e| WsError::Send(e.to_string()))
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ /// Close the connection gracefully
|
|
|
|
|
+ pub async fn close(mut self) -> Result<(), WsError> {
|
|
|
|
|
+ self.inner
|
|
|
|
|
+ .send(Message::Close(None))
|
|
|
|
|
+ .await
|
|
|
|
|
+ .map_err(|e| WsError::Close(e.to_string()))
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+/// Receive half of a split WebSocket connection
|
|
|
|
|
+#[derive(Debug)]
|
|
|
|
|
+pub struct WsReceiver {
|
|
|
|
|
+ inner: SplitStream<InnerStream>,
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+impl WsReceiver {
|
|
|
|
|
+ /// Receive the next text message
|
|
|
|
|
+ ///
|
|
|
|
|
+ /// Returns `None` if the connection is closed.
|
|
|
|
|
+ /// Non-text messages (ping/pong/binary/close) are skipped.
|
|
|
|
|
+ pub async fn recv(&mut self) -> Result<Option<String>, WsError> {
|
|
|
|
|
+ loop {
|
|
|
|
|
+ match self.inner.next().await {
|
|
|
|
|
+ Some(Ok(Message::Text(text))) => return Ok(Some(text.to_string())),
|
|
|
|
|
+ Some(Ok(Message::Close(_))) => return Ok(None),
|
|
|
|
|
+ Some(Ok(Message::Ping(_))) | Some(Ok(Message::Pong(_))) => {
|
|
|
|
|
+ // Note: When split, we can't respond to pings from the receiver
|
|
|
|
|
+ // The sender should be used to send pongs if needed
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ Some(Ok(Message::Binary(_))) | Some(Ok(Message::Frame(_))) => {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ Some(Err(e)) => return Err(WsError::Receive(e.to_string())),
|
|
|
|
|
+ None => return Ok(None),
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+/// Convert HTTP/HTTPS URLs to WebSocket URLs
|
|
|
|
|
+///
|
|
|
|
|
+/// - `http://` → `ws://`
|
|
|
|
|
+/// - `https://` → `wss://`
|
|
|
|
|
+/// - `ws://` and `wss://` are left unchanged
|
|
|
|
|
+fn convert_url_scheme(url: &str) -> Result<String, WsError> {
|
|
|
|
|
+ if url.starts_with("http://") {
|
|
|
|
|
+ Ok(url.replacen("http://", "ws://", 1))
|
|
|
|
|
+ } else if url.starts_with("https://") {
|
|
|
|
|
+ Ok(url.replacen("https://", "wss://", 1))
|
|
|
|
|
+ } else if url.starts_with("ws://") || url.starts_with("wss://") {
|
|
|
|
|
+ Ok(url.to_string())
|
|
|
|
|
+ } else {
|
|
|
|
|
+ Err(WsError::InvalidUrl(format!(
|
|
|
|
|
+ "URL must start with http://, https://, ws://, or wss://: {}",
|
|
|
|
|
+ url
|
|
|
|
|
+ )))
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+#[cfg(test)]
|
|
|
|
|
+mod tests {
|
|
|
|
|
+ use super::*;
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn test_url_scheme_conversion_http_to_ws() {
|
|
|
|
|
+ let result = convert_url_scheme("http://example.com/ws");
|
|
|
|
|
+ assert!(result.is_ok());
|
|
|
|
|
+ assert_eq!(result.expect("Should succeed"), "ws://example.com/ws");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn test_url_scheme_conversion_https_to_wss() {
|
|
|
|
|
+ let result = convert_url_scheme("https://example.com/ws");
|
|
|
|
|
+ assert!(result.is_ok());
|
|
|
|
|
+ assert_eq!(result.expect("Should succeed"), "wss://example.com/ws");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn test_url_scheme_conversion_ws_unchanged() {
|
|
|
|
|
+ let result = convert_url_scheme("ws://example.com/ws");
|
|
|
|
|
+ assert!(result.is_ok());
|
|
|
|
|
+ assert_eq!(result.expect("Should succeed"), "ws://example.com/ws");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn test_url_scheme_conversion_wss_unchanged() {
|
|
|
|
|
+ let result = convert_url_scheme("wss://example.com/ws");
|
|
|
|
|
+ assert!(result.is_ok());
|
|
|
|
|
+ assert_eq!(result.expect("Should succeed"), "wss://example.com/ws");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn test_url_scheme_conversion_invalid() {
|
|
|
|
|
+ let result = convert_url_scheme("ftp://example.com/ws");
|
|
|
|
|
+ assert!(result.is_err());
|
|
|
|
|
+ if let Err(WsError::InvalidUrl(msg)) = result {
|
|
|
|
|
+ assert!(msg.contains("must start with"));
|
|
|
|
|
+ } else {
|
|
|
|
|
+ panic!("Expected WsError::InvalidUrl");
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn test_client_new() {
|
|
|
|
|
+ let client = WsClient::new();
|
|
|
|
|
+ assert!(client.headers.is_empty());
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn test_client_default() {
|
|
|
|
|
+ let client = WsClient::default();
|
|
|
|
|
+ assert!(client.headers.is_empty());
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn test_builder_default() {
|
|
|
|
|
+ let builder = WsClientBuilder::default();
|
|
|
|
|
+ let client = builder.build();
|
|
|
|
|
+ assert!(client.is_ok());
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn test_builder_with_headers() {
|
|
|
|
|
+ let client = WsClient::builder()
|
|
|
|
|
+ .header("Authorization", "Bearer token")
|
|
|
|
|
+ .header("X-Custom", "value")
|
|
|
|
|
+ .build()
|
|
|
|
|
+ .expect("Should build successfully");
|
|
|
|
|
+
|
|
|
|
|
+ assert_eq!(client.headers.len(), 2);
|
|
|
|
|
+ assert_eq!(
|
|
|
|
|
+ client.headers[0],
|
|
|
|
|
+ ("Authorization".to_string(), "Bearer token".to_string())
|
|
|
|
|
+ );
|
|
|
|
|
+ assert_eq!(
|
|
|
|
|
+ client.headers[1],
|
|
|
|
|
+ ("X-Custom".to_string(), "value".to_string())
|
|
|
|
|
+ );
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[test]
|
|
|
|
|
+ fn test_builder_accept_invalid_certs() {
|
|
|
|
|
+ let client = WsClient::builder()
|
|
|
|
|
+ .danger_accept_invalid_certs(true)
|
|
|
|
|
+ .build()
|
|
|
|
|
+ .expect("Should build successfully");
|
|
|
|
|
+
|
|
|
|
|
+ assert!(client.accept_invalid_certs);
|
|
|
|
|
+ }
|
|
|
|
|
+}
|