use crate::{get_id, Error}; use futures_util::{SinkExt, StreamExt}; use nostr_rs_types::{ relayer::{Auth, ROk}, types::Addr, Request, Response, }; use parking_lot::RwLock; use std::collections::HashMap; use tokio::{ net::TcpStream, sync::mpsc::{channel, Receiver, Sender}, }; #[allow(unused_imports)] use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream}; #[derive(Debug)] pub struct Connection { #[allow(unused)] pub(crate) conn_id: u128, sender: Sender, subscriptions: RwLock>, } const MAX_SUBSCRIPTIONS_BUFFER: usize = 100; impl Connection { #[cfg(test)] pub fn new_for_test() -> (Self, Receiver) { let (sender, receiver) = channel(MAX_SUBSCRIPTIONS_BUFFER); ( Self { conn_id: 0, sender, subscriptions: RwLock::new(HashMap::new()), }, receiver, ) } pub async fn new( broadcast_request: Sender<(u128, Request)>, disconnection_notify: Option>, stream: TcpStream, ) -> Result { let websocket = accept_async(stream).await?; let conn_id = get_id(); let (sender, receiver) = channel(MAX_SUBSCRIPTIONS_BUFFER); Self::spawn( broadcast_request, websocket, receiver, disconnection_notify, conn_id, ); let _ = sender.send(Auth::default().into()).await; Ok(Self { conn_id, sender, subscriptions: RwLock::new(HashMap::new()), }) } #[allow(unused)] fn spawn( broadcast_request: Sender<(u128, Request)>, websocket: WebSocketStream, mut receiver: Receiver, disconnection_notify: Option>, conn_id: u128, ) { tokio::spawn(async move { let mut _subscriptions: HashMap)> = HashMap::new(); let (mut writer, mut reader) = websocket.split(); loop { tokio::select! { Some(msg) = receiver.recv() => { let msg = if let Ok(msg) = serde_json::to_string(&msg) { msg } else { continue; }; if let Err(err) = writer.send(Message::Text(msg)).await { log::error!("Error sending message to client: {}", err); break; } } Some(msg) = reader.next() => { if let Ok(Message::Text(msg)) = msg { let msg: Result = serde_json::from_str(&msg); match msg { Ok(msg) => { let _ = broadcast_request.send((conn_id, msg)).await; }, Err(err) => { log::error!("Error parsing message from client: {}", err); let reply: Response = ROk { id: Addr::default(), status: false, message: "Error parsing message".to_owned(), }.into(); let reply = if let Ok(reply) = serde_json::to_string(&reply) { reply } else { continue; }; if let Err(err) = writer.send(Message::Text(reply)).await { log::error!("Error sending message to client: {}", err); break; } } }; } } else => { break; } } } if let Some(disconnection_notify) = disconnection_notify { let _ = disconnection_notify.try_send(conn_id); } }); } #[inline] pub fn send(&self, response: Response) -> Result<(), Error> { self.sender .try_send(response) .map_err(|e| Error::TrySendError(Box::new(e))) } #[inline] pub fn get_sender(&self) -> Sender { self.sender.clone() } pub fn get_subscription_id(&self, id: &str) -> Option { let subscriptions = self.subscriptions.read(); subscriptions.get(id).copied() } pub fn create_subscription(&self, id: String) -> (u128, Sender) { let mut subscriptions = self.subscriptions.write(); let internal_id = subscriptions.entry(id).or_insert_with(get_id); (*internal_id, self.sender.clone()) } }