|
@@ -5,7 +5,6 @@ use crate::{
|
|
|
use futures_util::StreamExt;
|
|
|
use nostr_rs_client::{pool::subscription::PoolSubscriptionId, Pool, Url};
|
|
|
use nostr_rs_storage_base::Storage;
|
|
|
-use nostr_rs_subscription_manager::SubscriptionManager;
|
|
|
use nostr_rs_types::{
|
|
|
relayer::{self, ROk, ROkStatus},
|
|
|
types::{Addr, Event, SubscriptionId},
|
|
@@ -15,6 +14,7 @@ use std::{
|
|
|
collections::{HashMap, HashSet},
|
|
|
ops::Deref,
|
|
|
sync::Arc,
|
|
|
+ time::Instant,
|
|
|
};
|
|
|
use tokio::{
|
|
|
net::{TcpListener, TcpStream},
|
|
@@ -42,6 +42,12 @@ impl Default for RelayerSubscriptionId {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+type Connections = Arc<RwLock<HashMap<ConnectionId, Connection>>>;
|
|
|
+type SubscriptionManager =
|
|
|
+ Arc<nostr_rs_subscription_manager::SubscriptionManager<RelayerSubscriptionId, ()>>;
|
|
|
+type ClientPoolSubscriptions =
|
|
|
+ Arc<RwLock<HashMap<PoolSubscriptionId, (SubscriptionId, ConnectionId)>>>;
|
|
|
+
|
|
|
/// Relayer struct
|
|
|
///
|
|
|
pub struct Relayer<T: Storage + Send + Sync + 'static> {
|
|
@@ -51,9 +57,9 @@ pub struct Relayer<T: Storage + Send + Sync + 'static> {
|
|
|
/// be able to perform any optimization like prefetching content while offline
|
|
|
storage: Arc<Option<T>>,
|
|
|
/// Subscription manager
|
|
|
- subscription_manager: Arc<SubscriptionManager<RelayerSubscriptionId, ()>>,
|
|
|
+ subscription_manager: SubscriptionManager,
|
|
|
/// List of all active connections
|
|
|
- connections: Arc<RwLock<HashMap<ConnectionId, Connection>>>,
|
|
|
+ connections: Connections,
|
|
|
/// This Sender can be used to send requests from anywhere to the relayer.
|
|
|
send_to_relayer: Sender<(ConnectionId, Request)>,
|
|
|
/// This Receiver is the relayer the way the relayer receives messages
|
|
@@ -63,9 +69,9 @@ pub struct Relayer<T: Storage + Send + Sync + 'static> {
|
|
|
///
|
|
|
/// A relayer can optionally be connected to a pool of clients to get
|
|
|
/// foreign events.
|
|
|
- client_pool: Option<Pool>,
|
|
|
+ client_pool: Option<Arc<Pool>>,
|
|
|
client_pool_receiver: Option<Receiver<(Response, Url)>>,
|
|
|
- client_pool_subscriptions: RwLock<HashMap<PoolSubscriptionId, (SubscriptionId, ConnectionId)>>,
|
|
|
+ client_pool_subscriptions: ClientPoolSubscriptions,
|
|
|
}
|
|
|
|
|
|
impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
@@ -82,7 +88,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
|
|
|
let (client_pool_receiver, client_pool) = if let Some(client_pool) = client_pool {
|
|
|
let result = client_pool.split()?;
|
|
|
- (result.0, Some(result.1))
|
|
|
+ (result.0, Some(Arc::new(result.1)))
|
|
|
} else {
|
|
|
let (_, receiver) = mpsc::channel(1);
|
|
|
(receiver, None)
|
|
@@ -95,7 +101,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
relayer_receiver: Some(relayer_receiver),
|
|
|
connections: Default::default(),
|
|
|
client_pool_receiver: Some(client_pool_receiver),
|
|
|
- client_pool: client_pool,
|
|
|
+ client_pool,
|
|
|
client_pool_subscriptions: Default::default(),
|
|
|
})
|
|
|
}
|
|
@@ -138,6 +144,8 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
|
|
|
let handle = tokio::spawn(async move {
|
|
|
loop {
|
|
|
+ let start = Instant::now();
|
|
|
+ println!("{}", client_pool_receiver.len());
|
|
|
tokio::select! {
|
|
|
Ok((stream, _)) = server.accept() => {
|
|
|
// accept new connections
|
|
@@ -149,7 +157,12 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
Response::Event(event) => {
|
|
|
// we received a message from the client pool, store it locally
|
|
|
// and re-broadcast it.
|
|
|
- let _ = this.broadcast(event.deref()).await;
|
|
|
+ tokio::spawn(Self::broadcast(
|
|
|
+ this.storage.clone(),
|
|
|
+ this.subscription_manager.clone(),
|
|
|
+ this.connections.clone(),
|
|
|
+ event.event
|
|
|
+ ));
|
|
|
}
|
|
|
Response::EndOfStoredEvents(sub) => {
|
|
|
let connections = this.connections.read().await;
|
|
@@ -163,23 +176,23 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
continue
|
|
|
};
|
|
|
|
|
|
- let _ = connection.send(Response::EndOfStoredEvents(sub_id.into()));
|
|
|
+ let _ = connection.respond(Response::EndOfStoredEvents(sub_id.into()));
|
|
|
+ let duration = start.elapsed();
|
|
|
+ println!("xTime elapsed: {} ms", duration.as_millis());
|
|
|
}
|
|
|
_ => {}
|
|
|
}
|
|
|
}
|
|
|
Some((conn_id, request)) = receiver.recv() => {
|
|
|
- // receive messages from our clients
|
|
|
- let connections = this.connections.read().await;
|
|
|
- let connection = if let Some(connection) = connections.get(&conn_id) {
|
|
|
- connection
|
|
|
- } else {
|
|
|
- continue;
|
|
|
- };
|
|
|
-
|
|
|
- // receive messages from clients
|
|
|
- let _ = this.process_request_from_client(connection, request).await;
|
|
|
- drop(connections);
|
|
|
+ tokio::spawn(Self::process_request(
|
|
|
+ this.storage.clone(),
|
|
|
+ this.client_pool.clone(),
|
|
|
+ this.client_pool_subscriptions.clone(),
|
|
|
+ this.subscription_manager.clone(),
|
|
|
+ this.connections.clone(),
|
|
|
+ conn_id,
|
|
|
+ request.clone()
|
|
|
+ ));
|
|
|
}
|
|
|
else => {
|
|
|
}
|
|
@@ -242,17 +255,50 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
Ok(id)
|
|
|
}
|
|
|
|
|
|
- /// Process a request from a connected client
|
|
|
+ #[cfg(test)]
|
|
|
async fn process_request_from_client(
|
|
|
&self,
|
|
|
- connection: &Connection,
|
|
|
+ connection: &LocalConnection<T>,
|
|
|
+ request: Request,
|
|
|
+ ) -> Result<(), Error> {
|
|
|
+ Self::process_request(
|
|
|
+ self.storage.clone(),
|
|
|
+ self.client_pool.clone(),
|
|
|
+ self.client_pool_subscriptions.clone(),
|
|
|
+ self.subscription_manager.clone(),
|
|
|
+ self.connections.clone(),
|
|
|
+ connection.conn_id,
|
|
|
+ request,
|
|
|
+ )
|
|
|
+ .await
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Process a request from a connected client
|
|
|
+ async fn process_request(
|
|
|
+ storage: Arc<Option<T>>,
|
|
|
+ client_pool: Option<Arc<Pool>>,
|
|
|
+ client_pool_subscriptions: ClientPoolSubscriptions,
|
|
|
+ subscription_manager: SubscriptionManager,
|
|
|
+ connections: Connections,
|
|
|
+ connection_id: ConnectionId,
|
|
|
request: Request,
|
|
|
) -> Result<(), Error> {
|
|
|
match request {
|
|
|
Request::Event(event) => {
|
|
|
+ let read_connections = connections.read().await;
|
|
|
+ let connection = read_connections
|
|
|
+ .get(&connection_id)
|
|
|
+ .ok_or(Error::UnknownConnection(connection_id))?;
|
|
|
let event_id: Addr = event.id.clone().into();
|
|
|
- if !self.broadcast(&event).await? {
|
|
|
- connection.send(
|
|
|
+ if !Self::broadcast(
|
|
|
+ storage.clone(),
|
|
|
+ subscription_manager.clone(),
|
|
|
+ connections.clone(),
|
|
|
+ event.deref().clone(),
|
|
|
+ )
|
|
|
+ .await?
|
|
|
+ {
|
|
|
+ connection.respond(
|
|
|
ROk {
|
|
|
id: event_id,
|
|
|
status: ROkStatus::Duplicate,
|
|
@@ -262,17 +308,17 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
return Ok(());
|
|
|
}
|
|
|
|
|
|
- if let Some(storage) = self.storage.as_ref() {
|
|
|
+ if let Some(storage) = storage.as_ref() {
|
|
|
let _ = storage.store_local_event(&event).await;
|
|
|
}
|
|
|
|
|
|
- if let Some(client_pool) = self.client_pool.as_ref() {
|
|
|
+ if let Some(client_pool) = client_pool.as_ref() {
|
|
|
// pass the event to the pool of clients, so this relayer can relay
|
|
|
// their local events to the clients in the network of relayers
|
|
|
let _ = client_pool.post(event).await;
|
|
|
}
|
|
|
|
|
|
- connection.send(
|
|
|
+ connection.respond(
|
|
|
ROk {
|
|
|
id: event_id,
|
|
|
status: ROkStatus::Ok,
|
|
@@ -281,7 +327,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
)?;
|
|
|
}
|
|
|
Request::Request(request) => {
|
|
|
- let foreign_subscription = if let Some(client_pool) = self.client_pool.as_ref() {
|
|
|
+ let foreign_subscription = if let Some(client_pool) = client_pool.as_ref() {
|
|
|
// If this relay is connected to other relays through the
|
|
|
// client pool, create the same subscription in them as
|
|
|
// well, with the main goal of fetching any foreign event
|
|
@@ -295,9 +341,9 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
.subscribe(request.filters.clone().into())
|
|
|
.await?;
|
|
|
|
|
|
- self.client_pool_subscriptions.write().await.insert(
|
|
|
+ client_pool_subscriptions.write().await.insert(
|
|
|
foreign_sub_id.clone(),
|
|
|
- (request.subscription_id.clone(), connection.get_conn_id()),
|
|
|
+ (request.subscription_id.clone(), connection_id),
|
|
|
);
|
|
|
|
|
|
Some(foreign_sub_id)
|
|
@@ -305,7 +351,12 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
None
|
|
|
};
|
|
|
|
|
|
- if let Some(storage) = self.storage.as_ref() {
|
|
|
+ let read_connections = connections.read().await;
|
|
|
+ let connection = read_connections
|
|
|
+ .get(&connection_id)
|
|
|
+ .ok_or(Error::UnknownConnection(connection_id))?;
|
|
|
+
|
|
|
+ if let Some(storage) = storage.as_ref() {
|
|
|
let mut sent = HashSet::new();
|
|
|
// Sent all events that match the filter that are stored in our database
|
|
|
for filter in request.filters.clone().into_iter() {
|
|
@@ -316,7 +367,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
continue;
|
|
|
}
|
|
|
sent.insert(event.id.clone());
|
|
|
- let _ = connection.send(
|
|
|
+ let _ = connection.respond(
|
|
|
relayer::Event {
|
|
|
subscription_id: request.subscription_id.clone(),
|
|
|
event,
|
|
@@ -330,8 +381,9 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
if foreign_subscription.is_none() {
|
|
|
// If there is a foreign subscription, we shouldn't send a
|
|
|
// EOS until we have got EOS from all foreign relays
|
|
|
- let _ = connection
|
|
|
- .send(relayer::EndOfStoredEvents(request.subscription_id.clone()).into());
|
|
|
+ let _ = connection.respond(
|
|
|
+ relayer::EndOfStoredEvents(request.subscription_id.clone()).into(),
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
connection
|
|
@@ -339,7 +391,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
request.subscription_id.clone(),
|
|
|
(
|
|
|
foreign_subscription,
|
|
|
- self.subscription_manager
|
|
|
+ subscription_manager
|
|
|
.subscribe(
|
|
|
(request.subscription_id, connection.get_conn_id()).into(),
|
|
|
request.filters,
|
|
@@ -351,7 +403,13 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
.await;
|
|
|
}
|
|
|
Request::Close(close) => {
|
|
|
- connection.unsubscribe(&close).await;
|
|
|
+ connections
|
|
|
+ .read()
|
|
|
+ .await
|
|
|
+ .get(&connection_id)
|
|
|
+ .ok_or(Error::UnknownConnection(connection_id))?
|
|
|
+ .unsubscribe(&close)
|
|
|
+ .await;
|
|
|
}
|
|
|
};
|
|
|
|
|
@@ -360,51 +418,24 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
|
|
|
#[inline]
|
|
|
/// A non-blocking version of broadcast
|
|
|
- #[allow(dead_code)]
|
|
|
- fn broadcast_and_forget(&self, event: Event) {
|
|
|
- let storage = self.storage.clone();
|
|
|
- let connections = self.connections.clone();
|
|
|
- let subscription_manager = self.subscription_manager.clone();
|
|
|
-
|
|
|
- tokio::spawn(async move {
|
|
|
- if let Some(storage) = storage.as_ref() {
|
|
|
- if !storage.store(&event).await.unwrap_or_default() {
|
|
|
- return;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- let connections = connections.read().await;
|
|
|
- for RelayerSubscriptionId((sub_id, conn_id)) in
|
|
|
- subscription_manager.get_subscribers(&event).await
|
|
|
- {
|
|
|
- if let Some(connection) = connections.get(&conn_id) {
|
|
|
- let _ = connection.send(
|
|
|
- relayer::Event {
|
|
|
- subscription_id: sub_id,
|
|
|
- event: event.clone(),
|
|
|
- }
|
|
|
- .into(),
|
|
|
- );
|
|
|
- }
|
|
|
- }
|
|
|
- });
|
|
|
- }
|
|
|
-
|
|
|
- #[inline]
|
|
|
- /// Broadcast a given event to all local subscribers
|
|
|
- pub async fn broadcast(&self, event: &Event) -> Result<bool, Error> {
|
|
|
- if let Some(storage) = self.storage.as_ref() {
|
|
|
- if !storage.store(event).await? {
|
|
|
+ pub async fn broadcast(
|
|
|
+ storage: Arc<Option<T>>,
|
|
|
+ subscription_manager: SubscriptionManager,
|
|
|
+ connections: Connections,
|
|
|
+ event: Event,
|
|
|
+ ) -> Result<bool, Error> {
|
|
|
+ if let Some(storage) = storage.as_ref() {
|
|
|
+ if !storage.store(&event).await? {
|
|
|
return Ok(false);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- let connections = self.connections.read().await;
|
|
|
+ let connections = connections.read().await;
|
|
|
for RelayerSubscriptionId((sub_id, conn_id)) in
|
|
|
- self.subscription_manager.get_subscribers(event).await
|
|
|
+ subscription_manager.get_subscribers(&event).await
|
|
|
{
|
|
|
if let Some(connection) = connections.get(&conn_id) {
|
|
|
- let _ = connection.send(
|
|
|
+ let _ = connection.respond(
|
|
|
relayer::Event {
|
|
|
subscription_id: sub_id,
|
|
|
event: event.clone(),
|
|
@@ -413,14 +444,13 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
|
|
|
);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
Ok(true)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
|
mod test {
|
|
|
- use std::time::Duration;
|
|
|
-
|
|
|
use super::*;
|
|
|
use futures::future::join_all;
|
|
|
use nostr_rs_client::Url;
|
|
@@ -431,6 +461,7 @@ mod test {
|
|
|
Request,
|
|
|
};
|
|
|
use serde_json::json;
|
|
|
+ use std::time::Duration;
|
|
|
use tokio::time::sleep;
|
|
|
|
|
|
async fn dummy_server(port: u16, client_pool: Option<Pool>) -> (Url, JoinHandle<()>) {
|
|
@@ -517,8 +548,10 @@ mod test {
|
|
|
},
|
|
|
]))
|
|
|
.expect("valid object");
|
|
|
- let relayer = Relayer::new(Some(get_db(true).await), None).expect("valid relayer");
|
|
|
- let (connection, mut recv) = Connection::new_local_connection();
|
|
|
+ let relayer =
|
|
|
+ Arc::new(Relayer::new(Some(get_db(true).await), None).expect("valid relayer"));
|
|
|
+
|
|
|
+ let mut connection = relayer.create_new_local_connection().await;
|
|
|
|
|
|
let note = get_note_with_custom_tags(json!([["f", "foo"]]));
|
|
|
|
|
@@ -535,7 +568,8 @@ mod test {
|
|
|
// ev1
|
|
|
assert_eq!(
|
|
|
ROkStatus::Ok,
|
|
|
- recv.try_recv()
|
|
|
+ connection
|
|
|
+ .try_recv()
|
|
|
.expect("valid")
|
|
|
.as_ok()
|
|
|
.cloned()
|
|
@@ -546,17 +580,22 @@ mod test {
|
|
|
// ev1
|
|
|
assert_eq!(
|
|
|
note,
|
|
|
- recv.try_recv().expect("valid").as_event().unwrap().event
|
|
|
+ connection
|
|
|
+ .try_recv()
|
|
|
+ .expect("valid")
|
|
|
+ .as_event()
|
|
|
+ .unwrap()
|
|
|
+ .event
|
|
|
);
|
|
|
|
|
|
// eod
|
|
|
- assert!(recv
|
|
|
+ assert!(connection
|
|
|
.try_recv()
|
|
|
.expect("valid")
|
|
|
.as_end_of_stored_events()
|
|
|
.is_some());
|
|
|
|
|
|
- assert!(recv.try_recv().is_err());
|
|
|
+ assert!(connection.try_recv().is_none());
|
|
|
}
|
|
|
|
|
|
#[tokio::test]
|
|
@@ -613,15 +652,17 @@ mod test {
|
|
|
}
|
|
|
]))
|
|
|
.expect("valid object");
|
|
|
- let relayer = Relayer::new(Some(get_db(true).await), None).expect("valid relayer");
|
|
|
- let (connection, mut recv) = Connection::new_local_connection();
|
|
|
+ let relayer =
|
|
|
+ Arc::new(Relayer::new(Some(get_db(true).await), None).expect("valid relayer"));
|
|
|
+ let mut connection = relayer.create_new_local_connection().await;
|
|
|
let _ = relayer
|
|
|
.process_request_from_client(&connection, request)
|
|
|
.await;
|
|
|
// ev1
|
|
|
assert_eq!(
|
|
|
"9508850d7ddc8ef58c8b392236c49d472dc23fa11f4e73eb5475dfb099ddff42",
|
|
|
- recv.try_recv()
|
|
|
+ connection
|
|
|
+ .try_recv()
|
|
|
.expect("valid")
|
|
|
.as_event()
|
|
|
.expect("event")
|
|
@@ -631,7 +672,8 @@ mod test {
|
|
|
// ev3
|
|
|
assert_eq!(
|
|
|
"e862fe23daf52ab09b36a37fa91ca3743e0c323e630e8627891212ca147c2da9",
|
|
|
- recv.try_recv()
|
|
|
+ connection
|
|
|
+ .try_recv()
|
|
|
.expect("valid")
|
|
|
.as_event()
|
|
|
.expect("event")
|
|
@@ -641,7 +683,8 @@ mod test {
|
|
|
// ev2
|
|
|
assert_eq!(
|
|
|
"2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a",
|
|
|
- recv.try_recv()
|
|
|
+ connection
|
|
|
+ .try_recv()
|
|
|
.expect("valid")
|
|
|
.as_event()
|
|
|
.expect("event")
|
|
@@ -650,13 +693,13 @@ mod test {
|
|
|
);
|
|
|
|
|
|
// eod
|
|
|
- assert!(recv
|
|
|
+ assert!(connection
|
|
|
.try_recv()
|
|
|
.expect("valid")
|
|
|
.as_end_of_stored_events()
|
|
|
.is_some());
|
|
|
|
|
|
- assert!(recv.try_recv().is_err());
|
|
|
+ assert!(connection.try_recv().is_none());
|
|
|
}
|
|
|
|
|
|
#[tokio::test]
|
|
@@ -949,8 +992,9 @@ mod test {
|
|
|
|
|
|
#[tokio::test]
|
|
|
async fn posting_event_replies_ok() {
|
|
|
- let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
|
|
|
- let (connection, mut recv) = Connection::new_local_connection();
|
|
|
+ let relayer =
|
|
|
+ Arc::new(Relayer::new(Some(get_db(false).await), None).expect("valid relayer"));
|
|
|
+ let mut connection = relayer.create_new_local_connection().await;
|
|
|
|
|
|
let note = get_note();
|
|
|
let note_id = note.as_event().map(|x| x.id.clone()).unwrap();
|
|
@@ -970,7 +1014,7 @@ mod test {
|
|
|
}
|
|
|
.into()
|
|
|
),
|
|
|
- recv.try_recv().ok()
|
|
|
+ connection.try_recv()
|
|
|
);
|
|
|
}
|
|
|
|