Ver código fonte

Merge branch 'client-pool-non-mutable-self' of cesar/nostr-prototype into main

Cesar Rodas 1 ano atrás
pai
commit
bf4a4a4a2f
5 arquivos alterados com 49 adições e 41 exclusões
  1. 0 1
      Cargo.lock
  2. 0 1
      crates/client/Cargo.toml
  3. 4 4
      crates/client/src/client.rs
  4. 42 28
      crates/client/src/pool.rs
  5. 3 7
      src/main.rs

+ 0 - 1
Cargo.lock

@@ -901,7 +901,6 @@ dependencies = [
  "log",
  "nostr-rs-relayer",
  "nostr-rs-types",
- "parking_lot",
  "serde_json",
  "thiserror",
  "tokio",

+ 0 - 1
crates/client/Cargo.toml

@@ -15,7 +15,6 @@ tokio-tungstenite = { version = "0.18.0", features = [
 url = "2.3.1"
 serde_json = "1.0.94"
 futures-util = "0.3.27"
-parking_lot = "0.12.1"
 log = "0.4.17"
 futures = "0.3.28"
 

+ 4 - 4
crates/client/src/client.rs

@@ -68,14 +68,14 @@ impl Drop for Client {
 
 impl Client {
     /// Creates a new relayer
-    pub fn new(broadcast_to_listeners: mpsc::Sender<(Response, Url)>, url: Url) -> Self {
+    pub fn new(send_message_to_listener: mpsc::Sender<(Response, Url)>, url: Url) -> Self {
         let (sender_to_socket, send_to_socket) = mpsc::channel(100_000);
         let is_connected = Arc::new(AtomicBool::new(false));
 
         let subscriptions = Arc::new(RwLock::new(HashMap::new()));
 
         let worker = Self::spawn_background_client(
-            broadcast_to_listeners,
+            send_message_to_listener,
             send_to_socket,
             url.clone(),
             is_connected.clone(),
@@ -92,7 +92,7 @@ impl Client {
     }
 
     fn spawn_background_client(
-        broadcast_to_listeners: mpsc::Sender<(Response, Url)>,
+        send_message_to_listener: mpsc::Sender<(Response, Url)>,
         mut send_to_socket: mpsc::Receiver<Request>,
         url: Url,
         is_connected: Arc<AtomicBool>,
@@ -175,7 +175,7 @@ impl Client {
                             let msg: Result<Response, _> = serde_json::from_str(&msg);
 
                             if let Ok(msg) = msg {
-                                if let Err(error) = broadcast_to_listeners.try_send((msg.into(), url.clone())) {
+                                if let Err(error) = send_message_to_listener.try_send((msg.into(), url.clone())) {
                                     log::error!("{}: Reconnecting client because of {}", url, error);
                                     break;
                                 }

+ 42 - 28
crates/client/src/pool.rs

@@ -9,7 +9,7 @@ use nostr_rs_types::{
     Response,
 };
 use std::collections::HashMap;
-use tokio::sync::mpsc;
+use tokio::sync::{mpsc, RwLock};
 use url::Url;
 
 /// Clients
@@ -19,10 +19,10 @@ use url::Url;
 /// time, and to receive messages
 #[derive(Debug)]
 pub struct Pool {
-    clients: HashMap<Url, Client>,
+    clients: RwLock<HashMap<Url, Client>>,
     sender: mpsc::Sender<(Response, Url)>,
     receiver: mpsc::Receiver<(Response, Url)>,
-    subscriptions: HashMap<SubscriptionId, Vec<ActiveSubscription>>,
+    subscriptions: RwLock<HashMap<SubscriptionId, Vec<ActiveSubscription>>>,
 }
 
 impl Default for Pool {
@@ -38,9 +38,25 @@ impl Pool {
     pub fn new() -> Self {
         let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_BUFFER_SIZE);
         Self {
-            clients: HashMap::new(),
+            clients: Default::default(),
+            subscriptions: Default::default(),
             receiver,
+            sender,
+        }
+    }
+
+    /// Creates a new instance with a list of urls
+    pub fn new_with_clients(clients: Vec<Url>) -> Self {
+        let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_BUFFER_SIZE);
+        let clients = clients
+            .into_iter()
+            .map(|url| (url.clone(), Client::new(sender.clone(), url)))
+            .collect::<HashMap<_, _>>();
+
+        Self {
+            clients: RwLock::new(clients),
             subscriptions: Default::default(),
+            receiver,
             sender,
         }
     }
@@ -56,14 +72,15 @@ impl Pool {
     }
 
     /// Subscribes to all the connected relayers
-    pub async fn subscribe(&mut self, subscription: subscribe::Subscribe) -> Result<(), Error> {
-        let wait_all = self
-            .clients
+    pub async fn subscribe(&self, subscription: subscribe::Subscribe) -> Result<(), Error> {
+        let clients = self.clients.read().await;
+
+        let wait_all = clients
             .values()
             .map(|sender| sender.subscribe(subscription.clone()))
             .collect::<Vec<_>>();
 
-        self.subscriptions.insert(
+        self.subscriptions.write().await.insert(
             subscription.subscription_id,
             join_all(wait_all)
                 .await
@@ -76,25 +93,23 @@ impl Pool {
 
     /// Sends a request to all the connected relayers
     pub async fn post(&self, request: client::Event) {
-        let wait_all = self
-            .clients
-            .values()
-            .map(|sender| sender.post(request.clone()))
-            .collect::<Vec<_>>();
-
-        join_all(wait_all).await;
-    }
-
-    /// Returns a vector to all outgoing connections
-    pub fn get_connections(&self) -> Vec<&Client> {
-        self.clients.values().collect::<Vec<&Client>>()
+        let clients = self.clients.read().await;
+        join_all(
+            clients
+                .values()
+                .map(|sender| sender.post(request.clone()))
+                .collect::<Vec<_>>(),
+        )
+        .await;
     }
 
     /// Returns the number of active connections.
-    pub fn check_active_connections(&self) -> usize {
+    pub async fn check_active_connections(&self) -> usize {
         self.clients
+            .read()
+            .await
             .iter()
-            .filter(|(_, relayer)| relayer.is_connected())
+            .filter(|(_, client)| client.is_connected())
             .collect::<Vec<_>>()
             .len()
     }
@@ -103,13 +118,12 @@ impl Pool {
     ///
     /// This function will open a connection at most once, if a connection
     /// already exists false will be returned
-    pub fn connect_to(mut self, url: Url) -> Self {
-        if !self.clients.contains_key(&url) {
+    pub async fn connect_to(&self, url: Url) {
+        let mut clients = self.clients.write().await;
+
+        if !clients.contains_key(&url) {
             log::warn!("Connecting to {}", url);
-            self.clients
-                .insert(url.clone(), Client::new(self.sender.clone(), url));
+            clients.insert(url.clone(), Client::new(self.sender.clone(), url));
         }
-
-        self
     }
 }

+ 3 - 7
src/main.rs

@@ -1,6 +1,7 @@
 use futures::Future;
+use nostr_rs_client::Pool;
 use nostr_rs_rocksdb::RocksDb;
-use nostr_rs_types::{types::Filter, Request, Response};
+use nostr_rs_types::{relayer, types::Filter, Request, Response};
 use std::{collections::HashMap, env, fs, pin::Pin, sync::Arc};
 use tokio::{
     net::TcpListener,
@@ -63,12 +64,7 @@ async fn main() {
     println!("{:#?}", config);
 
     let db = RocksDb::new(&config.db_path).expect("db");
-    let mut client_pool = config
-        .relayers
-        .iter()
-        .fold(nostr_rs_client::Pool::new(), |clients, relayer_url| {
-            clients.connect_to(relayer_url.clone())
-        });
+    let mut client_pool = Pool::new_with_clients(config.relayers);
 
     let initial_subscription = vec![
         Filter {