Parcourir la source

Enable client pool to process requests and responses in a non-blocking way

Subscriptions may be slow, with this new change all incoming requests will be
processed in parallel.
Cesar Rodas il y a 1 mois
Parent
commit
6e33a10364

+ 1 - 1
crates/personal-relayer/src/lib.rs

@@ -40,7 +40,7 @@ pub struct PersonalRelayer<T: Storage + Send + Sync + 'static> {
 
 impl<T: Storage + Send + Sync + 'static> PersonalRelayer<T> {
     pub async fn new(storage: T, accounts: Vec<Id>, client_urls: Vec<Url>) -> Result<Self, Error> {
-        let pool = Pool::new_with_clients(client_urls);
+        let (pool, _) = Pool::new_with_clients(client_urls)?;
 
         join_all(
             accounts

+ 1 - 1
crates/relayer/src/connection/mod.rs

@@ -178,7 +178,7 @@ impl Connection {
 
     #[inline]
     /// Sends a message to this connection's websocket
-    pub fn send(&self, response: Response) -> Result<(), Error> {
+    pub fn respond(&self, response: Response) -> Result<(), Error> {
         self.sender
             .try_send(response)
             .map_err(|e| Error::TrySendError(Box::new(e)))

+ 2 - 2
crates/relayer/src/error.rs

@@ -33,8 +33,8 @@ pub enum Error {
     NoClient,
 
     /// Unknown connections
-    #[error("Unknown connection: {0}")]
-    UnknownConnection(u128),
+    #[error("Unknown connection: {0:?}")]
+    UnknownConnection(ConnectionId),
 
     /// The relayer is already splitten
     #[error("Relayer already splitten")]

+ 136 - 92
crates/relayer/src/relayer.rs

@@ -5,7 +5,6 @@ use crate::{
 use futures_util::StreamExt;
 use nostr_rs_client::{pool::subscription::PoolSubscriptionId, Pool, Url};
 use nostr_rs_storage_base::Storage;
-use nostr_rs_subscription_manager::SubscriptionManager;
 use nostr_rs_types::{
     relayer::{self, ROk, ROkStatus},
     types::{Addr, Event, SubscriptionId},
@@ -15,6 +14,7 @@ use std::{
     collections::{HashMap, HashSet},
     ops::Deref,
     sync::Arc,
+    time::Instant,
 };
 use tokio::{
     net::{TcpListener, TcpStream},
@@ -42,6 +42,12 @@ impl Default for RelayerSubscriptionId {
     }
 }
 
+type Connections = Arc<RwLock<HashMap<ConnectionId, Connection>>>;
+type SubscriptionManager =
+    Arc<nostr_rs_subscription_manager::SubscriptionManager<RelayerSubscriptionId, ()>>;
+type ClientPoolSubscriptions =
+    Arc<RwLock<HashMap<PoolSubscriptionId, (SubscriptionId, ConnectionId)>>>;
+
 /// Relayer struct
 ///
 pub struct Relayer<T: Storage + Send + Sync + 'static> {
@@ -51,9 +57,9 @@ pub struct Relayer<T: Storage + Send + Sync + 'static> {
     /// be able to perform any optimization like prefetching content while offline
     storage: Arc<Option<T>>,
     /// Subscription manager
-    subscription_manager: Arc<SubscriptionManager<RelayerSubscriptionId, ()>>,
+    subscription_manager: SubscriptionManager,
     /// List of all active connections
-    connections: Arc<RwLock<HashMap<ConnectionId, Connection>>>,
+    connections: Connections,
     /// This Sender can be used to send requests from anywhere to the relayer.
     send_to_relayer: Sender<(ConnectionId, Request)>,
     /// This Receiver is the relayer the way the relayer receives messages
@@ -63,9 +69,9 @@ pub struct Relayer<T: Storage + Send + Sync + 'static> {
     ///
     /// A relayer can optionally be connected to a pool of clients to get
     /// foreign events.
-    client_pool: Option<Pool>,
+    client_pool: Option<Arc<Pool>>,
     client_pool_receiver: Option<Receiver<(Response, Url)>>,
-    client_pool_subscriptions: RwLock<HashMap<PoolSubscriptionId, (SubscriptionId, ConnectionId)>>,
+    client_pool_subscriptions: ClientPoolSubscriptions,
 }
 
 impl<T: Storage + Send + Sync + 'static> Relayer<T> {
@@ -82,7 +88,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
 
         let (client_pool_receiver, client_pool) = if let Some(client_pool) = client_pool {
             let result = client_pool.split()?;
-            (result.0, Some(result.1))
+            (result.0, Some(Arc::new(result.1)))
         } else {
             let (_, receiver) = mpsc::channel(1);
             (receiver, None)
@@ -95,7 +101,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
             relayer_receiver: Some(relayer_receiver),
             connections: Default::default(),
             client_pool_receiver: Some(client_pool_receiver),
-            client_pool: client_pool,
+            client_pool,
             client_pool_subscriptions: Default::default(),
         })
     }
@@ -138,6 +144,8 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
 
         let handle = tokio::spawn(async move {
             loop {
+                let start = Instant::now();
+                println!("{}", client_pool_receiver.len());
                 tokio::select! {
                     Ok((stream, _)) = server.accept() => {
                         // accept new connections
@@ -149,7 +157,12 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                             Response::Event(event) => {
                                 // we received a message from the client pool, store it locally
                                 // and re-broadcast it.
-                                let _ = this.broadcast(event.deref()).await;
+                                tokio::spawn(Self::broadcast(
+                                    this.storage.clone(),
+                                    this.subscription_manager.clone(),
+                                    this.connections.clone(),
+                                    event.event
+                                ));
                             }
                             Response::EndOfStoredEvents(sub) => {
                                 let connections = this.connections.read().await;
@@ -163,23 +176,23 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                                     continue
                                 };
 
-                                let _ = connection.send(Response::EndOfStoredEvents(sub_id.into()));
+                                let _ = connection.respond(Response::EndOfStoredEvents(sub_id.into()));
+                                let duration = start.elapsed();
+                                println!("xTime elapsed: {} ms", duration.as_millis());
                             }
                             _ => {}
                         }
                     }
                     Some((conn_id, request)) = receiver.recv() => {
-                        // receive messages from our clients
-                        let connections = this.connections.read().await;
-                        let connection = if let Some(connection) = connections.get(&conn_id) {
-                            connection
-                        } else {
-                            continue;
-                        };
-
-                        // receive messages from clients
-                        let _ = this.process_request_from_client(connection, request).await;
-                        drop(connections);
+                        tokio::spawn(Self::process_request(
+                            this.storage.clone(),
+                            this.client_pool.clone(),
+                            this.client_pool_subscriptions.clone(),
+                            this.subscription_manager.clone(),
+                            this.connections.clone(),
+                            conn_id,
+                            request.clone()
+                        ));
                     }
                     else => {
                     }
@@ -242,17 +255,50 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
         Ok(id)
     }
 
-    /// Process a request from a connected client
+    #[cfg(test)]
     async fn process_request_from_client(
         &self,
-        connection: &Connection,
+        connection: &LocalConnection<T>,
+        request: Request,
+    ) -> Result<(), Error> {
+        Self::process_request(
+            self.storage.clone(),
+            self.client_pool.clone(),
+            self.client_pool_subscriptions.clone(),
+            self.subscription_manager.clone(),
+            self.connections.clone(),
+            connection.conn_id,
+            request,
+        )
+        .await
+    }
+
+    /// Process a request from a connected client
+    async fn process_request(
+        storage: Arc<Option<T>>,
+        client_pool: Option<Arc<Pool>>,
+        client_pool_subscriptions: ClientPoolSubscriptions,
+        subscription_manager: SubscriptionManager,
+        connections: Connections,
+        connection_id: ConnectionId,
         request: Request,
     ) -> Result<(), Error> {
         match request {
             Request::Event(event) => {
+                let read_connections = connections.read().await;
+                let connection = read_connections
+                    .get(&connection_id)
+                    .ok_or(Error::UnknownConnection(connection_id))?;
                 let event_id: Addr = event.id.clone().into();
-                if !self.broadcast(&event).await? {
-                    connection.send(
+                if !Self::broadcast(
+                    storage.clone(),
+                    subscription_manager.clone(),
+                    connections.clone(),
+                    event.deref().clone(),
+                )
+                .await?
+                {
+                    connection.respond(
                         ROk {
                             id: event_id,
                             status: ROkStatus::Duplicate,
@@ -262,17 +308,17 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                     return Ok(());
                 }
 
-                if let Some(storage) = self.storage.as_ref() {
+                if let Some(storage) = storage.as_ref() {
                     let _ = storage.store_local_event(&event).await;
                 }
 
-                if let Some(client_pool) = self.client_pool.as_ref() {
+                if let Some(client_pool) = client_pool.as_ref() {
                     // pass the event to the pool of clients, so this relayer can relay
                     // their local events to the clients in the network of relayers
                     let _ = client_pool.post(event).await;
                 }
 
-                connection.send(
+                connection.respond(
                     ROk {
                         id: event_id,
                         status: ROkStatus::Ok,
@@ -281,7 +327,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                 )?;
             }
             Request::Request(request) => {
-                let foreign_subscription = if let Some(client_pool) = self.client_pool.as_ref() {
+                let foreign_subscription = if let Some(client_pool) = client_pool.as_ref() {
                     // If this relay is connected to other relays through the
                     // client pool, create the same subscription in them as
                     // well, with the main goal of fetching any foreign event
@@ -295,9 +341,9 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                         .subscribe(request.filters.clone().into())
                         .await?;
 
-                    self.client_pool_subscriptions.write().await.insert(
+                    client_pool_subscriptions.write().await.insert(
                         foreign_sub_id.clone(),
-                        (request.subscription_id.clone(), connection.get_conn_id()),
+                        (request.subscription_id.clone(), connection_id),
                     );
 
                     Some(foreign_sub_id)
@@ -305,7 +351,12 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                     None
                 };
 
-                if let Some(storage) = self.storage.as_ref() {
+                let read_connections = connections.read().await;
+                let connection = read_connections
+                    .get(&connection_id)
+                    .ok_or(Error::UnknownConnection(connection_id))?;
+
+                if let Some(storage) = storage.as_ref() {
                     let mut sent = HashSet::new();
                     // Sent all events that match the filter that are stored in our database
                     for filter in request.filters.clone().into_iter() {
@@ -316,7 +367,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                                 continue;
                             }
                             sent.insert(event.id.clone());
-                            let _ = connection.send(
+                            let _ = connection.respond(
                                 relayer::Event {
                                     subscription_id: request.subscription_id.clone(),
                                     event,
@@ -330,8 +381,9 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                 if foreign_subscription.is_none() {
                     // If there is a foreign subscription, we shouldn't send a
                     // EOS until we have got EOS from all foreign relays
-                    let _ = connection
-                        .send(relayer::EndOfStoredEvents(request.subscription_id.clone()).into());
+                    let _ = connection.respond(
+                        relayer::EndOfStoredEvents(request.subscription_id.clone()).into(),
+                    );
                 }
 
                 connection
@@ -339,7 +391,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                         request.subscription_id.clone(),
                         (
                             foreign_subscription,
-                            self.subscription_manager
+                            subscription_manager
                                 .subscribe(
                                     (request.subscription_id, connection.get_conn_id()).into(),
                                     request.filters,
@@ -351,7 +403,13 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                     .await;
             }
             Request::Close(close) => {
-                connection.unsubscribe(&close).await;
+                connections
+                    .read()
+                    .await
+                    .get(&connection_id)
+                    .ok_or(Error::UnknownConnection(connection_id))?
+                    .unsubscribe(&close)
+                    .await;
             }
         };
 
@@ -360,51 +418,24 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
 
     #[inline]
     /// A non-blocking version of broadcast
-    #[allow(dead_code)]
-    fn broadcast_and_forget(&self, event: Event) {
-        let storage = self.storage.clone();
-        let connections = self.connections.clone();
-        let subscription_manager = self.subscription_manager.clone();
-
-        tokio::spawn(async move {
-            if let Some(storage) = storage.as_ref() {
-                if !storage.store(&event).await.unwrap_or_default() {
-                    return;
-                }
-            }
-
-            let connections = connections.read().await;
-            for RelayerSubscriptionId((sub_id, conn_id)) in
-                subscription_manager.get_subscribers(&event).await
-            {
-                if let Some(connection) = connections.get(&conn_id) {
-                    let _ = connection.send(
-                        relayer::Event {
-                            subscription_id: sub_id,
-                            event: event.clone(),
-                        }
-                        .into(),
-                    );
-                }
-            }
-        });
-    }
-
-    #[inline]
-    /// Broadcast a given event to all local subscribers
-    pub async fn broadcast(&self, event: &Event) -> Result<bool, Error> {
-        if let Some(storage) = self.storage.as_ref() {
-            if !storage.store(event).await? {
+    pub async fn broadcast(
+        storage: Arc<Option<T>>,
+        subscription_manager: SubscriptionManager,
+        connections: Connections,
+        event: Event,
+    ) -> Result<bool, Error> {
+        if let Some(storage) = storage.as_ref() {
+            if !storage.store(&event).await? {
                 return Ok(false);
             }
         }
 
-        let connections = self.connections.read().await;
+        let connections = connections.read().await;
         for RelayerSubscriptionId((sub_id, conn_id)) in
-            self.subscription_manager.get_subscribers(event).await
+            subscription_manager.get_subscribers(&event).await
         {
             if let Some(connection) = connections.get(&conn_id) {
-                let _ = connection.send(
+                let _ = connection.respond(
                     relayer::Event {
                         subscription_id: sub_id,
                         event: event.clone(),
@@ -413,14 +444,13 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                 );
             }
         }
+
         Ok(true)
     }
 }
 
 #[cfg(test)]
 mod test {
-    use std::time::Duration;
-
     use super::*;
     use futures::future::join_all;
     use nostr_rs_client::Url;
@@ -431,6 +461,7 @@ mod test {
         Request,
     };
     use serde_json::json;
+    use std::time::Duration;
     use tokio::time::sleep;
 
     async fn dummy_server(port: u16, client_pool: Option<Pool>) -> (Url, JoinHandle<()>) {
@@ -517,8 +548,10 @@ mod test {
           },
         ]))
         .expect("valid object");
-        let relayer = Relayer::new(Some(get_db(true).await), None).expect("valid relayer");
-        let (connection, mut recv) = Connection::new_local_connection();
+        let relayer =
+            Arc::new(Relayer::new(Some(get_db(true).await), None).expect("valid relayer"));
+
+        let mut connection = relayer.create_new_local_connection().await;
 
         let note = get_note_with_custom_tags(json!([["f", "foo"]]));
 
@@ -535,7 +568,8 @@ mod test {
         // ev1
         assert_eq!(
             ROkStatus::Ok,
-            recv.try_recv()
+            connection
+                .try_recv()
                 .expect("valid")
                 .as_ok()
                 .cloned()
@@ -546,17 +580,22 @@ mod test {
         // ev1
         assert_eq!(
             note,
-            recv.try_recv().expect("valid").as_event().unwrap().event
+            connection
+                .try_recv()
+                .expect("valid")
+                .as_event()
+                .unwrap()
+                .event
         );
 
         // eod
-        assert!(recv
+        assert!(connection
             .try_recv()
             .expect("valid")
             .as_end_of_stored_events()
             .is_some());
 
-        assert!(recv.try_recv().is_err());
+        assert!(connection.try_recv().is_none());
     }
 
     #[tokio::test]
@@ -613,15 +652,17 @@ mod test {
           }
         ]))
         .expect("valid object");
-        let relayer = Relayer::new(Some(get_db(true).await), None).expect("valid relayer");
-        let (connection, mut recv) = Connection::new_local_connection();
+        let relayer =
+            Arc::new(Relayer::new(Some(get_db(true).await), None).expect("valid relayer"));
+        let mut connection = relayer.create_new_local_connection().await;
         let _ = relayer
             .process_request_from_client(&connection, request)
             .await;
         // ev1
         assert_eq!(
             "9508850d7ddc8ef58c8b392236c49d472dc23fa11f4e73eb5475dfb099ddff42",
-            recv.try_recv()
+            connection
+                .try_recv()
                 .expect("valid")
                 .as_event()
                 .expect("event")
@@ -631,7 +672,8 @@ mod test {
         // ev3
         assert_eq!(
             "e862fe23daf52ab09b36a37fa91ca3743e0c323e630e8627891212ca147c2da9",
-            recv.try_recv()
+            connection
+                .try_recv()
                 .expect("valid")
                 .as_event()
                 .expect("event")
@@ -641,7 +683,8 @@ mod test {
         // ev2
         assert_eq!(
             "2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a",
-            recv.try_recv()
+            connection
+                .try_recv()
                 .expect("valid")
                 .as_event()
                 .expect("event")
@@ -650,13 +693,13 @@ mod test {
         );
 
         // eod
-        assert!(recv
+        assert!(connection
             .try_recv()
             .expect("valid")
             .as_end_of_stored_events()
             .is_some());
 
-        assert!(recv.try_recv().is_err());
+        assert!(connection.try_recv().is_none());
     }
 
     #[tokio::test]
@@ -949,8 +992,9 @@ mod test {
 
     #[tokio::test]
     async fn posting_event_replies_ok() {
-        let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
-        let (connection, mut recv) = Connection::new_local_connection();
+        let relayer =
+            Arc::new(Relayer::new(Some(get_db(false).await), None).expect("valid relayer"));
+        let mut connection = relayer.create_new_local_connection().await;
 
         let note = get_note();
         let note_id = note.as_event().map(|x| x.id.clone()).unwrap();
@@ -970,7 +1014,7 @@ mod test {
                 }
                 .into()
             ),
-            recv.try_recv().ok()
+            connection.try_recv()
         );
     }