Эх сурвалжийг харах

Improving relayer + client pool integration

Cesar Rodas 3 сар өмнө
parent
commit
ed5dc227d8

+ 1 - 1
Cargo.lock

@@ -943,10 +943,10 @@ version = "0.1.0"
 dependencies = [
  "futures-util",
  "log",
+ "nostr-rs-client",
  "nostr-rs-rocksdb",
  "nostr-rs-storage-base",
  "nostr-rs-types",
- "parking_lot",
  "rand",
  "serde_json",
  "thiserror",

+ 4 - 0
crates/client/src/error.rs

@@ -21,4 +21,8 @@ pub enum Error {
     /// The client has no connection to any relayer
     #[error("There is no connection")]
     Disconnected,
+
+    /// The pool was already splitted
+    #[error("The pool was already splitted")]
+    AlreadySplitted,
 }

+ 10 - 5
crates/client/src/pool.rs

@@ -21,7 +21,7 @@ use url::Url;
 pub struct Pool {
     clients: RwLock<HashMap<Url, Client>>,
     sender: mpsc::Sender<(Response, Url)>,
-    receiver: mpsc::Receiver<(Response, Url)>,
+    receiver: Option<mpsc::Receiver<(Response, Url)>>,
     subscriptions: RwLock<HashMap<SubscriptionId, Vec<ActiveSubscription>>>,
 }
 
@@ -39,8 +39,8 @@ impl Pool {
         let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_BUFFER_SIZE);
         Self {
             clients: Default::default(),
+            receiver: Some(receiver),
             subscriptions: Default::default(),
-            receiver,
             sender,
         }
     }
@@ -56,19 +56,24 @@ impl Pool {
         Self {
             clients: RwLock::new(clients),
             subscriptions: Default::default(),
-            receiver,
+            receiver: Some(receiver),
             sender,
         }
     }
 
+    /// Splits the pool removing the receiver to be used in a different context
+    pub fn split(mut self) -> Result<(mpsc::Receiver<(Response, Url)>, Self), Error> {
+        Ok((self.receiver.take().ok_or(Error::AlreadySplitted)?, self))
+    }
+
     /// Tries to receive a message from any of the connected relayers
     pub fn try_recv(&mut self) -> Option<(Response, Url)> {
-        self.receiver.try_recv().ok()
+        self.receiver.as_mut()?.try_recv().ok()
     }
 
     /// Receives a message from any of the connected relayers
     pub async fn recv(&mut self) -> Option<(Response, Url)> {
-        self.receiver.recv().await
+        self.receiver.as_mut()?.recv().await
     }
 
     /// Subscribes to all the connected relayers

+ 1 - 1
crates/relayer/Cargo.toml

@@ -8,8 +8,8 @@ edition = "2021"
 [dependencies]
 nostr-rs-types = { path = "../types" }
 nostr-rs-storage-base = { path = "../storage/base" }
+nostr-rs-client = { path = "../client" }
 futures-util = "0.3.27"
