Browse Source

Added tests for relayer + subscribe pool

Cesar Rodas 3 tháng trước cách đây
mục cha
commit
bbe2047601

+ 5 - 1
crates/client/src/lib.rs

@@ -14,4 +14,8 @@ mod pool;
 
 pub use url::Url;
 
-pub use self::{client::Client, error::Error, pool::Pool};
+pub use self::{
+    client::Client,
+    error::Error,
+    pool::{Pool, PoolSubscription},
+};

+ 1 - 0
crates/client/src/pool.rs

@@ -43,6 +43,7 @@ impl Default for Pool {
 }
 
 /// Return a subscription that will be removed when dropped
+#[derive(Debug)]
 pub struct PoolSubscription {
     subscription_id: SubscriptionId,
     subscriptions: Subscriptions,

+ 12 - 5
crates/relayer/src/connection.rs

@@ -1,5 +1,6 @@
 use crate::{subscription::ActiveSubscription, Error};
 use futures_util::{SinkExt, StreamExt};
+use nostr_rs_client::PoolSubscription;
 use nostr_rs_types::{
     relayer::ROk,
     types::{Addr, SubscriptionId},
@@ -54,7 +55,8 @@ impl ConnectionId {
 pub struct Connection {
     conn_id: ConnectionId,
     sender: Sender<Response>,
-    subscriptions: RwLock<HashMap<SubscriptionId, Vec<ActiveSubscription>>>,
+    subscriptions:
+        RwLock<HashMap<SubscriptionId, (Option<PoolSubscription>, Vec<ActiveSubscription>)>>,
     handler: Option<JoinHandle<()>>,
 }
 
@@ -76,7 +78,7 @@ impl Connection {
             Self {
                 conn_id: ConnectionId::default(),
                 sender,
-                subscriptions: RwLock::new(HashMap::new()),
+                subscriptions: Default::default(),
                 handler: None,
             },
             receiver,
@@ -95,7 +97,7 @@ impl Connection {
         Ok(Self {
             conn_id,
             sender,
-            subscriptions: RwLock::new(HashMap::new()),
+            subscriptions: Default::default(),
             handler: Some(Self::spawn(
                 send_message_to_relayer,
                 websocket,
@@ -187,11 +189,16 @@ impl Connection {
     }
 
     /// Create a subscription for this connection
-    pub async fn keep_track_subscription(
+    pub async fn subscribe(
         &self,
         id: SubscriptionId,
-        subscriptions: Vec<ActiveSubscription>,
+        subscriptions: (Option<PoolSubscription>, Vec<ActiveSubscription>),
     ) {
         self.subscriptions.write().await.insert(id, subscriptions);
     }
+
+    /// Remove a subscription for this connection
+    pub async fn unsubscribe(&self, id: &SubscriptionId) {
+        self.subscriptions.write().await.remove(id);
+    }
 }

+ 174 - 14
crates/relayer/src/relayer.rs

@@ -122,6 +122,9 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
         }))
     }
 
+    /// Handle the client pool
+    ///
+    /// Main loop to consume messages from the client pool and broadcast them to the local subscribers
     fn handle_client_pool(
         client_pool: Pool,
         sender: Sender<(ConnectionId, Request)>,
@@ -140,6 +143,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                                 ))
                                 .await;
                         }
+                        Response::EndOfStoredEvents(_) => {}
                         x => {
                             println!("x => {:?}", x);
                         }
@@ -195,11 +199,18 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                 }
             }
             Request::Request(request) => {
-                if let Some((client_pool, _)) = self.client_pool.as_ref() {
+                let foreign_subscription = if let Some((client_pool, _)) = self.client_pool.as_ref()
+                {
                     // pass the subscription request to the pool of clients, so this relayer
                     // can relay any unknown event to the clients through their subscriptions
-                    let _ = client_pool.subscribe(request.filters.clone().into()).await;
-                }
+                    Some(
+                        client_pool
+                            .subscribe(request.filters.clone().into())
+                            .await?,
+                    )
+                } else {
+                    None
+                };
 
                 if let Some(storage) = self.storage.as_ref() {
                     // Sent all events that match the filter that are stored in our database
@@ -222,20 +233,23 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                     .send(relayer::EndOfStoredEvents(request.subscription_id.clone()).into());
 
                 connection
-                    .keep_track_subscription(
+                    .subscribe(
                         request.subscription_id.clone(),
-                        self.subscriptions
-                            .subscribe(
-                                connection.get_conn_id(),
-                                connection.get_sender(),
-                                request.clone(),
-                            )
-                            .await,
+                        (
+                            foreign_subscription,
+                            self.subscriptions
+                                .subscribe(
+                                    connection.get_conn_id(),
+                                    connection.get_sender(),
+                                    request.clone(),
+                                )
+                                .await,
+                        ),
                     )
                     .await;
             }
-            Request::Close(_close) => {
-                todo!()
+            Request::Close(close) => {
+                connection.unsubscribe(&*close).await;
             }
         };
 
@@ -259,11 +273,27 @@ mod test {
 
     use super::*;
     use futures::future::join_all;
+    use nostr_rs_client::Url;
     use nostr_rs_memory::Memory;
-    use nostr_rs_types::Request;
+    use nostr_rs_types::{account::Account, types::Content, Request};
     use serde_json::json;
     use tokio::time::sleep;
 
+    async fn dummy_server(port: u16, client_pool: Option<Pool>) -> (Url, JoinHandle<()>) {
+        let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
+            .await
+            .unwrap();
+        let local_addr = listener.local_addr().expect("addr");
+
+        let relayer =
+            Relayer::new(Some(Memory::default()), client_pool).expect("valid dummy server");
+        let stopper = relayer.main(listener).expect("valid main loop");
+        (
+            Url::parse(&format!("ws://{}", local_addr)).expect("valid url"),
+            stopper,
+        )
+    }
+
     fn get_note() -> Request {
         serde_json::from_value(json!(
             [
@@ -732,4 +762,134 @@ mod test {
 
         assert_eq!(relayer.total_subscribers(), 0);
     }
+
+    #[tokio::test]
+    async fn relayer_posts_to_custom_posts_to_all_clients() {
+        let (relayer1, _) = dummy_server(0, None).await;
+        let (relayer2, _) = dummy_server(0, None).await;
+        let (relayer3, _) = dummy_server(0, None).await;
+        let (main_relayer, _) = dummy_server(
+            0,
+            Some(Pool::new_with_clients(vec![
+                relayer1.clone(),
+                relayer2.clone(),
+                relayer3.clone(),
+            ])),
+        )
+        .await;
+
+        let mut reader_client =
+            Pool::new_with_clients(vec![relayer1.clone(), relayer2.clone(), relayer3.clone()]);
+        let main_client = Pool::new_with_clients(vec![main_relayer]);
+
+        let _sub = reader_client
+            .subscribe(Default::default())
+            .await
+            .expect("valid subscription");
+
+        sleep(Duration::from_millis(20)).await;
+
+        for _ in 0..3 {
+            assert!(reader_client
+                .try_recv()
+                .map(|(r, _)| r)
+                .expect("valid message")
+                .as_end_of_stored_events()
+                .is_some());
+        }
+        assert!(reader_client.try_recv().is_none());
+
+        let account1 = Account::default();
+        let signed_content = account1
+            .sign_content(vec![], Content::ShortTextNote("test 0".to_owned()), None)
+            .expect("valid signed content");
+
+        // account1 posts a new note into the relayer1, and the main relayer
+        // should get a copy of it, as well as it is connected to relayer2 and
+        // relayer1.
+        main_client.post(signed_content.clone().into()).await;
+
+        sleep(Duration::from_millis(10)).await;
+
+        let responses = (0..3)
+            .map(|_| reader_client.try_recv().expect("valid message"))
+            .filter_map(|(r, url)| {
+                r.as_event()
+                    .map(|r| (url.port().expect("port"), r.to_owned()))
+            })
+            .collect::<HashMap<_, _>>();
+
+        assert!(reader_client.try_recv().is_none());
+
+        assert_eq!(responses.len(), 3);
+        assert_eq!(
+            responses
+                .get(&relayer1.port().expect("port"))
+                .map(|x| x.event.id.clone()),
+            Some(signed_content.id.clone())
+        );
+        assert_eq!(
+            responses
+                .get(&relayer2.port().expect("port"))
+                .map(|x| x.event.id.clone()),
+            Some(signed_content.id.clone())
+        );
+        assert_eq!(
+            responses
+                .get(&relayer3.port().expect("port"))
+                .map(|x| x.event.id.clone()),
+            Some(signed_content.id)
+        );
+    }
+
+    #[tokio::test]
+    async fn relayer_with_client_pool() {
+        let (relayer1, _) = dummy_server(0, None).await;
+        let (relayer2, _) = dummy_server(0, None).await;
+        let (main_relayer, _) = dummy_server(
+            0,
+            Some(Pool::new_with_clients(vec![relayer1.clone(), relayer2])),
+        )
+        .await;
+
+        let secondary_client = Pool::new_with_clients(vec![relayer1]);
+
+        // Create a subscription in the main relayer, main_client is only
+        // connected to the main relayer
+        let mut main_client = Pool::new_with_clients(vec![main_relayer]);
+        let _sub = main_client
+            .subscribe(Default::default())
+            .await
+            .expect("valid subscription");
+
+        sleep(Duration::from_millis(10)).await;
+        assert!(main_client
+            .try_recv()
+            .map(|(r, _)| r)
+            .expect("valid message")
+            .as_end_of_stored_events()
+            .is_some());
+        assert!(main_client.try_recv().is_none());
+
+        let account1 = Account::default();
+        let signed_content = account1
+            .sign_content(vec![], Content::ShortTextNote("test 0".to_owned()), None)
+            .expect("valid signed content");
+
+        // account1 posts a new note into the relayer1, and the main relayer
+        // should get a copy of it, as well as it is connected to relayer2 and
+        // relayer1.
+        secondary_client.post(signed_content.clone().into()).await;
+
+        // wait for the note to be delivered
+        sleep(Duration::from_millis(10)).await;
+        assert_eq!(
+            Some((signed_content.id, signed_content.signature)),
+            main_client
+                .try_recv()
+                .and_then(|(r, _)| r.as_event().cloned().map(|x| x.event))
+                .map(|x| (x.id, x.signature))
+        );
+        assert!(main_client.try_recv().is_none());
+    }
 }