Procházet zdrojové kódy

Better relayer API

Cesar Rodas před 3 měsíci
rodič
revize
cd4ba6982b
2 změnil soubory, kde provedl 79 přidání a 61 odebrání
  1. 3 0
      crates/relayer/src/error.rs
  2. 76 61
      crates/relayer/src/relayer.rs

+ 3 - 0
crates/relayer/src/error.rs

@@ -22,4 +22,7 @@ pub enum Error {
 
     #[error("Nostr client error: {0}")]
     Client(#[from] nostr_rs_client::Error),
+
+    #[error("Relayer already splitten")]
+    AlreadySplitted,
 }

+ 76 - 61
crates/relayer/src/relayer.rs

@@ -9,7 +9,7 @@ use nostr_rs_types::{
 };
 use std::{collections::HashMap, ops::Deref};
 use tokio::{
-    net::TcpStream,
+    net::{TcpListener, TcpStream},
     sync::{
         mpsc::{channel, Receiver, Sender},
         RwLockReadGuard,
@@ -24,7 +24,7 @@ type SubId = u128;
 
 type Subscriptions = HashMap<SubId, (SubscriptionId, Sender<Response>)>;
 
-pub struct Relayer<T: Storage> {
+pub struct Relayer<T: Storage + Send + Sync + 'static> {
     /// Storage engine, if provided the services are going to persisted in disk,
     /// otherwise all the messages are going to be ephemeral, making this
     /// relayer just a dumb proxy (that can be useful for privacy) but it won't
@@ -44,11 +44,17 @@ pub struct Relayer<T: Storage> {
     /// fast iteration and match quickly filters.
     subscriptions: RwLock<HashMap<Subscription, RwLock<Subscriptions>>>,
     clients: RwLock<HashMap<u128, Connection>>,
-    sender: Sender<(u128, Request)>,
+    /// This Sender can be used to send requests from anywhere to the relayer.
+    send_to_relayer: Sender<(u128, Request)>,
+    /// This Receiver is the relayer the way the relayer receives messages
+    relayer_receiver: Option<Receiver<(u128, Request)>>,
+    /// Client pool
+    ///
+    /// A relayer can optionally be connected to a pool of clients to get foreign events.
     client_pool: Option<(RwLock<Pool>, JoinHandle<()>)>,
 }
 
-impl<T: Storage> Drop for Relayer<T> {
+impl<T: Storage + Send + Sync + 'static> Drop for Relayer<T> {
     fn drop(&mut self) {
         if let Some((_, handle)) = self.client_pool.take() {
             let _ = handle.abort();
@@ -56,32 +62,64 @@ impl<T: Storage> Drop for Relayer<T> {
     }
 }
 
-impl<T: Storage> Relayer<T> {
-    pub fn new(
-        storage: Option<T>,
-        client_pool: Option<Pool>,
-    ) -> Result<(Self, Receiver<(u128, Request)>), Error> {
+impl<T: Storage + Send + Sync + 'static> Relayer<T> {
+    pub fn new(storage: Option<T>, client_pool: Option<Pool>) -> Result<Self, Error> {
         let (sender, receiver) = channel(100_000);
-        Ok((
-            Self {
-                storage,
-                sender: sender.clone(),
-                subscriptions: Default::default(),
-                subscriptions_ids_index: Default::default(),
-                clients: Default::default(),
-                client_pool: if let Some(client_pool) = client_pool {
-                    Some(Self::handle_client_pool(client_pool, sender)?)
-                } else {
-                    None
-                },
+        Ok(Self {
+            storage,
+            send_to_relayer: sender.clone(),
+            relayer_receiver: Some(receiver),
+            subscriptions: Default::default(),
+            subscriptions_ids_index: Default::default(),
+            clients: Default::default(),
+            client_pool: if let Some(client_pool) = client_pool {
+                Some(Self::handle_client_pool(client_pool, sender)?)
+            } else {
+                None
             },
-            receiver,
-        ))
+        })
     }
 
-    pub fn main(self) -> JoinHandle<()> {
-        let relayer = self;
-        tokio::spawn(async move { todo!() })
+    pub fn split(mut self) -> Result<(Self, Receiver<(u128, Request)>), Error> {
+        let receiver = self.relayer_receiver.take().ok_or(Error::AlreadySplitted)?;
+        Ok((self, receiver))
+    }
+
+    pub fn main(self, server: TcpListener) -> Result<JoinHandle<()>, Error> {
+        let (this, mut receiver) = self.split()?;
+        Ok(tokio::spawn(async move {
+            loop {
+                tokio::select! {
+                    Ok((stream, _)) = server.accept() => {
+                        // accept new external connections
+                        let _ = this.add_connection(None, stream).await;
+                    },
+                    Some((conn_id, request)) = receiver.recv() => {
+                        // receive messages from the connection pool
+                        if conn_id == 0 {
+                            // connection pool
+                            if let Request::Event(event) = request {
+                                this.store_and_broadcast(&event.deref()).await;
+                            }
+                            continue;
+                        }
+
+                        let connections = this.clients.read().await;
+                        let connection = if let Some(connection) = connections.get(&conn_id) {
+                            connection
+                        } else {
+                            continue;
+                        };
+
+                        // receive messages from clients
+                        let _ = this.progress_request_from_client(connection, request).await;
+                        drop(connections);
+                    }
+                    else => {
+                    }
+                }
+            }
+        }))
     }
 
     fn handle_client_pool(
@@ -117,14 +155,15 @@ impl<T: Storage> Relayer<T> {
         stream: TcpStream,
     ) -> Result<u128, Error> {
         let client =
-            Connection::new_connection(self.sender.clone(), disconnection_notify, stream).await?;
+            Connection::new_connection(self.send_to_relayer.clone(), disconnection_notify, stream)
+                .await?;
         let id = client.conn_id;
         self.clients.write().await.insert(id, client);
 
         Ok(id)
     }
 
-    async fn recv_request_from_client(
+    async fn progress_request_from_client(
         &self,
         connection: &Connection,
         request: Request,
@@ -213,34 +252,6 @@ impl<T: Storage> Relayer<T> {
         Ok(Some(request))
     }
 
-    pub async fn recv(
-        &self,
-        receiver: &mut Receiver<(u128, Request)>,
-    ) -> Result<Option<Request>, Error> {
-        let (conn_id, request) = if let Some(request) = receiver.recv().await {
-            request
-        } else {
-            return Ok(None);
-        };
-
-        if conn_id == 0 {
-            match request {
-                Request::Event(event) => {
-                    self.store_and_broadcast(&event.deref()).await;
-                }
-                _ => {}
-            };
-            return Ok(None);
-        }
-
-        let connections = self.clients.read().await;
-        let connection = connections
-            .get(&conn_id)
-            .ok_or(Error::UnknownConnection(conn_id))?;
-
-        self.recv_request_from_client(connection, request).await
-    }
-
     pub async fn send_to_conn(&self, conn_id: u128, response: Response) -> Result<(), Error> {
         let connections = self.clients.read().await;
         let connection = connections
@@ -380,9 +391,11 @@ mod test {
           }
         ]))
         .expect("valid object");
-        let (relayer, _) = Relayer::new(Some(get_db(true).await), None).expect("valid relayer");
+        let relayer = Relayer::new(Some(get_db(true).await), None).expect("valid relayer");
         let (connection, mut recv) = Connection::new_for_test();
-        let _ = relayer.recv_request_from_client(&connection, request).await;
+        let _ = relayer
+            .progress_request_from_client(&connection, request)
+            .await;
         // ev1
         assert_eq!(
             "9508850d7ddc8ef58c8b392236c49d472dc23fa11f4e73eb5475dfb099ddff42",
@@ -453,9 +466,11 @@ mod test {
     #[tokio::test]
     async fn server_listener_real_time() {
         let request: Request = serde_json::from_str("[\"REQ\",\"1298169700973717\",{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[1,3,6,7,9735],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"#e\":[\"2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a\",\"a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1\"],\"kinds\":[1,6,7,9735]}]").expect("valid object");
-        let (relayer, _) = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
+        let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
         let (connection, mut recv) = Connection::new_for_test();
-        let _ = relayer.recv_request_from_client(&connection, request).await;
+        let _ = relayer
+            .progress_request_from_client(&connection, request)
+            .await;
         // eod
         assert!(recv
             .try_recv()
@@ -469,7 +484,7 @@ mod test {
         let new_event: Request = serde_json::from_str(r#"["EVENT", {"kind":1,"content":"Pong","tags":[["e","9508850d7ddc8ef58c8b392236c49d472dc23fa11f4e73eb5475dfb099ddff42","","root"],["e","2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a","","reply"],["p","39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],["p","ee7202ad91459e013bfef263c59e47deb0163a5e7651b026673765488bfaf102"]],"created_at":1681938616,"pubkey":"a42007e33cfa25673b26f46f39df039aa6003258a68dc88f1f1e0447607aedb3","id":"e862fe23daf52ab09b36a37fa91ca3743e0c323e630e8627891212ca147c2da9","sig":"9036150a6c8a32933cffcc42aec4d2109a22e9f10d1c3860c0435a925e6386babb7df5c95fcf68c8ed6a9726a1f07225af663d0b068eb555014130aad21674fc","meta":{"revision":0,"created":1681939266488,"version":0},"$loki":108}]"#).expect("value");
 
         relayer
-            .recv_request_from_client(&connection, new_event)
+            .progress_request_from_client(&connection, new_event)
             .await
             .expect("process event");