-parking_lot = "0.12.1"
 tokio = { version = "1.26.0", features = ["sync", "macros", "rt", "time"] }
 tokio-tungstenite = { version = "0.18.0", features = [
     "rustls",

+ 8 - 6
crates/relayer/src/connection.rs

@@ -5,11 +5,13 @@ use nostr_rs_types::{
     types::Addr,
     Request, Response,
 };
-use parking_lot::RwLock;
 use std::collections::HashMap;
 use tokio::{
     net::TcpStream,
-    sync::mpsc::{channel, Receiver, Sender},
+    sync::{
+        mpsc::{channel, Receiver, Sender},
+        RwLock,
+    },
 };
 #[allow(unused_imports)]
 use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream};
@@ -136,13 +138,13 @@ impl Connection {
         self.sender.clone()
     }
 
-    pub fn get_subscription_id(&self, id: &str) -> Option<u128> {
-        let subscriptions = self.subscriptions.read();
+    pub async fn get_subscription_id(&self, id: &str) -> Option<u128> {
+        let subscriptions = self.subscriptions.read().await;
         subscriptions.get(id).copied()
     }
 
-    pub fn create_subscription(&self, id: String) -> (u128, Sender<Response>) {
-        let mut subscriptions = self.subscriptions.write();
+    pub async fn create_subscription(&self, id: String) -> (u128, Sender<Response>) {
+        let mut subscriptions = self.subscriptions.write().await;
         let internal_id = subscriptions.entry(id).or_insert_with(get_id);
         (*internal_id, self.sender.clone())
     }

+ 3 - 0
crates/relayer/src/error.rs

@@ -19,4 +19,7 @@ pub enum Error {
 
     #[error("TrySendError: {0}")]
     TrySendError(#[from] Box<tokio::sync::mpsc::error::TrySendError<Response>>),
+
+    #[error("Nostr client error: {0}")]
+    Client(#[from] nostr_rs_client::Error),
 }

+ 98 - 55
crates/relayer/src/relayer.rs

@@ -1,18 +1,23 @@
 use crate::{Connection, Error, Subscription};
 use futures_util::StreamExt;
+use nostr_rs_client::{Error as ClientError, Pool};
 use nostr_rs_storage_base::Storage;
 use nostr_rs_types::{
     relayer,
     types::{Event, SubscriptionId},
     Request, Response,
 };
-use parking_lot::{RwLock, RwLockReadGuard};
 use std::{collections::HashMap, ops::Deref, sync::Arc};
-use tokio::sync::mpsc;
-#[allow(unused_imports)]
 use tokio::{
     net::TcpStream,
-    sync::mpsc::{channel, Receiver, Sender},
+    sync::{
+        mpsc::{channel, Receiver, Sender},
+        RwLockReadGuard,
+    },
+};
+use tokio::{
+    sync::{mpsc, RwLock},
+    task::JoinHandle,
 };
 
 type SubId = u128;
@@ -39,23 +44,61 @@ pub struct Relayer<T: Storage> {
     /// fast iteration and match quickly filters.
     subscriptions: RwLock<HashMap<Subscription, RwLock<Subscriptions>>>,
     clients: RwLock<HashMap<u128, Connection>>,
-    #[allow(dead_code)]
     sender: Sender<(u128, Request)>,
+    client_pool: Option<(Pool, JoinHandle<()>)>,
+}
+
+impl<T: Storage> Drop for Relayer<T> {
+    fn drop(&mut self) {
+        if let Some((_, handle)) = self.client_pool.take() {
+            let _ = handle.abort();
+        }
+    }
 }
 
 impl<T: Storage> Relayer<T> {
-    pub fn new(storage: Option<T>) -> (Arc<Self>, Receiver<(u128, Request)>) {
+    pub fn new(
+        storage: Option<T>,
+        client_pool: Option<Pool>,
+    ) -> Result<(Arc<Self>, Receiver<(u128, Request)>), Error> {
         let (sender, receiver) = channel(100_000);
-        (
+        Ok((
             Arc::new(Self {
                 storage,
-                subscriptions: RwLock::new(HashMap::new()),
-                subscriptions_ids_index: RwLock::new(HashMap::new()),
-                clients: RwLock::new(HashMap::new()),
-                sender,
+                sender: sender.clone(),
+                subscriptions: Default::default(),
+                subscriptions_ids_index: Default::default(),
+                clients: Default::default(),
+                client_pool: if let Some(client_pool) = client_pool {
+                    Some(Self::handle_client_pool(client_pool, sender)?)
+                } else {
+                    None
+                },
             }),
             receiver,
-        )
+        ))
+    }
+
+    fn handle_client_pool(
+        client_pool: Pool,
+        sender: Sender<(u128, Request)>,
+    ) -> Result<(Pool, JoinHandle<()>), ClientError> {
+        let (mut receiver, client_pool) = client_pool.split()?;
+
+        let handle = tokio::spawn(async move {
+            loop {
+                if let Some((response, _)) = receiver.recv().await {
+                    match response {
+                        Response::Event(event) => {
+                            let _ = sender.send((0, Request::Event(event.event.into()))).await;
+                        }
+                        _ => {}
+                    }
+                }
+            }
+        });
+
+        Ok((client_pool, handle))
     }
 
     /// Returns a reference to the internal database
@@ -70,7 +113,7 @@ impl<T: Storage> Relayer<T> {
     ) -> Result<u128, Error> {
         let client = Connection::new(self.sender.clone(), disconnection_notify, stream).await?;
         let id = client.conn_id;
-        let mut clients = self.clients.write();
+        let mut clients = self.clients.write().await;
         clients.insert(id, client);
 
         Ok(id)
@@ -87,35 +130,32 @@ impl<T: Storage> Relayer<T> {
             }
             Request::Request(request) => {
                 // Create subscription
-                let (sub_id, receiver) =
-                    connection.create_subscription(request.subscription_id.deref().to_owned());
-                let mut sub_index = self.subscriptions_ids_index.write();
-                let mut subscriptions = self.subscriptions.write();
+                let (sub_id, receiver) = connection
+                    .create_subscription(request.subscription_id.deref().to_owned())
+                    .await;
+                let mut sub_index = self.subscriptions_ids_index.write().await;
+                let mut subscriptions = self.subscriptions.write().await;
                 if let Some(prev_subs) = sub_index.remove(&sub_id) {
                     // remove any previous subscriptions
-                    prev_subs.iter().for_each(|index| {
+                    for index in prev_subs.iter() {
                         if let Some(subscriptions) = subscriptions.get_mut(index) {
-                            subscriptions.write().remove(&sub_id);
+                            subscriptions.write().await.remove(&sub_id);
                         }
-                    });
+                    }
                 }
-                sub_index.insert(
-                    sub_id,
-                    Subscription::from_filters(&request.filters)
-                        .into_iter()
-                        .map(|index| {
-                            subscriptions
-                                .entry(index.clone())
-                                .or_insert_with(|| RwLock::new(HashMap::new()))
-                                .write()
-                                .insert(
-                                    sub_id,
-                                    (request.subscription_id.clone(), receiver.clone()),
-                                );
-                            index
-                        })
-                        .collect::<Vec<_>>(),
-                );
+
+                let mut sub_index_values = vec![];
+                for index in Subscription::from_filters(&request.filters).into_iter() {
+                    subscriptions
+                        .entry(index.clone())
+                        .or_insert_with(|| RwLock::new(HashMap::new()))
+                        .write()
+                        .await
+                        .insert(sub_id, (request.subscription_id.clone(), receiver.clone()));
+                    sub_index_values.push(index);
+                }
+
+                sub_index.insert(sub_id, sub_index_values);
 
                 drop(subscriptions);
                 drop(sub_index);
@@ -141,15 +181,15 @@ impl<T: Storage> Relayer<T> {
                     .send(relayer::EndOfStoredEvents(request.subscription_id.clone()).into());
             }
             Request::Close(close) => {
-                if let Some(id) = connection.get_subscription_id(&close.0) {
-                    let mut subscriptions = self.subscriptions_ids_index.write();
+                if let Some(id) = connection.get_subscription_id(&close.0).await {
+                    let mut subscriptions = self.subscriptions_ids_index.write().await;
                     if let Some(indexes) = subscriptions.remove(&id) {
-                        let mut subscriptions = self.subscriptions.write();
-                        indexes.iter().for_each(|index| {
-                            if let Some(subscriptions) = subscriptions.get_mut(index) {
-                                subscriptions.write().remove(&id);
+                        let mut subscriptions = self.subscriptions.write().await;
+                        for index in indexes {
+                            if let Some(subscriptions) = subscriptions.get_mut(&index) {
+                                subscriptions.write().await.remove(&id);
                             }
-                        });
+                        }
                     }
                 }
             }
@@ -167,7 +207,7 @@ impl<T: Storage> Relayer<T> {
         } else {
             return Ok(None);
         };
-        let connections = self.clients.read();
+        let connections = self.clients.read().await;
         let connection = connections
             .get(&conn_id)
             .ok_or(Error::UnknownConnection(conn_id))?;
@@ -175,8 +215,8 @@ impl<T: Storage> Relayer<T> {
         self.recv_request_from_client(connection, request).await
     }
 
-    pub fn send_to_conn(&self, conn_id: u128, response: Response) -> Result<(), Error> {
-        let connections = self.clients.read();
+    pub async fn send_to_conn(&self, conn_id: u128, response: Response) -> Result<(), Error> {
+        let connections = self.clients.read().await;
         let connection = connections
             .get(&conn_id)
             .ok_or(Error::UnknownConnection(conn_id))?;
@@ -185,7 +225,10 @@ impl<T: Storage> Relayer<T> {
     }
 
     #[inline]
-    fn broadcast_to_subscribers(subscriptions: RwLockReadGuard<Subscriptions>, event: &Event) {
+    fn broadcast_to_subscribers<'a>(
+        subscriptions: RwLockReadGuard<'a, Subscriptions>,
+        event: &Event,
+    ) {
         for (_, receiver) in subscriptions.iter() {
             let _ = receiver.1.try_send(
                 relayer::Event {
@@ -202,25 +245,25 @@ impl<T: Storage> Relayer<T> {
         if let Some(storage) = self.storage.as_ref() {
             let _ = storage.store_local_event(event).await;
         }
-        let subscriptions = self.subscriptions.read();
+        let subscriptions = self.subscriptions.read().await;
 
         for subscription_type in Subscription::from_event(event) {
             if let Some(subscribers) = subscriptions.get(&subscription_type) {
-                Self::broadcast_to_subscribers(subscribers.read(), event);
+                Self::broadcast_to_subscribers(subscribers.read().await, event);
             }
         }
     }
 
     #[inline]
-    pub fn store_and_broadcast(&self, event: &Event) {
+    pub async fn store_and_broadcast(&self, event: &Event) {
         if let Some(storage) = self.storage.as_ref() {
             let _ = storage.store(event);
         }
-        let subscriptions = self.subscriptions.read();
+        let subscriptions = self.subscriptions.read().await;
 
         for subscription_type in Subscription::from_event(event) {
             if let Some(subscribers) = subscriptions.get(&subscription_type) {
-                Self::broadcast_to_subscribers(subscribers.read(), event);
+                Self::broadcast_to_subscribers(subscribers.read().await, event);
             }
         }
     }
@@ -258,7 +301,7 @@ mod test {
                 {\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},
                 {\"#e\":[\"2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a\",\"a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1\"],\"kinds\":[1,6,7,9735]}
             ]").expect("valid object");
-        let (relayer, _) = Relayer::new(Some(get_db(true).await));
+        let (relayer, _) = Relayer::new(Some(get_db(true).await), None).expect("valid relayer");
         let (connection, mut recv) = Connection::new_for_test();
         let _ = relayer.recv_request_from_client(&connection, request).await;
         // ev1
@@ -307,7 +350,7 @@ mod test {
     #[tokio::test]
     async fn server_listener_real_time() {
         let request: Request = serde_json::from_str("[\"REQ\",\"1298169700973717\",{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[1,3,6,7,9735],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"#e\":[\"2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a\",\"a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1\"],\"kinds\":[1,6,7,9735]}]").expect("valid object");
-        let (relayer, _) = Relayer::new(Some(get_db(false).await));
+        let (relayer, _) = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
         let (connection, mut recv) = Connection::new_for_test();
         let _ = relayer.recv_request_from_client(&connection, request).await;
         // eod