|
|
@@ -4,9 +4,12 @@ use std::sync::Arc;
|
|
|
|
|
|
use cdk_common::subscription::Params;
|
|
|
use cdk_common::ws::{WsMessageOrResponse, WsMethodRequest, WsRequest, WsUnsubscribeRequest};
|
|
|
+#[cfg(feature = "auth")]
|
|
|
+use cdk_common::{Method, RoutePath};
|
|
|
use futures::{SinkExt, StreamExt};
|
|
|
use tokio::sync::{mpsc, RwLock};
|
|
|
use tokio_tungstenite::connect_async;
|
|
|
+use tokio_tungstenite::tungstenite::client::IntoClientRequest;
|
|
|
use tokio_tungstenite::tungstenite::Message;
|
|
|
|
|
|
use super::http::http_main;
|
|
|
@@ -37,14 +40,64 @@ pub async fn ws_main(
|
|
|
url.set_scheme("ws").expect("Could not set scheme");
|
|
|
}
|
|
|
|
|
|
- let url = url.to_string();
|
|
|
+ let request = match url.to_string().into_client_request() {
|
|
|
+ Ok(req) => req,
|
|
|
+ Err(err) => {
|
|
|
+ tracing::error!("Failed to create client request: {:?}", err);
|
|
|
+ // Fallback to HTTP client if we can't create the WebSocket request
|
|
|
+ return http_main(
|
|
|
+ std::iter::empty(),
|
|
|
+ http_client,
|
|
|
+ subscriptions,
|
|
|
+ new_subscription_recv,
|
|
|
+ on_drop,
|
|
|
+ wallet,
|
|
|
+ )
|
|
|
+ .await;
|
|
|
+ }
|
|
|
+ };
|
|
|
|
|
|
let mut active_subscriptions = HashMap::<SubId, mpsc::Sender<_>>::new();
|
|
|
let mut failure_count = 0;
|
|
|
|
|
|
loop {
|
|
|
+ let mut request_clone = request.clone();
|
|
|
+ #[cfg(feature = "auth")]
|
|
|
+ {
|
|
|
+ let auth_wallet = http_client.get_auth_wallet().await;
|
|
|
+ let token = match auth_wallet.as_ref() {
|
|
|
+ Some(auth_wallet) => {
|
|
|
+ let endpoint = cdk_common::ProtectedEndpoint::new(Method::Get, RoutePath::Ws);
|
|
|
+ match auth_wallet.get_auth_for_request(&endpoint).await {
|
|
|
+ Ok(token) => token,
|
|
|
+ Err(err) => {
|
|
|
+ tracing::warn!("Failed to get auth token: {:?}", err);
|
|
|
+ None
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ None => None,
|
|
|
+ };
|
|
|
+
|
|
|
+ if let Some(auth_token) = token {
|
|
|
+ let header_key = match &auth_token {
|
|
|
+ cdk_common::AuthToken::ClearAuth(_) => "Clear-auth",
|
|
|
+ cdk_common::AuthToken::BlindAuth(_) => "Blind-auth",
|
|
|
+ };
|
|
|
+
|
|
|
+ match auth_token.to_string().parse() {
|
|
|
+ Ok(header_value) => {
|
|
|
+ request_clone.headers_mut().insert(header_key, header_value);
|
|
|
+ }
|
|
|
+ Err(err) => {
|
|
|
+ tracing::warn!("Failed to parse auth token as header value: {:?}", err);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
tracing::debug!("Connecting to {}", url);
|
|
|
- let ws_stream = match connect_async(&url).await {
|
|
|
+ let ws_stream = match connect_async(request_clone.clone()).await {
|
|
|
Ok((ws_stream, _)) => ws_stream,
|
|
|
Err(err) => {
|
|
|
failure_count += 1;
|