|
@@ -9,7 +9,7 @@ use nostr_rs_types::{
|
|
|
Response,
|
|
|
};
|
|
|
use std::collections::HashMap;
|
|
|
-use tokio::sync::mpsc;
|
|
|
+use tokio::sync::{mpsc, RwLock};
|
|
|
use url::Url;
|
|
|
|
|
|
/// Clients
|
|
@@ -19,10 +19,10 @@ use url::Url;
|
|
|
/// time, and to receive messages
|
|
|
#[derive(Debug)]
|
|
|
pub struct Pool {
|
|
|
- clients: HashMap<Url, Client>,
|
|
|
+ clients: RwLock<HashMap<Url, Client>>,
|
|
|
sender: mpsc::Sender<(Response, Url)>,
|
|
|
receiver: mpsc::Receiver<(Response, Url)>,
|
|
|
- subscriptions: HashMap<SubscriptionId, Vec<ActiveSubscription>>,
|
|
|
+ subscriptions: RwLock<HashMap<SubscriptionId, Vec<ActiveSubscription>>>,
|
|
|
}
|
|
|
|
|
|
impl Default for Pool {
|
|
@@ -38,9 +38,25 @@ impl Pool {
|
|
|
pub fn new() -> Self {
|
|
|
let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_BUFFER_SIZE);
|
|
|
Self {
|
|
|
- clients: HashMap::new(),
|
|
|
+ clients: Default::default(),
|
|
|
+ subscriptions: Default::default(),
|
|
|
receiver,
|
|
|
+ sender,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Creates a new instance with a list of urls
|
|
|
+ pub fn new_with_clients(clients: Vec<Url>) -> Self {
|
|
|
+ let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_BUFFER_SIZE);
|
|
|
+ let clients = clients
|
|
|
+ .into_iter()
|
|
|
+ .map(|url| (url.clone(), Client::new(sender.clone(), url)))
|
|
|
+ .collect::<HashMap<_, _>>();
|
|
|
+
|
|
|
+ Self {
|
|
|
+ clients: RwLock::new(clients),
|
|
|
subscriptions: Default::default(),
|
|
|
+ receiver,
|
|
|
sender,
|
|
|
}
|
|
|
}
|
|
@@ -56,14 +72,15 @@ impl Pool {
|
|
|
}
|
|
|
|
|
|
/// Subscribes to all the connected relayers
|
|
|
- pub async fn subscribe(&mut self, subscription: subscribe::Subscribe) -> Result<(), Error> {
|
|
|
- let wait_all = self
|
|
|
- .clients
|
|
|
+ pub async fn subscribe(&self, subscription: subscribe::Subscribe) -> Result<(), Error> {
|
|
|
+ let clients = self.clients.read().await;
|
|
|
+
|
|
|
+ let wait_all = clients
|
|
|
.values()
|
|
|
.map(|sender| sender.subscribe(subscription.clone()))
|
|
|
.collect::<Vec<_>>();
|
|
|
|
|
|
- self.subscriptions.insert(
|
|
|
+ self.subscriptions.write().await.insert(
|
|
|
subscription.subscription_id,
|
|
|
join_all(wait_all)
|
|
|
.await
|
|
@@ -76,25 +93,23 @@ impl Pool {
|
|
|
|
|
|
/// Sends a request to all the connected relayers
|
|
|
pub async fn post(&self, request: client::Event) {
|
|
|
- let wait_all = self
|
|
|
- .clients
|
|
|
- .values()
|
|
|
- .map(|sender| sender.post(request.clone()))
|
|
|
- .collect::<Vec<_>>();
|
|
|
-
|
|
|
- join_all(wait_all).await;
|
|
|
- }
|
|
|
-
|
|
|
- /// Returns a vector to all outgoing connections
|
|
|
- pub fn get_connections(&self) -> Vec<&Client> {
|
|
|
- self.clients.values().collect::<Vec<&Client>>()
|
|
|
+ let clients = self.clients.read().await;
|
|
|
+ join_all(
|
|
|
+ clients
|
|
|
+ .values()
|
|
|
+ .map(|sender| sender.post(request.clone()))
|
|
|
+ .collect::<Vec<_>>(),
|
|
|
+ )
|
|
|
+ .await;
|
|
|
}
|
|
|
|
|
|
/// Returns the number of active connections.
|
|
|
- pub fn check_active_connections(&self) -> usize {
|
|
|
+ pub async fn check_active_connections(&self) -> usize {
|
|
|
self.clients
|
|
|
+ .read()
|
|
|
+ .await
|
|
|
.iter()
|
|
|
- .filter(|(_, relayer)| relayer.is_connected())
|
|
|
+ .filter(|(_, client)| client.is_connected())
|
|
|
.collect::<Vec<_>>()
|
|
|
.len()
|
|
|
}
|
|
@@ -103,13 +118,12 @@ impl Pool {
|
|
|
///
|
|
|
/// This function will open a connection at most once, if a connection
|
|
|
/// already exists false will be returned
|
|
|
- pub fn connect_to(mut self, url: Url) -> Self {
|
|
|
- if !self.clients.contains_key(&url) {
|
|
|
+ pub async fn connect_to(&self, url: Url) {
|
|
|
+ let mut clients = self.clients.write().await;
|
|
|
+
|
|
|
+ if !clients.contains_key(&url) {
|
|
|
log::warn!("Connecting to {}", url);
|
|
|
- self.clients
|
|
|
- .insert(url.clone(), Client::new(self.sender.clone(), url));
|
|
|
+ clients.insert(url.clone(), Client::new(self.sender.clone(), url));
|
|
|
}
|
|
|
-
|
|
|
- self
|
|
|
}
|
|
|
}
|