瀏覽代碼

Add a better subscription support in client

Subscriptions are restored on reconnection.

There is a better API to subscribe and to post events.

When ActiveSubscription struct goes out of scope it is automatically desuscribed from relayer
Cesar Rodas 3 月之前
父節點
當前提交
b5e407f9b8
共有 5 個文件被更改,包括 61 次插入248 次删除
  1. 2 2
      crates/client/src/lib.rs
  2. 44 29
      crates/client/src/pool.rs
  3. 0 188
      crates/client/src/relayer.rs
  4. 14 28
      crates/dump/src/main.rs
  5. 1 1
      crates/types/src/types/subscription_id.rs

+ 2 - 2
crates/client/src/lib.rs

@@ -8,12 +8,12 @@
 //!
 //! It will also have reconnection logic built-in internally.
 #![deny(missing_docs, warnings)]
+mod client;
 mod error;
 mod pool;
-mod relayer;
 
 pub use self::{
+    client::Client,
     error::Error,
     pool::{Event, Pool},
-    relayer::Relayer,
 };

+ 44 - 29
crates/client/src/pool.rs

@@ -1,10 +1,14 @@
 //! Relayers
 //!
 //! This is the main entry point to the client library.
-use crate::Relayer;
-use futures::Future;
-use nostr_rs_types::{Request, Response};
-use std::{collections::HashMap, pin::Pin};
+use crate::{client::ActiveSubscription, Client, Error};
+use futures::future::join_all;
+use nostr_rs_types::{
+    client::{self, subscribe},
+    types::SubscriptionId,
+    Response,
+};
+use std::collections::HashMap;
 use tokio::sync::mpsc;
 use url::Url;
 
