Преглед на файлове

Improvements to relayer

Detect when a client's connection drops
Cesar Rodas преди 1 година
родител
ревизия
a317dca219
променени са 5 файла, в които са добавени 141 реда и са изтрити 12 реда
  1. 27 6
      crates/relayer/src/connection.rs
  2. 18 5
      crates/relayer/src/relayer.rs
  3. 80 0
      crates/types/src/relayer/auth.rs
  4. 2 1
      crates/types/src/relayer/mod.rs
  5. 14 0
      crates/types/src/response.rs

+ 27 - 6
crates/relayer/src/connection.rs

@@ -1,6 +1,10 @@
 use crate::{get_id, Error};
 use futures_util::{SinkExt, StreamExt};
-use nostr_rs_types::{relayer::ROk, types::Addr, Request, Response};
+use nostr_rs_types::{
+    relayer::{Auth, ROk},
+    types::Addr,
+    Request, Response,
+};
 use parking_lot::RwLock;
 use std::collections::HashMap;
 use tokio::{
@@ -12,7 +16,7 @@ use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream};
 
 #[derive(Debug)]
 pub struct Connection {
-    #[allow(dead_code)]
+    #[allow(unused)]
     pub(crate) conn_id: u128,
     sender: Sender<Response>,
     subscriptions: RwLock<HashMap<String, u128>>,
@@ -22,7 +26,7 @@ const MAX_SUBSCRIPTIONS_BUFFER: usize = 100;
 
 impl Connection {
     #[cfg(test)]
-    pub fn new() -> (Self, Receiver<Response>) {
+    pub fn new_for_test() -> (Self, Receiver<Response>) {
         let (sender, receiver) = channel(MAX_SUBSCRIPTIONS_BUFFER);
         (
             Self {
@@ -34,15 +38,22 @@ impl Connection {
         )
     }
 
-    #[cfg(not(test))]
     pub async fn new(
         broadcast_request: Sender<(u128, Request)>,
+        disconnection_notify: Option<Sender<u128>>,
         stream: TcpStream,
     ) -> Result<Self, Error> {
         let websocket = accept_async(stream).await?;
         let conn_id = get_id();
         let (sender, receiver) = channel(MAX_SUBSCRIPTIONS_BUFFER);
-        Self::spawn(broadcast_request, websocket, receiver, conn_id);
+        Self::spawn(
+            broadcast_request,
+            websocket,
+            receiver,
+            disconnection_notify,
+            conn_id,
+        );
+        let _ = sender.send(Auth::default().into()).await;
         Ok(Self {
             conn_id,
             sender,
@@ -50,11 +61,12 @@ impl Connection {
         })
     }
 
-    #[cfg(not(test))]
+    #[allow(unused)]
     fn spawn(
         broadcast_request: Sender<(u128, Request)>,
         websocket: WebSocketStream<TcpStream>,
         mut receiver: Receiver<Response>,
+        disconnection_notify: Option<Sender<u128>>,
         conn_id: u128,
     ) {
         tokio::spawn(async move {
@@ -105,6 +117,10 @@ impl Connection {
                     }
                 }
             }
+
+            if let Some(disconnection_notify) = disconnection_notify {
+                let _ = disconnection_notify.try_send(conn_id);
+            }
         });
     }
 
@@ -113,6 +129,11 @@ impl Connection {
         Ok(self.sender.try_send(response)?)
     }
 
+    #[inline]
+    pub fn get_sender(&self) -> Sender<Response> {
+        self.sender.clone()
+    }
+
     pub fn get_subscription_id(&self, id: &str) -> Option<u128> {
         let subscriptions = self.subscriptions.read();
         subscriptions.get(id).copied()

+ 18 - 5
crates/relayer/src/relayer.rs

@@ -7,6 +7,7 @@ use nostr_rs_types::{
 };
 use parking_lot::{RwLock, RwLockReadGuard};
 use std::{collections::HashMap, ops::Deref, sync::Arc};
+use tokio::sync::mpsc;
 #[allow(unused_imports)]
 use tokio::{
     net::TcpStream,
@@ -57,9 +58,12 @@ impl Relayer {
         &self.storage
     }
 
-    #[cfg(not(test))]
-    pub async fn add_connection(&self, stream: TcpStream) -> Result<(), Error> {
-        let client = Connection::new(self.sender.clone(), stream).await?;
+    pub async fn add_connection(
+        &self,
+        disconnection_notify: Option<mpsc::Sender<u128>>,
+        stream: TcpStream,
+    ) -> Result<(), Error> {
+        let client = Connection::new(self.sender.clone(), disconnection_notify, stream).await?;
         let mut clients = self.clients.write();
         clients.insert(client.conn_id, client);
 
@@ -164,6 +168,15 @@ impl Relayer {
         self.recv_request_from_client(connection, request)
     }
 
+    pub fn send_to_conn(&self, conn_id: u128, response: Response) -> Result<(), Error> {
+        let connections = self.clients.read();
+        let connection = connections
+            .get(&conn_id)
+            .ok_or(Error::UnknownConnection(conn_id))?;
+
+        connection.send(response)
+    }
+
     #[inline]
     fn broadcast_to_subscribers(subscriptions: RwLockReadGuard<Subscriptions>, event: &Event) {
         for (_, receiver) in subscriptions.iter() {
@@ -233,7 +246,7 @@ mod test {
     async fn serve_listener_from_local_db() {
         let request: Request = serde_json::from_str("[\"REQ\",\"1298169700973717\",{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[1,3,6,7,9735],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"#e\":[\"2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a\",\"a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1\"],\"kinds\":[1,6,7,9735]}]").expect("valid object");
         let (relayer, _) = Relayer::new(get_db(true));
-        let (connection, mut recv) = Connection::new();
+        let (connection, mut recv) = Connection::new_for_test();
         let _ = relayer.recv_request_from_client(&connection, request);
         // ev1
         assert_eq!(
@@ -282,7 +295,7 @@ mod test {
     async fn server_listener_real_time() {
         let request: Request = serde_json::from_str("[\"REQ\",\"1298169700973717\",{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[1,3,6,7,9735],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"#e\":[\"2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a\",\"a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1\"],\"kinds\":[1,6,7,9735]}]").expect("valid object");
         let (relayer, _) = Relayer::new(get_db(false));
-        let (connection, mut recv) = Connection::new();
+        let (connection, mut recv) = Connection::new_for_test();
         let _ = relayer.recv_request_from_client(&connection, request);
         // eod
         assert!(recv

+ 80 - 0
crates/types/src/relayer/auth.rs

@@ -0,0 +1,80 @@
+//! Auth message
+//!
+//! This message is being used to authenticate the client to the relayer, it
+//! works by giving a random challenge to the client, that must be returned
+//! signed by the client
+use crate::common::SerializeDeserialize;
+use rand::{distributions::Alphanumeric, Rng};
+use serde_json::Value;
+use std::{collections::VecDeque, ops::Deref};
+
+/// This is how a relayer sends an authentication challenge to the client. The
+/// challenge must be returned signed by the client
+#[derive(Clone, Debug)]
+pub struct Auth(pub String);
+
+impl Default for Auth {
+    fn default() -> Self {
+        Self(
+            rand::thread_rng()
+                .sample_iter(&Alphanumeric)
+                .take(64)
+                .map(char::from)
+                .collect(),
+        )
+    }
+}
+
+impl Deref for Auth {
+    type Target = str;
+
+    fn deref(&self) -> &Self::Target {
+        &self.0
+    }
+}
+
+impl SerializeDeserialize for Auth {
+    fn get_tag() -> &'static str {
+        "AUTH"
+    }
+
+    fn serialize(&self) -> Result<Vec<Value>, String> {
+        Ok(vec![
+            Value::String(Self::get_tag().to_owned()),
+            Value::String((*self.0).to_owned()),
+        ])
+    }
+
+    fn deserialize(args: VecDeque<Value>) -> Result<Self, String> {
+        if args.is_empty() {
+            return Err("Invalid length".to_owned());
+        }
+        let challenge = args[0]
+            .as_str()
+            .ok_or_else(|| "Invalid message, expected string".to_owned())?;
+        Ok(Auth(challenge.to_owned()))
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use crate::Response;
+
+    #[test]
+    fn parse() {
+        let json = r#"["AUTH", "test"]"#;
+        let message: Response = serde_json::from_str(json).expect("message");
+        assert!(message.as_auth().is_some());
+    }
+
+    #[test]
+    fn serialize() {
+        let auth = Auth("test".to_owned());
+        let m: Response = auth.into();
+        assert_eq!(
+            r#"["AUTH","test"]"#,
+            serde_json::to_string(&m).expect("valid json")
+        );
+    }
+}

+ 2 - 1
crates/types/src/relayer/mod.rs

@@ -2,9 +2,10 @@
 //!
 //! This mod has all the messages that relayers may send to clients
 
+pub mod auth;
 pub mod eose;
 pub mod event;
 pub mod notice;
 pub mod ok;
 
-pub use self::{eose::EndOfStoredEvents, event::Event, notice::Notice, ok::ROk};
+pub use self::{auth::Auth, eose::EndOfStoredEvents, event::Event, notice::Notice, ok::ROk};

+ 14 - 0
crates/types/src/response.rs

@@ -15,6 +15,8 @@ custom_derive! {
     ///
     /// All responses from relayers to clients are abstracted in this struct
     pub enum Response {
+        /// Server Authentication challenge
+        Auth(relayer::Auth),
         /// The server replies an OK
         Ok(relayer::ROk),
         /// This is how server communicates about errors (most likely protocol
@@ -48,6 +50,14 @@ impl Response {
         }
     }
 
+    /// Returns the message as auth challenge, if possible
+    pub fn as_auth(&self) -> Option<&relayer::Auth> {
+        match self {
+            Self::Auth(auth) => Some(auth),
+            _ => None,
+        }
+    }
+
     /// Returns the current message as a notice, if possible
     pub fn as_notice(&self) -> Option<&relayer::Notice> {
         match self {
@@ -71,6 +81,7 @@ impl ser::Serialize for Response {
         S: Serializer,
     {
         let values = match self {
+            Self::Auth(t) => t.serialize(),
             Self::Ok(t) => t.serialize(),
             Self::Notice(t) => t.serialize(),
             Self::Event(t) => t.serialize(),
@@ -107,6 +118,9 @@ impl<'de> de::Deserialize<'de> for Response {
             .ok_or_else(|| de::Error::custom("Invalid type for element 0 of the array"))?;
 
         match tag {
+            "AUTH" => Ok(relayer::Auth::deserialize(array)
+                .map_err(de::Error::custom)?
+                .into()),
             "EOSE" => Ok(relayer::EndOfStoredEvents::deserialize(array)
                 .map_err(de::Error::custom)?
                 .into()),