@@ -15,9 +19,10 @@ use url::Url;
 /// time, and to receive messages
 #[derive(Debug)]
 pub struct Pool {
-    clients: HashMap<Url, Relayer>,
+    clients: HashMap<Url, Client>,
     sender: mpsc::Sender<(Event, Url)>,
     receiver: mpsc::Receiver<(Event, Url)>,
+    subscriptions: HashMap<SubscriptionId, Vec<ActiveSubscription>>,
 }
 
 impl Default for Pool {
@@ -47,6 +52,7 @@ impl Pool {
         Self {
             clients: HashMap::new(),
             receiver,
+            subscriptions: Default::default(),
             sender,
         }
     }
@@ -61,16 +67,39 @@ impl Pool {
         self.receiver.recv().await
     }
 
+    /// Subscribes to all the connected relayers
+    pub async fn subscribe(&mut self, subscription: subscribe::Subscribe) -> Result<(), Error> {
+        let wait_all = self
+            .clients
+            .values()
+            .map(|sender| sender.subscribe(subscription.clone()))
+            .collect::<Vec<_>>();
+
+        self.subscriptions.insert(
+            subscription.subscription_id,
+            join_all(wait_all)
+                .await
+                .into_iter()
+                .collect::<Result<Vec<_>, _>>()?,
+        );
+
+        Ok(())
+    }
+
     /// Sends a request to all the connected relayers
-    pub async fn send(&self, request: Request) {
-        for (_, sender) in self.clients.iter() {
-            let _ = sender.send(request.clone()).await;
-        }
+    pub async fn post(&self, request: client::Event) {
+        let wait_all = self
+            .clients
+            .values()
+            .map(|sender| sender.post(request.clone()))
+            .collect::<Vec<_>>();
+
+        join_all(wait_all).await;
     }
 
     /// Returns a vector to all outgoing connections
-    pub fn get_connections(&self) -> Vec<&Relayer> {
-        self.clients.values().collect::<Vec<&Relayer>>()
+    pub fn get_connections(&self) -> Vec<&Client> {
+        self.clients.values().collect::<Vec<&Client>>()
     }
 
     /// Returns the number of active connections.
@@ -82,29 +111,15 @@ impl Pool {
             .len()
     }
 
-    /// Creates a connection to a new relayer. If the connection is successful a
-    /// Callback will be called, with a list of previously sent requests, and a
-    /// Sender to send new requests to this relayer alone.
-    ///
-    /// The same callback will be called for every reconnection to the same
-    /// relayer, also the callback will be called, giving the chance to re-send
-    /// sent requests to the new connections
+    /// Creates a connection to a new relayer.
     ///
     /// This function will open a connection at most once, if a connection
     /// already exists false will be returned
-    pub fn connect_to<F>(mut self, url: Url, on_connection: Option<F>) -> Self
-    where
-        F: (Fn(&Url, mpsc::Sender<Request>) -> Pin<Box<dyn Future<Output = ()> + Send>>)
-            + Send
-            + Sync
-            + 'static,
-    {
+    pub fn connect_to(mut self, url: Url) -> Self {
         if !self.clients.contains_key(&url) {
             log::warn!("Connecting to {}", url);
-            self.clients.insert(
-                url.clone(),
-                Relayer::new(self.sender.clone(), url, on_connection),
-            );
+            self.clients
+                .insert(url.clone(), Client::new(self.sender.clone(), url));
         }
 
         self

+ 0 - 188
crates/client/src/relayer.rs

@@ -1,188 +0,0 @@
-use crate::{pool::Event, Error};
-use futures::Future;
-use futures_util::{SinkExt, StreamExt};
-use nostr_rs_types::{Request, Response};
-use std::{
-    pin::Pin,
-    sync::{
-        atomic::{AtomicBool, Ordering::Relaxed},
-        Arc,
-    },
-};
-use tokio::{
-    sync::mpsc,
-    task::JoinHandle,
-    time::{sleep, timeout, Duration},
-};
-use tokio_tungstenite::{connect_async, tungstenite::Message};
-use url::Url;
-
-/// Relayer object
-#[derive(Debug)]
-pub struct Relayer {
-    /// URL of the relayer
-    pub url: Url,
-    /// Sender to the relayer. This can be used to send a Requests to this
-    /// relayer
-    pub send_to_socket: mpsc::Sender<Request>,
-
-    worker: JoinHandle<()>,
-
-    is_connected: Arc<AtomicBool>,
-}
-
-const NO_ACTIVITY_TIMEOUT_SECS: u64 = 120;
-
-impl Drop for Relayer {
-    fn drop(&mut self) {
-        self.worker.abort()
-    }
-}
-
-impl Relayer {
-    /// Creates a new relayer
-    pub fn new<F>(
-        broadcast_to_listeners: mpsc::Sender<(Event, Url)>,
-        url: Url,
-        on_connection: Option<F>,
-    ) -> Self
-    where
-        F: (Fn(&Url, mpsc::Sender<Request>) -> Pin<Box<dyn Future<Output = ()> + Send>>)
-            + Send
-            + Sync
-            + 'static,
-    {
-        let (sender_to_socket, send_to_socket) = mpsc::channel(100_000);
-        let is_connected = Arc::new(AtomicBool::new(false));
-        let worker = Self::spawn_background_client(
-            broadcast_to_listeners,
-            sender_to_socket.clone(),
-            send_to_socket,
-            url.clone(),
-            is_connected.clone(),
-            on_connection,
-        );
-
-        Self {
-            url,
-            is_connected,
-            send_to_socket: sender_to_socket,
-            worker,
-        }
-    }
-
-    fn spawn_background_client<F>(
-        broadcast_to_listeners: mpsc::Sender<(Event, Url)>,
-        sender_to_socket: mpsc::Sender<Request>,
-        mut send_to_socket: mpsc::Receiver<Request>,
-        url: Url,
-        is_connected: Arc<AtomicBool>,
-        on_connection: Option<F>,
-    ) -> JoinHandle<()>
-    where
-        F: (Fn(&Url, mpsc::Sender<Request>) -> Pin<Box<dyn Future<Output = ()> + Send>>)
-            + Send
-            + Sync
-            + 'static,
-    {
-        is_connected.store(false, Relaxed);
-
-        tokio::spawn(async move {
-            let mut connection_attempts = 0;
-
-            loop {
-                log::warn!("{}: Connect attempt {}", url, connection_attempts);
-                connection_attempts += 1;
-                let mut socket = match connect_async(url.clone()).await {
-                    Ok(x) => x.0,
-                    Err(err) => {
-                        log::warn!("{}: Failed to connect: {}", url, err);
-                        sleep(Duration::from_secs(5)).await;
-                        continue;
-                    }
-                };
-
-                log::info!("Connected to {}", url);
-                connection_attempts = 0;
-
-                if let Some(on_connection) = &on_connection {
-                    on_connection(&url, sender_to_socket.clone()).await;
-                }
-
-                loop {
-                    tokio::select! {
-                        Some(msg) = send_to_socket.recv() => {
-                            if let Ok(json) = serde_json::to_string(&msg) {
-                                log::info!("{}: Sending {}", url, json);
-                                if let Err(x) = socket.send(Message::Text(json)).await {
-                                    log::error!("{} : Reconnecting due {}", url, x);
-                                    break;
-                                }
-                            }
-                        }
-                        msg = timeout(Duration::from_secs(NO_ACTIVITY_TIMEOUT_SECS), socket.next()) => {
-                            let msg = if let Ok(Some(Ok(msg))) = msg {
-                                is_connected.store(true, Relaxed);
-                                    match msg {
-                                        Message::Text(text) => text,
-                                        Message::Ping(msg) => {
-                                            if let Err(x) = socket.send(Message::Pong(msg)).await {
-                                                log::error!("{} : Reconnecting due error at sending Pong: {:?}", url, x);
-                                                break;
-                                            }
-                                            continue;
-                                        },
-                                        msg => {
-                                            log::error!("Unexpected {:?}", msg);
-                                            continue;
-                                        }
-                                    }
-                                } else {
-                                    log::error!("{} Reconnecting client due of empty recv: {:?}", url, msg);
-                                    break;
-                                };
-
-                            if msg.is_empty() {
-                                continue;
-                            }
-
-                            log::info!("New message: {}", msg);
-
-
-                            let msg: Result<Response, _> = serde_json::from_str(&msg);
-
-                            if let Ok(msg) = msg {
-                                if let Err(error) = broadcast_to_listeners.try_send((Event::Response(msg.into()), url.clone())) {
-                                    log::error!("{}: Reconnecting client because of {}", url, error);
-                                    break;
-                                }
-                            }
-                        }
-                        else => {
-                            log::warn!("{}: else", url);
-                            break;
-                        }
-                    }
-                }
-
-                is_connected.store(false, Relaxed);
-                // Throttle down to not spam the server with reconnections
-                sleep(Duration::from_millis(500)).await;
-            }
-        })
-    }
-
-    /// Checks if the relayer is connected. It is guaranteed that the relayer is
-    /// connected if this method returns true.
-    pub fn is_connected(&self) -> bool {
-        self.is_connected.load(Relaxed)
-    }
-
-    /// Sends a requests to this relayer
-    pub async fn send(&self, request: Request) -> Result<(), Error> {
-        self.send_to_socket
-            .send(request)
-            .await
-            .map_err(|e| Error::Sync(Box::new(e)))
-    }
-}

+ 14 - 28
crates/dump/src/main.rs

@@ -1,8 +1,5 @@
-use futures::Future;
 use nostr_rs_client::{Error as ClientError, Event, Pool};
-use nostr_rs_types::{client::Subscribe, Request, Response};
-use std::pin::Pin;
-use tokio::sync::mpsc;
+use nostr_rs_types::{client::Subscribe, Response};
 
 #[derive(Debug, thiserror::Error)]
 pub enum Error {
@@ -13,34 +10,23 @@ pub enum Error {
     Client(#[from] ClientError),
 }
 
-fn on_connection(
-    host: &str,
-    _socket: mpsc::Sender<Request>,
-) -> Pin<Box<dyn Future<Output = ()> + Send>> {
-    println!("Reconnecting to {}", host);
-    Box::pin(async move {
-        let _ = _socket.send(Subscribe::default().into()).await;
-    })
-}
-
 #[tokio::main]
 async fn main() {
     env_logger::init();
-    let mut clients = Pool::new()
-        .connect_to("wss://relay.damus.io/", Some(on_connection))
-        .expect("valid url")
-        .connect_to("wss://brb.io", Some(on_connection))
-        .expect("valid url")
-        .connect_to("wss://nos.lol", Some(on_connection))
-        .expect("valid url")
-        .connect_to("wss://relay.current.fyi", Some(on_connection))
-        .expect("valid url")
-        .connect_to("wss://eden.nostr.land", Some(on_connection))
-        .expect("valid url")
-        .connect_to("wss://relay.snort.social", Some(on_connection))
-        .expect("valid url");
+    let mut clients = vec![
+        "wss://relay.damus.io/",
+        "wss://brb.io",
+        "wss://nos.lol",
+        "wss://relay.current.fyi",
+        "wss://eden.nostr.land",
+        "wss://relay.snort.social",
+    ]
+    .into_iter()
+    .fold(Pool::new(), |clients, host| {
+        clients.connect_to(host.parse().expect("valid url"))
+    });
 
-    clients.send(Subscribe::default().into()).await;
+    let _ = clients.subscribe(Subscribe::default().into()).await;
 
     loop {
         if let Some((msg, relayed_by)) = clients.recv().await {

+ 1 - 1
crates/types/src/types/subscription_id.rs

@@ -23,7 +23,7 @@ pub enum Error {
 /// The rules are simple, any UTF-8 valid string with fewer than 32 characters
 ///
 /// By default a random ID will be created if needed.
-#[derive(Debug, Clone, PartialEq, Eq)]
+#[derive(Debug, Clone, Hash, PartialEq, Eq)]
 pub struct SubscriptionId(String);
 
 impl Deref for SubscriptionId {