소스 검색

Merge branch 'relayer-refactor' of cesar/nostr-prototype into main

Cesar Rodas 3 달 전
부모
커밋
c8d7902c9e

+ 18 - 3
Cargo.lock

@@ -899,6 +899,7 @@ dependencies = [
  "futures",
  "futures-util",
  "log",
+ "nostr-rs-memory",
  "nostr-rs-relayer",
  "nostr-rs-types",
  "serde_json",
@@ -938,16 +939,30 @@ dependencies = [
 ]
 
 [[package]]
+name = "nostr-rs-personal-relayer"
+version = "0.1.0"
+dependencies = [
+ "futures",
+ "nostr-rs-client",
+ "nostr-rs-relayer",
+ "nostr-rs-storage-base",
+ "nostr-rs-types",
+ "thiserror",
+ "tokio",
+ "url",
+]
+
+[[package]]
 name = "nostr-rs-relayer"
 version = "0.1.0"
 dependencies = [
+ "futures",
  "futures-util",
  "log",
- "nostr-rs-rocksdb",
+ "nostr-rs-client",
+ "nostr-rs-memory",
  "nostr-rs-storage-base",
  "nostr-rs-types",
- "parking_lot",
- "rand",
  "serde_json",
  "thiserror",
  "tokio",

+ 1 - 1
Cargo.toml

@@ -9,7 +9,7 @@ members = [
     "crates/client",
     "crates/relayer",
     "crates/storage/base",
-    "crates/storage/rocksdb", "crates/dump", "crates/storage/memory",
+    "crates/storage/rocksdb", "crates/dump", "crates/storage/memory", "crates/personal-relayer",
 ]
 
 [dependencies]

+ 1 - 0
crates/client/Cargo.toml

@@ -20,3 +20,4 @@ futures = "0.3.28"
 
 [dev-dependencies]
 nostr-rs-relayer = { path = "../relayer" }
+nostr-rs-memory = { path = "../storage/memory" }

+ 15 - 6
crates/client/src/client.rs

@@ -1,5 +1,4 @@
 use crate::Error;
-use futures::executor::block_on;
 use futures_util::{SinkExt, StreamExt};
 use nostr_rs_types::{
     client::{self, subscribe},
@@ -32,11 +31,13 @@ pub struct ActiveSubscription {
 
 impl Drop for ActiveSubscription {
     fn drop(&mut self) {
-        block_on(async move {
-            self.subscriptions.write().await.remove(&self.id);
-            let _ = self
-                .send_to_socket
-                .send(nostr_rs_types::client::Close(self.id.clone()).into())
+        let subscriptions = self.subscriptions.clone();
+        let id = self.id.clone();
+        let send_to_socket = self.send_to_socket.clone();
+        tokio::spawn(async move {
+            subscriptions.write().await.remove(&id);
+            let _ = send_to_socket
+                .send(nostr_rs_types::client::Close(id).into())
                 .await;
         });
     }
@@ -51,10 +52,13 @@ pub struct Client {
     /// relayer
     pub send_to_socket: mpsc::Sender<Request>,
 
+    /// List of active subscriptions for this nostr client
     subscriptions: Subscriptions,
 
+    /// Background task / thread that is doing the actual connection
     worker: JoinHandle<()>,
 
+    /// Wether the background worker is connected or not
     is_connected: Arc<AtomicBool>,
 }
 
@@ -91,6 +95,11 @@ impl Client {
         }
     }
 
+    /// Spawns a background client that connects to the relayer
+    /// and sends messages to the listener
+    ///
+    /// This function will return a JoinHandle that can be used to
+    /// wait for the background client to finish or to cancel it.
     fn spawn_background_client(
         send_message_to_listener: mpsc::Sender<(Response, Url)>,
         mut send_to_socket: mpsc::Receiver<Request>,

+ 4 - 0
crates/client/src/error.rs

@@ -21,4 +21,8 @@ pub enum Error {
     /// The client has no connection to any relayer
     #[error("There is no connection")]
     Disconnected,
+
+    /// The pool was already splitted
+    #[error("The pool was already splitted")]
+    AlreadySplitted,
 }

+ 2 - 0
crates/client/src/lib.rs

@@ -12,4 +12,6 @@ mod client;
 mod error;
 mod pool;
 
+pub use url::Url;
+
 pub use self::{client::Client, error::Error, pool::Pool};

+ 120 - 24
crates/client/src/pool.rs

@@ -8,10 +8,12 @@ use nostr_rs_types::{
     types::SubscriptionId,
     Response,
 };
-use std::collections::HashMap;
+use std::{collections::HashMap, sync::Arc};
 use tokio::sync::{mpsc, RwLock};
 use url::Url;
 
+type Subscriptions =
+    Arc<RwLock<HashMap<SubscriptionId, (subscribe::Subscribe, Vec<ActiveSubscription>)>>>;
 /// Clients
 ///
 /// This is a set of outgoing connections to relayers. This struct can connect
@@ -21,30 +23,42 @@ use url::Url;
 pub struct Pool {
     clients: RwLock<HashMap<Url, Client>>,
     sender: mpsc::Sender<(Response, Url)>,
-    receiver: mpsc::Receiver<(Response, Url)>,
-    subscriptions: RwLock<HashMap<SubscriptionId, Vec<ActiveSubscription>>>,
+    receiver: Option<mpsc::Receiver<(Response, Url)>>,
+    subscriptions: Subscriptions,
 }
 
+/// Default channel buffer size for the pool
+pub const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 1_000;
+
 impl Default for Pool {
     fn default() -> Self {
-        Self::new()
-    }
-}
-
-const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 10_000;
-
-impl Pool {
-    /// Creates a new Relayers object
-    pub fn new() -> Self {
         let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_BUFFER_SIZE);
         Self {
             clients: Default::default(),
+            receiver: Some(receiver),
             subscriptions: Default::default(),
-            receiver,
             sender,
         }
     }
+}
+
+/// Return a subscription that will be removed when dropped
+pub struct PoolSubscription {
+    subscription_id: SubscriptionId,
+    subscriptions: Subscriptions,
+}
 
+impl Drop for PoolSubscription {
+    fn drop(&mut self) {
+        let subscriptions = self.subscriptions.clone();
+        let subscription_id = self.subscription_id.clone();
+        tokio::spawn(async move {
+            subscriptions.write().await.remove(&subscription_id);
+        });
+    }
+}
+
+impl Pool {
     /// Creates a new instance with a list of urls
     pub fn new_with_clients(clients: Vec<Url>) -> Self {
         let (sender, receiver) = mpsc::channel(DEFAULT_CHANNEL_BUFFER_SIZE);
@@ -56,23 +70,36 @@ impl Pool {
         Self {
             clients: RwLock::new(clients),
             subscriptions: Default::default(),
-            receiver,
+            receiver: Some(receiver),
             sender,
         }
     }
 
+    /// Splits the pool removing the receiver to be used in a different context
+    pub fn split(mut self) -> Result<(mpsc::Receiver<(Response, Url)>, Self), Error> {
+        Ok((self.receiver.take().ok_or(Error::AlreadySplitted)?, self))
+    }
+
     /// Tries to receive a message from any of the connected relayers
     pub fn try_recv(&mut self) -> Option<(Response, Url)> {
-        self.receiver.try_recv().ok()
+        self.receiver.as_mut()?.try_recv().ok()
     }
 
     /// Receives a message from any of the connected relayers
     pub async fn recv(&mut self) -> Option<(Response, Url)> {
-        self.receiver.recv().await
+        self.receiver.as_mut()?.recv().await
+    }
+
+    /// Returns the number of active subscriptions
+    pub async fn active_subscriptions(&self) -> usize {
+        self.subscriptions.read().await.keys().len()
     }
 
     /// Subscribes to all the connected relayers
-    pub async fn subscribe(&self, subscription: subscribe::Subscribe) -> Result<(), Error> {
+    pub async fn subscribe(
+        &self,
+        subscription: subscribe::Subscribe,
+    ) -> Result<PoolSubscription, Error> {
         let clients = self.clients.read().await;
 
         let wait_all = clients
@@ -80,15 +107,23 @@ impl Pool {
             .map(|sender| sender.subscribe(subscription.clone()))
             .collect::<Vec<_>>();
 
+        let subscription_id = subscription.subscription_id.clone();
+
         self.subscriptions.write().await.insert(
-            subscription.subscription_id,
-            join_all(wait_all)
-                .await
-                .into_iter()
-                .collect::<Result<Vec<_>, _>>()?,
+            subscription_id.clone(),
+            (
+                subscription,
+                join_all(wait_all)
+                    .await
+                    .into_iter()
+                    .collect::<Result<Vec<_>, _>>()?,
+            ),
         );
 
-        Ok(())
+        Ok(PoolSubscription {
+            subscription_id,
+            subscriptions: self.subscriptions.clone(),
+        })
     }
 
     /// Sends a request to all the connected relayers
@@ -120,10 +155,71 @@ impl Pool {
     /// already exists false will be returned
     pub async fn connect_to(&self, url: Url) {
         let mut clients = self.clients.write().await;
+        let mut subscriptions = self.subscriptions.write().await;
 
         if !clients.contains_key(&url) {
             log::warn!("Connecting to {}", url);
-            clients.insert(url.clone(), Client::new(self.sender.clone(), url));
+            let client = Client::new(self.sender.clone(), url.clone());
+
+            for (filter, sub) in subscriptions.values_mut() {
+                let _ = client.subscribe(filter.clone()).await.map(|subscription| {
+                    sub.push(subscription);
+                });
+            }
+
+            clients.insert(url.clone(), client);
         }
     }
 }
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use nostr_rs_memory::Memory;
+    use nostr_rs_relayer::Relayer;
+    use std::time::Duration;
+    use tokio::{net::TcpListener, task::JoinHandle, time::sleep};
+
+    async fn dummy_server() -> (Url, JoinHandle<()>) {
+        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+        let local_addr = listener.local_addr().expect("addr");
+
+        let relayer = Relayer::new(Some(Memory::default()), None).expect("valid dummy server");
+        let stopper = relayer.main(listener).expect("valid main loop");
+        (
+            Url::parse(&format!("ws://{}", local_addr.to_string())).expect("valid url"),
+            stopper,
+        )
+    }
+
+    #[tokio::test]
+    async fn droppable_subscription() {
+        let pool = Pool::default();
+        let subscription = pool
+            .subscribe(Default::default())
+            .await
+            .expect("valid subscription");
+
+        assert_eq!(pool.active_subscriptions().await, 1);
+        drop(subscription);
+        sleep(Duration::from_millis(10)).await;
+        assert_eq!(pool.active_subscriptions().await, 0);
+    }
+
+    #[tokio::test]
+    async fn connect_to_dummy_server() {
+        let (addr, stopper) = dummy_server().await;
+        let pool = Pool::new_with_clients(vec![addr]);
+
+        assert_eq!(0, pool.check_active_connections().await);
+
+        sleep(Duration::from_millis(1000)).await;
+        assert_eq!(1, pool.check_active_connections().await);
+
+        // stop dummy server
+        stopper.abort();
+
+        sleep(Duration::from_millis(100)).await;
+        assert_eq!(0, pool.check_active_connections().await);
+    }
+}

+ 9 - 13
crates/dump/src/main.rs

@@ -1,4 +1,4 @@
-use nostr_rs_client::{Error as ClientError, Pool};
+use nostr_rs_client::{Error as ClientError, Pool, Url};
 use nostr_rs_types::{client::Subscribe, Response};
 
 #[derive(Debug, thiserror::Error)]
@@ -13,18 +13,14 @@ pub enum Error {
 #[tokio::main]
 async fn main() {
     env_logger::init();
-    let mut clients = vec![
-        "wss://relay.damus.io/",
-        "wss://brb.io",
-        "wss://nos.lol",
-        "wss://relay.current.fyi",
-        "wss://eden.nostr.land",
-        "wss://relay.snort.social",
-    ]
-    .into_iter()
-    .fold(Pool::new(), |clients, host| {
-        clients.connect_to(host.parse().expect("valid url"))
-    });
+    let mut clients = Pool::new_with_clients(vec![
+        Url::parse("wss://relay.damus.io/").expect("valid url"),
+        Url::parse("wss://brb.io").expect("valid url"),
+        Url::parse("wss://nos.lol").expect("valid url"),
+        Url::parse("wss://relay.current.fyi").expect("valid url"),
+        Url::parse("wss://eden.nostr.land").expect("valid url"),
+        Url::parse("wss://relay.snort.social").expect("valid url"),
+    ]);
 
     let _ = clients.subscribe(Subscribe::default().into()).await;
 

+ 14 - 0
crates/personal-relayer/Cargo.toml

@@ -0,0 +1,14 @@
+[package]
+name = "nostr-rs-personal-relayer"
+version = "0.1.0"
+edition = "2021"
+
+[dependencies]
+nostr-rs-types = { path = "../types" }
+nostr-rs-storage-base = { path = "../storage/base" }
+nostr-rs-client = { path = "../client" }
+nostr-rs-relayer = { path = "../relayer" }
+thiserror = "1.0.39"
+url = { version = "2.5.2", features = ["serde"] }
+futures = "0.3.30"
+tokio = { version = "1.39.2", features = ["full"] }

+ 77 - 0
crates/personal-relayer/src/lib.rs

@@ -0,0 +1,77 @@
+use futures::future::join_all;
+use nostr_rs_client::Pool;
+use nostr_rs_relayer::Relayer;
+use nostr_rs_storage_base::Storage;
+use nostr_rs_types::types::{Addr, Filter};
+use tokio::{net::TcpListener, task::JoinHandle};
+use url::Url;
+
+pub struct Stoppable(Option<Vec<JoinHandle<()>>>);
+
+impl From<Vec<JoinHandle<()>>> for Stoppable {
+    fn from(value: Vec<JoinHandle<()>>) -> Self {
+        Self(Some(value))
+    }
+}
+
+impl Drop for Stoppable {
+    fn drop(&mut self) {
+        if let Some(tasks) = self.0.take() {
+            for join_handle in tasks.into_iter() {
+                join_handle.abort();
+            }
+        }
+    }
+}
+
+#[derive(thiserror::Error, Debug)]
+pub enum Error {
+    #[error("Relayer: {0}")]
+    Relayer(#[from] nostr_rs_relayer::Error),
+
+    #[error("Client error: {0}")]
+    Client(#[from] nostr_rs_client::Error),
+}
+
+pub struct PersonalRelayer<T: Storage + Send + Sync + 'static> {
+    relayer: Relayer<T>,
+    accounts: Vec<Addr>,
+}
+
+impl<T: Storage + Send + Sync + 'static> PersonalRelayer<T> {
+    pub async fn new(
+        storage: T,
+        accounts: Vec<Addr>,
+        client_urls: Vec<Url>,
+    ) -> Result<Self, Error> {
+        let pool = Pool::new_with_clients(client_urls);
+
+        join_all(
+            accounts
+                .iter()
+                .map(|account| {
+                    pool.subscribe(
+                        Filter {
+                            authors: vec![account.clone()],
+                            ..Default::default()
+                        }
+                        .into(),
+                    )
+                })
+                .collect::<Vec<_>>(),
+        )
+        .await
+        .into_iter()
+        .collect::<Result<Vec<_>, _>>()?;
+
+        Ok(Self {
+            relayer: Relayer::new(Some(storage), Some(pool))?,
+            accounts,
+        })
+    }
+
+    pub fn main(self, server: TcpListener) -> Result<Stoppable, Error> {
+        let tasks = vec![self.relayer.main(server)?, tokio::spawn(async move {})];
+        Ok(tasks.into())
+    }
+}

+ 3 - 3
crates/relayer/Cargo.toml

@@ -8,8 +8,8 @@ edition = "2021"
 [dependencies]
 nostr-rs-types = { path = "../types" }
 nostr-rs-storage-base = { path = "../storage/base" }
+nostr-rs-client = { path = "../client" }
 futures-util = "0.3.27"
-parking_lot = "0.12.1"
 tokio = { version = "1.26.0", features = ["sync", "macros", "rt", "time"] }
 tokio-tungstenite = { version = "0.18.0", features = [
     "rustls",
@@ -18,8 +18,8 @@ tokio-tungstenite = { version = "0.18.0", features = [
 ] }
 thiserror = "1.0.39"
 serde_json = "1.0.94"
-rand = "0.8.5"
 log = "0.4.17"
+futures = "0.3.30"
 
 [dev-dependencies]
-nostr-rs-rocksdb = { path = "../storage/rocksdb" }
+nostr-rs-memory = { path = "../storage/memory" }

+ 84 - 35
crates/relayer/src/connection.rs

@@ -1,74 +1,119 @@
-use crate::{get_id, Error};
+use crate::{subscription::ActiveSubscription, Error};
 use futures_util::{SinkExt, StreamExt};
 use nostr_rs_types::{
     relayer::{Auth, ROk},
-    types::Addr,
+    types::{Addr, SubscriptionId},
     Request, Response,
 };
-use parking_lot::RwLock;
-use std::collections::HashMap;
+use std::{
+    collections::HashMap,
+    sync::atomic::{AtomicUsize, Ordering},
+};
 use tokio::{
     net::TcpStream,
-    sync::mpsc::{channel, Receiver, Sender},
+    sync::{
+        mpsc::{channel, Receiver, Sender},
+        RwLock,
+    },
+    task::JoinHandle,
 };
 #[allow(unused_imports)]
 use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream};
 
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
+pub struct ConnectionId(usize);
+
+impl Default for ConnectionId {
+    fn default() -> Self {
+        static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
+        Self(NEXT_ID.fetch_add(1, Ordering::SeqCst))
+    }
+}
+
+impl ConnectionId {
+    /// Connection ID for messages from the client pool or any special empty ConnectionId
+    #[inline]
+    pub fn new_empty() -> Self {
+        Self(0)
+    }
+
+    /// Check if the connection id is empty
+    ///
+    /// Empty connection id is used for messages from Client pool to the relayer
+    #[inline]
+    pub fn is_empty(&self) -> bool {
+        self.0 == 0
+    }
+}
+
 #[derive(Debug)]
+/// Relayer connection
+///
+/// The new connection struct. This struct spawn's a new worker that handles
+/// upcoming messages form the client
 pub struct Connection {
-    #[allow(unused)]
-    pub(crate) conn_id: u128,
+    conn_id: ConnectionId,
     sender: Sender<Response>,
-    subscriptions: RwLock<HashMap<String, u128>>,
+    subscriptions: RwLock<HashMap<SubscriptionId, Vec<ActiveSubscription>>>,
+    handler: Option<JoinHandle<()>>,
 }
 
 const MAX_SUBSCRIPTIONS_BUFFER: usize = 100;
 
+impl Drop for Connection {
+    fn drop(&mut self) {
+        if let Some(handler) = self.handler.take() {
+            let _ = handler.abort();
+        }
+    }
+}
+
 impl Connection {
     #[cfg(test)]
     pub fn new_for_test() -> (Self, Receiver<Response>) {
         let (sender, receiver) = channel(MAX_SUBSCRIPTIONS_BUFFER);
         (
             Self {
-                conn_id: 0,
+                conn_id: ConnectionId::default(),
                 sender,
                 subscriptions: RwLock::new(HashMap::new()),
+                handler: None,
             },
             receiver,
         )
     }
 
-    pub async fn new(
-        broadcast_request: Sender<(u128, Request)>,
-        disconnection_notify: Option<Sender<u128>>,
+    /// Create new connection
+    pub async fn new_connection(
+        send_message_to_relayer: Sender<(ConnectionId, Request)>,
+        disconnection_notify: Option<Sender<ConnectionId>>,
         stream: TcpStream,
     ) -> Result<Self, Error> {
         let websocket = accept_async(stream).await?;
-        let conn_id = get_id();
+        let conn_id = Default::default();
         let (sender, receiver) = channel(MAX_SUBSCRIPTIONS_BUFFER);
-        Self::spawn(
-            broadcast_request,
-            websocket,
-            receiver,
-            disconnection_notify,
-            conn_id,
-        );
         let _ = sender.send(Auth::default().into()).await;
         Ok(Self {
             conn_id,
             sender,
             subscriptions: RwLock::new(HashMap::new()),
+            handler: Some(Self::spawn(
+                send_message_to_relayer,
+                websocket,
+                receiver,
+                disconnection_notify,
+                conn_id,
+            )),
         })
     }
 
-    #[allow(unused)]
     fn spawn(
-        broadcast_request: Sender<(u128, Request)>,
+        send_message_to_relayer: Sender<(ConnectionId, Request)>,
         websocket: WebSocketStream<TcpStream>,
         mut receiver: Receiver<Response>,
-        disconnection_notify: Option<Sender<u128>>,
-        conn_id: u128,
-    ) {
+        disconnection_notify: Option<Sender<ConnectionId>>,
+        conn_id: ConnectionId,
+    ) -> JoinHandle<()> {
         tokio::spawn(async move {
             let mut _subscriptions: HashMap<String, (u128, Receiver<Response>)> = HashMap::new();
             let (mut writer, mut reader) = websocket.split();
@@ -90,7 +135,7 @@ impl Connection {
                             let msg: Result<Request, _> = serde_json::from_str(&msg);
                             match msg {
                                 Ok(msg) => {
-                                    let _ = broadcast_request.send((conn_id, msg)).await;
+                                    let _ = send_message_to_relayer.send((conn_id, msg)).await;
                                 },
                                 Err(err) => {
                                     log::error!("Error parsing message from client: {}", err);
@@ -121,29 +166,33 @@ impl Connection {
             if let Some(disconnection_notify) = disconnection_notify {
                 let _ = disconnection_notify.try_send(conn_id);
             }
-        });
+        })
     }
 
     #[inline]
+    /// Sends a message to this connection's websocket
     pub fn send(&self, response: Response) -> Result<(), Error> {
         self.sender
             .try_send(response)
             .map_err(|e| Error::TrySendError(Box::new(e)))
     }
 
-    #[inline]
+    /// Get the sender for this connection
     pub fn get_sender(&self) -> Sender<Response> {
         self.sender.clone()
     }
 
-    pub fn get_subscription_id(&self, id: &str) -> Option<u128> {
-        let subscriptions = self.subscriptions.read();
-        subscriptions.get(id).copied()
+    /// Get the connection id for this connection
+    pub fn get_conn_id(&self) -> ConnectionId {
+        self.conn_id
     }
 
-    pub fn create_subscription(&self, id: String) -> (u128, Sender<Response>) {
-        let mut subscriptions = self.subscriptions.write();
-        let internal_id = subscriptions.entry(id).or_insert_with(get_id);
-        (*internal_id, self.sender.clone())
+    /// Create a subscription for this connection
+    pub async fn keep_track_subscription(
+        &self,
+        id: SubscriptionId,
+        subscriptions: Vec<ActiveSubscription>,
+    ) {
+        self.subscriptions.write().await.insert(id, subscriptions);
     }
 }

+ 17 - 6
crates/relayer/src/error.rs

@@ -1,22 +1,33 @@
 use nostr_rs_types::Response;
 
 #[derive(Debug, thiserror::Error)]
+/// Relayer error
 pub enum Error {
-    #[error("The identifier {0} is already in use")]
-    IdentifierAlreadyUsed(String),
-
+    /// Database/ Storage error
     #[error("Internal/DB: {0}")]
-    Db(#[from] nostr_rs_storage_base::Error),
+    Storage(#[from] nostr_rs_storage_base::Error),
 
+    /// Web-socket related errors
     #[error("WebSocket error: {0}")]
     WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
 
+    /// Serialization related errors
     #[error("Serialization: {0}")]
     Serde(#[from] serde_json::Error),
 
+    /// Tokio channel's error
+    #[error("TrySendError: {0}")]
+    TrySendError(#[from] Box<tokio::sync::mpsc::error::TrySendError<Response>>),
+
+    /// Client related errors
+    #[error("Nostr client error: {0}")]
+    Client(#[from] nostr_rs_client::Error),
+
+    /// Unknown connections
     #[error("Unknown connection: {0}")]
     UnknownConnection(u128),
 
-    #[error("TrySendError: {0}")]
-    TrySendError(#[from] Box<tokio::sync::mpsc::error::TrySendError<Response>>),
+    /// The relayer is already splitten
+    #[error("Relayer already splitten")]
+    AlreadySplitted,
 }

+ 12 - 18
crates/relayer/src/lib.rs

@@ -1,24 +1,18 @@
-use rand::Rng;
+//! Nostr relayer
+//!
+//! This is a generic relayer crate for the Nostr protocol.
+//!
+//! This relayer has the ability to integrate optionally with a client Pool to
+//! repost notes and to subscribe to their notes.
+//!
+//! The relayer can also be integrate with a nostr storage crate, to persist
+//! notes and serve subscribers from the storage, not only from real time
+//! updates.
+#![deny(missing_docs, warnings)]
 
 mod connection;
 mod error;
 mod relayer;
 mod subscription;
 
-pub use self::{
-    connection::Connection, error::Error, relayer::Relayer, subscription::Subscription,
-};
-
-// Get current nanoseconds and use the last 3 digits as a random number (because
-// sometimes it comes as 0)
-pub(crate) fn get_id() -> u128 {
-    let mut rng = rand::thread_rng();
-    let random_number = rng.gen_range(0..999);
-
-    let ts = std::time::SystemTime::now()
-        .duration_since(std::time::UNIX_EPOCH)
-        .expect("time")
-        .as_nanos();
-
-    ts.checked_add(random_number).unwrap_or(ts)
-}
+pub use self::{connection::Connection, error::Error, relayer::Relayer};

+ 577 - 175
crates/relayer/src/relayer.rs

@@ -1,61 +1,154 @@
-use crate::{Connection, Error, Subscription};
+use crate::{connection::ConnectionId, subscription::SubscriptionManager, Connection, Error};
 use futures_util::StreamExt;
+use nostr_rs_client::{Error as ClientError, Pool};
 use nostr_rs_storage_base::Storage;
-use nostr_rs_types::{
-    relayer,
-    types::{Event, SubscriptionId},
-    Request, Response,
-};
-use parking_lot::{RwLock, RwLockReadGuard};
+use nostr_rs_types::{relayer, types::Event, Request, Response};
 use std::{collections::HashMap, ops::Deref, sync::Arc};
-use tokio::sync::mpsc;
-#[allow(unused_imports)]
 use tokio::{
-    net::TcpStream,
+    net::{TcpListener, TcpStream},
     sync::mpsc::{channel, Receiver, Sender},
 };
+use tokio::{
+    sync::{mpsc, RwLock},
+    task::JoinHandle,
+};
 
-type SubId = u128;
-
-type Subscriptions = HashMap<SubId, (SubscriptionId, Sender<Response>)>;
-
-pub struct Relayer<T: Storage> {
+/// Relayer struct
+///
+pub struct Relayer<T: Storage + Send + Sync + 'static> {
     /// Storage engine, if provided the services are going to persisted in disk,
     /// otherwise all the messages are going to be ephemeral, making this
     /// relayer just a dumb proxy (that can be useful for privacy) but it won't
     /// be able to perform any optimization like prefetching content while offline
     storage: Option<T>,
-    /// Keeps a map between the internal subscription ID and the subscription
-    /// type. One subscription ID may have multiple subscription types.
-    ///
-    /// Each connection keeps a list of the subscription ID provided by the user
-    /// (String) and the internal, globally recognized subscription ID which is
-    /// internal (SubId)
-    subscriptions_ids_index: RwLock<HashMap<SubId, Vec<Subscription>>>,
-    /// Each subscription type that is active has a list of subscriptions.
+    /// x
+    subscriptions: Arc<SubscriptionManager>,
+    clients: RwLock<HashMap<ConnectionId, Connection>>,
+    /// This Sender can be used to send requests from anywhere to the relayer.
+    send_to_relayer: Sender<(ConnectionId, Request)>,
+    /// This Receiver is the relayer the way the relayer receives messages
+    relayer_receiver: Option<Receiver<(ConnectionId, Request)>>,
+    /// Client pool
     ///
-    /// A single REQ can be subscribed to multiple subscription types, specially
-    /// when it is translated in OR filters. It is designed this way to allow a
-    /// fast iteration and match quickly filters.
-    subscriptions: RwLock<HashMap<Subscription, RwLock<Subscriptions>>>,
-    clients: RwLock<HashMap<u128, Connection>>,
-    #[allow(dead_code)]
-    sender: Sender<(u128, Request)>,
+    /// A relayer can optionally be connected to a pool of clients to get foreign events.
+    client_pool: Option<(Pool, JoinHandle<()>)>,
 }
 
-impl<T: Storage> Relayer<T> {
-    pub fn new(storage: Option<T>) -> (Arc<Self>, Receiver<(u128, Request)>) {
+impl<T: Storage + Send + Sync + 'static> Drop for Relayer<T> {
+    fn drop(&mut self) {
+        if let Some((_, handle)) = self.client_pool.take() {
+            let _ = handle.abort();
+        }
+    }
+}
+
+impl<T: Storage + Send + Sync + 'static> Relayer<T> {
+    /// Creates a new relayer instance
+    ///
+    /// If the storage is given, it will be used to persist events, as well to
+    /// server past events when a new subscription is added.
+    ///
+    /// If the client_pool is given it will be used to connect to those relayers
+    /// and create a network of relayers, reposting events to them and
+    /// subscribing to their events.`gqq`
+    pub fn new(storage: Option<T>, client_pool: Option<Pool>) -> Result<Self, Error> {
         let (sender, receiver) = channel(100_000);
-        (
-            Arc::new(Self {
-                storage,
-                subscriptions: RwLock::new(HashMap::new()),
-                subscriptions_ids_index: RwLock::new(HashMap::new()),
-                clients: RwLock::new(HashMap::new()),
-                sender,
-            }),
-            receiver,
-        )
+        Ok(Self {
+            storage,
+            subscriptions: Default::default(),
+            send_to_relayer: sender.clone(),
+            relayer_receiver: Some(receiver),
+            clients: Default::default(),
+            client_pool: if let Some(client_pool) = client_pool {
+                Some(Self::handle_client_pool(client_pool, sender)?)
+            } else {
+                None
+            },
+        })
+    }
+
+    /// Total number of subscribers requests that actively listening for new events
+    pub fn total_subscribers(&self) -> usize {
+        self.subscriptions.total_subscribers()
+    }
+
+    /// Splits the relayer object and extract their receiver.
+    pub fn split(mut self) -> Result<(Self, Receiver<(ConnectionId, Request)>), Error> {
+        let receiver = self.relayer_receiver.take().ok_or(Error::AlreadySplitted)?;
+        Ok((self, receiver))
+    }
+
+    /// Runs the relayer main loop in a tokio task and returns it.
+    ///
+    /// This function consumes the object and takes the ownership. The returned
+    /// JoinHandle() can be used to stop the main loop
+    pub fn main(self, server: TcpListener) -> Result<JoinHandle<()>, Error> {
+        let (this, mut receiver) = self.split()?;
+        Ok(tokio::spawn(async move {
+            loop {
+                tokio::select! {
+                    Ok((stream, _)) = server.accept() => {
+                        // accept new external connections
+                        let _ = this.add_connection(None, stream).await;
+                    },
+                    Some((conn_id, request)) = receiver.recv() => {
+                        // receive messages from the connection pool
+                        if conn_id.is_empty() {
+                            // connection pool
+                            if let Request::Event(event) = request {
+                                if let Some(storage) = this.storage.as_ref() {
+                                    let _ = storage.store_local_event(&event).await;
+                                }
+                                this.broadcast(&event.deref()).await;
+                            }
+                            continue;
+                        }
+
+                        let connections = this.clients.read().await;
+                        let connection = if let Some(connection) = connections.get(&conn_id) {
+                            connection
+                        } else {
+                            continue;
+                        };
+
+                        // receive messages from clients
+                        let _ = this.process_request_from_client(connection, request).await;
+                        drop(connections);
+                    }
+                    else => {
+                    }
+                }
+            }
+        }))
+    }
+
+    fn handle_client_pool(
+        client_pool: Pool,
+        sender: Sender<(ConnectionId, Request)>,
+    ) -> Result<(Pool, JoinHandle<()>), ClientError> {
+        let (mut receiver, client_pool) = client_pool.split()?;
+
+        let handle = tokio::spawn(async move {
+            loop {
+                if let Some((response, _)) = receiver.recv().await {
+                    match response {
+                        Response::Event(event) => {
+                            let _ = sender
+                                .send((
+                                    ConnectionId::new_empty(),
+                                    Request::Event(event.event.into()),
+                                ))
+                                .await;
+                        }
+                        x => {
+                            println!("x => {:?}", x);
+                        }
+                    }
+                }
+            }
+        });
+
+        Ok((client_pool, handle))
     }
 
     /// Returns a reference to the internal database
@@ -63,62 +156,50 @@ impl<T: Storage> Relayer<T> {
         &self.storage
     }
 
+    /// Adds a new TpStream and adds it to the list of active connections.
+    ///
+    /// This function will spawn the client's loop to receive incoming messages and send those messages
     pub async fn add_connection(
         &self,
-        disconnection_notify: Option<mpsc::Sender<u128>>,
+        disconnection_notify: Option<mpsc::Sender<ConnectionId>>,
         stream: TcpStream,
-    ) -> Result<u128, Error> {
-        let client = Connection::new(self.sender.clone(), disconnection_notify, stream).await?;
-        let id = client.conn_id;
-        let mut clients = self.clients.write();
-        clients.insert(id, client);
+    ) -> Result<ConnectionId, Error> {
+        let client =
+            Connection::new_connection(self.send_to_relayer.clone(), disconnection_notify, stream)
+                .await?;
+        let id = client.get_conn_id();
+        self.clients.write().await.insert(id, client);
 
         Ok(id)
     }
 
-    async fn recv_request_from_client(
+    /// Process a request from a connected client
+    async fn process_request_from_client(
         &self,
         connection: &Connection,
         request: Request,
     ) -> Result<Option<Request>, Error> {
         match &request {
             Request::Event(event) => {
-                self.store_and_broadcast_local_event(event.deref()).await;
+                if let Some(storage) = self.storage.as_ref() {
+                    let _ = storage.store(event).await;
+                    let _ = storage.store_local_event(event).await;
+                }
+
+                self.broadcast(event).await;
+
+                if let Some((client_pool, _)) = self.client_pool.as_ref() {
+                    // pass the event to the pool of clients, so this relayer can relay
+                    // their local events to the clients in the network of relayers
+                    let _ = client_pool.post(event.clone().into()).await;
+                }
             }
             Request::Request(request) => {
-                // Create subscription
-                let (sub_id, receiver) =
-                    connection.create_subscription(request.subscription_id.deref().to_owned());
-                let mut sub_index = self.subscriptions_ids_index.write();
-                let mut subscriptions = self.subscriptions.write();
-                if let Some(prev_subs) = sub_index.remove(&sub_id) {
-                    // remove any previous subscriptions
-                    prev_subs.iter().for_each(|index| {
-                        if let Some(subscriptions) = subscriptions.get_mut(index) {
-                            subscriptions.write().remove(&sub_id);
-                        }
-                    });
+                if let Some((client_pool, _)) = self.client_pool.as_ref() {
+                    // pass the subscription request to the pool of clients, so this relayer
+                    // can relay any unknown event to the clients through their subscriptions
+                    let _ = client_pool.subscribe(request.filters.clone().into()).await;
                 }
-                sub_index.insert(
-                    sub_id,
-                    Subscription::from_filters(&request.filters)
-                        .into_iter()
-                        .map(|index| {
-                            subscriptions
-                                .entry(index.clone())
-                                .or_insert_with(|| RwLock::new(HashMap::new()))
-                                .write()
-                                .insert(
-                                    sub_id,
-                                    (request.subscription_id.clone(), receiver.clone()),
-                                );
-                            index
-                        })
-                        .collect::<Vec<_>>(),
-                );
-
-                drop(subscriptions);
-                drop(sub_index);
 
                 if let Some(storage) = self.storage.as_ref() {
                     // Sent all events that match the filter that are stored in our database
@@ -139,102 +220,73 @@ impl<T: Storage> Relayer<T> {
 
                 let _ = connection
                     .send(relayer::EndOfStoredEvents(request.subscription_id.clone()).into());
+
+                connection
+                    .keep_track_subscription(
+                        request.subscription_id.clone(),
+                        self.subscriptions
+                            .subscribe(
+                                connection.get_conn_id(),
+                                connection.get_sender(),
+                                request.clone(),
+                            )
+                            .await,
+                    )
+                    .await;
             }
-            Request::Close(close) => {
-                if let Some(id) = connection.get_subscription_id(&close.0) {
-                    let mut subscriptions = self.subscriptions_ids_index.write();
-                    if let Some(indexes) = subscriptions.remove(&id) {
-                        let mut subscriptions = self.subscriptions.write();
-                        indexes.iter().for_each(|index| {
-                            if let Some(subscriptions) = subscriptions.get_mut(index) {
-                                subscriptions.write().remove(&id);
-                            }
-                        });
-                    }
-                }
+            Request::Close(_close) => {
+                todo!()
             }
         };
 
         Ok(Some(request))
     }
 
-    pub async fn recv(
-        &self,
-        receiver: &mut Receiver<(u128, Request)>,
-    ) -> Result<Option<Request>, Error> {
-        let (conn_id, request) = if let Some(request) = receiver.recv().await {
-            request
-        } else {
-            return Ok(None);
-        };
-        let connections = self.clients.read();
-        let connection = connections
-            .get(&conn_id)
-            .ok_or(Error::UnknownConnection(conn_id))?;
-
-        self.recv_request_from_client(connection, request).await
-    }
-
-    pub fn send_to_conn(&self, conn_id: u128, response: Response) -> Result<(), Error> {
-        let connections = self.clients.read();
-        let connection = connections
-            .get(&conn_id)
-            .ok_or(Error::UnknownConnection(conn_id))?;
-
-        connection.send(response)
-    }
-
-    #[inline]
-    fn broadcast_to_subscribers(subscriptions: RwLockReadGuard<Subscriptions>, event: &Event) {
-        for (_, receiver) in subscriptions.iter() {
-            let _ = receiver.1.try_send(
-                relayer::Event {
-                    subscription_id: receiver.0.clone(),
-                    event: event.clone(),
-                }
-                .into(),
-            );
-        }
-    }
-
-    #[inline]
-    pub async fn store_and_broadcast_local_event(&self, event: &Event) {
-        if let Some(storage) = self.storage.as_ref() {
-            let _ = storage.store_local_event(event).await;
-        }
-        let subscriptions = self.subscriptions.read();
-
-        for subscription_type in Subscription::from_event(event) {
-            if let Some(subscribers) = subscriptions.get(&subscription_type) {
-                Self::broadcast_to_subscribers(subscribers.read(), event);
-            }
-        }
-    }
-
     #[inline]
-    pub fn store_and_broadcast(&self, event: &Event) {
+    /// Broadcast a given event to all local subscribers
+    pub async fn broadcast(&self, event: &Event) {
         if let Some(storage) = self.storage.as_ref() {
-            let _ = storage.store(event);
+            let _ = storage.store(event).await;
         }
-        let subscriptions = self.subscriptions.read();
 
-        for subscription_type in Subscription::from_event(event) {
-            if let Some(subscribers) = subscriptions.get(&subscription_type) {
-                Self::broadcast_to_subscribers(subscribers.read(), event);
-            }
-        }
+        self.subscriptions.broadcast(event.clone());
     }
 }
 
 #[cfg(test)]
 mod test {
+    use std::time::Duration;
+
     use super::*;
-    use crate::get_id;
-    use nostr_rs_rocksdb::RocksDb;
+    use futures::future::join_all;
+    use nostr_rs_memory::Memory;
     use nostr_rs_types::Request;
+    use serde_json::json;
+    use tokio::time::sleep;
+
+    fn get_note() -> Request {
+        serde_json::from_value(json!(
+            [
+                "EVENT",
+                {
+                    "kind":1,
+                    "content":"Pong",
+                    "tags":[
+                        ["e","9508850d7ddc8ef58c8b392236c49d472dc23fa11f4e73eb5475dfb099ddff42","","root"],
+                        ["e","2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a","","reply"],
+                        ["p","39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],
+                        ["p","ee7202ad91459e013bfef263c59e47deb0163a5e7651b026673765488bfaf102"]
+                    ],
+                    "created_at":1681938616,
+                    "pubkey":"a42007e33cfa25673b26f46f39df039aa6003258a68dc88f1f1e0447607aedb3",
+                    "id":"e862fe23daf52ab09b36a37fa91ca3743e0c323e630e8627891212ca147c2da9",
+                    "sig":"9036150a6c8a32933cffcc42aec4d2109a22e9f10d1c3860c0435a925e6386babb7df5c95fcf68c8ed6a9726a1f07225af663d0b068eb555014130aad21674fc",
+                }
+        ])).expect("value")
+    }
 
-    async fn get_db(prefill: bool) -> RocksDb {
-        let db = RocksDb::new(format!("/tmp/db/{}", get_id())).expect("db");
+    async fn get_db(prefill: bool) -> Memory {
+        let db = Memory::default();
         if prefill {
             let events = include_str!("../tests/events.json")
                 .lines()
@@ -244,23 +296,73 @@ mod test {
             for event in events {
                 assert!(db.store(&event).await.expect("valid"));
             }
+
+            while db.is_flushing() {
+                tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
+            }
         }
         db
     }
 
     #[tokio::test]
     async fn serve_listener_from_local_db() {
-        let request: Request = serde_json::from_str("[
-                \"REQ\",\"1298169700973717\",
-                {\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"since\":1681928304},
-                {\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[1,3,6,7,9735],\"since\":1681928304},
-                {\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},
-                {\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},
-                {\"#e\":[\"2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a\",\"a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1\"],\"kinds\":[1,6,7,9735]}
-            ]").expect("valid object");
-        let (relayer, _) = Relayer::new(Some(get_db(true).await));
+        let request = serde_json::from_value(json!([
+          "REQ",
+          "1298169700973717",
+          {
+            "authors": [
+              "39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"
+            ],
+            "since": 1681928304
+          },
+          {
+            "#p": [
+              "39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"
+            ],
+            "kinds": [
+              1,
+              3,
+              6,
+              7,
+              9735
+            ],
+            "since": 1681928304
+          },
+          {
+            "#p": [
+              "39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"
+            ],
+            "kinds": [
+              4
+            ]
+          },
+          {
+            "authors": [
+              "39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"
+            ],
+            "kinds": [
+              4
+            ]
+          },
+          {
+            "#e": [
+              "2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a",
+              "a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1"
+            ],
+            "kinds": [
+              1,
+              6,
+              7,
+              9735
+            ]
+          }
+        ]))
+        .expect("valid object");
+        let relayer = Relayer::new(Some(get_db(true).await), None).expect("valid relayer");
         let (connection, mut recv) = Connection::new_for_test();
-        let _ = relayer.recv_request_from_client(&connection, request).await;
+        let _ = relayer
+            .process_request_from_client(&connection, request)
+            .await;
         // ev1
         assert_eq!(
             "9508850d7ddc8ef58c8b392236c49d472dc23fa11f4e73eb5475dfb099ddff42",
@@ -294,6 +396,30 @@ mod test {
                 .id
                 .to_string()
         );
+
+        // ev3 (again)
+        assert_eq!(
+            "e862fe23daf52ab09b36a37fa91ca3743e0c323e630e8627891212ca147c2da9",
+            recv.try_recv()
+                .expect("valid")
+                .as_event()
+                .expect("event")
+                .event
+                .id
+                .to_string()
+        );
+        // ev2 (again)
+        assert_eq!(
+            "2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a",
+            recv.try_recv()
+                .expect("valid")
+                .as_event()
+                .expect("event")
+                .event
+                .id
+                .to_string()
+        );
+
         // eod
         assert!(recv
             .try_recv()
@@ -306,10 +432,46 @@ mod test {
 
     #[tokio::test]
     async fn server_listener_real_time() {
-        let request: Request = serde_json::from_str("[\"REQ\",\"1298169700973717\",{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[1,3,6,7,9735],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"#e\":[\"2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a\",\"a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1\"],\"kinds\":[1,6,7,9735]}]").expect("valid object");
-        let (relayer, _) = Relayer::new(Some(get_db(false).await));
+        let request: Request = serde_json::from_value(json!(
+            [
+                "REQ",
+                "1298169700973717",
+                {
+                    "authors": ["39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],
+                    "since":1681939304
+                },
+                {
+                    "#p":["39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],
+                    "kinds":[1,3,6,7,9735],
+                    "since":1681939304
+                },
+                {
+                    "#p":["39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],
+                    "kinds":[4]
+                },
+                {
+                    "authors":["39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],
+                    "kinds":[4]
+                },
+                {
+                    "#e":[
+                        "2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a",
+                        "a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1"
+                    ],
+                    "kinds":[1,6,7,9735]
+                }
+        ]))
+        .expect("valid object");
+        let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
         let (connection, mut recv) = Connection::new_for_test();
-        let _ = relayer.recv_request_from_client(&connection, request).await;
+
+        assert_eq!(relayer.total_subscribers(), 0);
+        let _ = relayer
+            .process_request_from_client(&connection, request)
+            .await;
+
+        assert_eq!(relayer.total_subscribers(), 5);
+
         // eod
         assert!(recv
             .try_recv()
@@ -320,14 +482,254 @@ mod test {
         // It is empty
         assert!(recv.try_recv().is_err());
 
-        let new_event: Request = serde_json::from_str(r#"["EVENT", {"kind":1,"content":"Pong","tags":[["e","9508850d7ddc8ef58c8b392236c49d472dc23fa11f4e73eb5475dfb099ddff42","","root"],["e","2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a","","reply"],["p","39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],["p","ee7202ad91459e013bfef263c59e47deb0163a5e7651b026673765488bfaf102"]],"created_at":1681938616,"pubkey":"a42007e33cfa25673b26f46f39df039aa6003258a68dc88f1f1e0447607aedb3","id":"e862fe23daf52ab09b36a37fa91ca3743e0c323e630e8627891212ca147c2da9","sig":"9036150a6c8a32933cffcc42aec4d2109a22e9f10d1c3860c0435a925e6386babb7df5c95fcf68c8ed6a9726a1f07225af663d0b068eb555014130aad21674fc","meta":{"revision":0,"created":1681939266488,"version":0},"$loki":108}]"#).expect("value");
+        relayer
+            .process_request_from_client(&connection, get_note())
+            .await
+            .expect("process event");
+
+        sleep(Duration::from_millis(100)).await;
+
+        // It is not empty
+        let msg = recv.try_recv();
+        assert!(msg.is_ok());
+        assert_eq!(
+            msg.expect("is ok")
+                .as_event()
+                .expect("valid")
+                .subscription_id
+                .to_string(),
+            "1298169700973717".to_owned()
+        );
+
+        // it must be deliverd at most once
+        assert!(recv.try_recv().is_err());
+        assert_eq!(relayer.total_subscribers(), 5);
+
+        // when client is dropped, the subscription is removed
+        // automatically
+        drop(connection);
+
+        sleep(Duration::from_millis(10)).await;
+
+        assert_eq!(relayer.total_subscribers(), 0);
+    }
+
+    #[tokio::test]
+    async fn subscribe_partial_key() {
+        let request: Request = serde_json::from_value(json!([
+            "REQ",
+            "1298169700973717",
+            {
+                "authors":["a42007e33c"],
+                "since":1681939304
+            }
+        ]))
+        .expect("valid object");
+
+        let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
+        let (connection, mut recv) = Connection::new_for_test();
+
+        assert_eq!(relayer.total_subscribers(), 0);
+        let _ = relayer
+            .process_request_from_client(&connection, request)
+            .await;
+
+        assert_eq!(relayer.total_subscribers(), 1);
+
+        // eod
+        assert!(recv
+            .try_recv()
+            .expect("valid")
+            .as_end_of_stored_events()
+            .is_some());
+
+        // It is empty
+        assert!(recv.try_recv().is_err());
 
         relayer
-            .recv_request_from_client(&connection, new_event)
+            .process_request_from_client(&connection, get_note())
             .await
             .expect("process event");
 
+        sleep(Duration::from_millis(100)).await;
+
         // It is not empty
-        assert!(recv.try_recv().is_ok());
+        let msg = recv.try_recv();
+        assert!(msg.is_ok());
+        assert_eq!(
+            msg.expect("is ok")
+                .as_event()
+                .expect("valid")
+                .subscription_id
+                .to_string(),
+            "1298169700973717".to_owned()
+        );
+
+        // it must be deliverd at most once
+        assert!(recv.try_recv().is_err());
+        assert_eq!(relayer.total_subscribers(), 1);
+
+        // when client is dropped, the subscription is removed
+        // automatically
+        drop(connection);
+
+        sleep(Duration::from_millis(10)).await;
+
+        assert_eq!(relayer.total_subscribers(), 0);
+    }
+
+    #[tokio::test]
+    async fn multiple_subcribers() {
+        let req1: Request = serde_json::from_value(json!(["REQ", "1298169700973717", {
+            "authors":["c42007e33c"],
+        }]))
+        .expect("valid object");
+        let req2: Request = serde_json::from_value(json!(["REQ", "1298169700973717", {
+           "authors":["a42007e33c"]
+        }]))
+        .expect("valid object");
+
+        let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
+        let (publisher, _) = Connection::new_for_test();
+
+        let mut set1 = (0..1000)
+            .map(|_| Connection::new_for_test())
+            .collect::<Vec<_>>();
+
+        let mut set2 = (0..100)
+            .map(|_| Connection::new_for_test())
+            .collect::<Vec<_>>();
+
+        let subscribe1 = set1
+            .iter()
+            .map(|(connection, _)| relayer.process_request_from_client(connection, req1.clone()))
+            .collect::<Vec<_>>();
+
+        let subscribe2 = set2
+            .iter()
+            .map(|(connection, _)| relayer.process_request_from_client(connection, req2.clone()))
+            .collect::<Vec<_>>();
+
+        assert_eq!(relayer.total_subscribers(), 0);
+
+        join_all(subscribe1)
+            .await
+            .into_iter()
+            .collect::<Result<Vec<_>, _>>()
+            .expect("valid calls");
+        join_all(subscribe2)
+            .await
+            .into_iter()
+            .collect::<Result<Vec<_>, _>>()
+            .expect("valid calls");
+
+        for (_, recv) in set1.iter_mut() {
+            assert!(recv
+                .try_recv()
+                .expect("end of stored events")
+                .as_end_of_stored_events()
+                .is_some());
+        }
+
+        for (_, recv) in set2.iter_mut() {
+            assert!(recv
+                .try_recv()
+                .expect("end of stored events")
+                .as_end_of_stored_events()
+                .is_some());
+        }
+
+        assert_eq!(relayer.total_subscribers(), 1100);
+
+        relayer
+            .process_request_from_client(&publisher, get_note())
+            .await
+            .expect("process event");
+
+        sleep(Duration::from_millis(10)).await;
+
+        for (_, recv) in set1.iter_mut() {
+            assert!(recv.try_recv().is_err());
+        }
+
+        for (_, recv) in set2.iter_mut() {
+            let msg = recv.try_recv();
+            assert!(msg.is_ok());
+            let msg = msg.expect("msg");
+
+            assert_eq!(
+                msg.as_event().expect("valid").subscription_id.to_string(),
+                "1298169700973717".to_owned()
+            );
+
+            assert!(recv.try_recv().is_err());
+        }
+
+        drop(set1);
+        sleep(Duration::from_millis(10)).await;
+        assert_eq!(relayer.total_subscribers(), 100);
+
+        drop(set2);
+        sleep(Duration::from_millis(10)).await;
+        assert_eq!(relayer.total_subscribers(), 0);
+
+        drop(relayer);
+    }
+
+    #[tokio::test]
+    async fn subscribe_to_all() {
+        let request: Request =
+            serde_json::from_value(json!(["REQ", "1298169700973717", {}])).expect("valid object");
+
+        let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
+        let (connection, mut recv) = Connection::new_for_test();
+
+        assert_eq!(relayer.total_subscribers(), 0);
+        let _ = relayer
+            .process_request_from_client(&connection, request)
+            .await;
+
+        assert_eq!(relayer.total_subscribers(), 1);
+
+        // eod
+        assert!(recv
+            .try_recv()
+            .expect("valid")
+            .as_end_of_stored_events()
+            .is_some());
+
+        // It is empty
+        assert!(recv.try_recv().is_err());
+
+        relayer
+            .process_request_from_client(&connection, get_note())
+            .await
+            .expect("process event");
+
+        sleep(Duration::from_millis(10)).await;
+
+        // It is not empty
+        let msg = recv.try_recv();
+        assert!(msg.is_ok());
+        assert_eq!(
+            msg.expect("is ok")
+                .as_event()
+                .expect("valid")
+                .subscription_id
+                .to_string(),
+            "1298169700973717".to_owned()
+        );
+
+        // it must be deliverd at most once
+        assert!(recv.try_recv().is_err());
+        assert_eq!(relayer.total_subscribers(), 1);
+
+        // when client is dropped, the subscription is removed
+        // automatically
+        drop(connection);
+
+        sleep(Duration::from_millis(10)).await;
+
+        assert_eq!(relayer.total_subscribers(), 0);
     }
 }

+ 0 - 222
crates/relayer/src/subscription.rs

@@ -1,222 +0,0 @@
-use nostr_rs_types::types::{Event, Filter, Tag};
-
-#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash)]
-pub struct Subscription {
-    author: Option<Vec<u8>>,
-    ref_public_key: Option<Vec<u8>>,
-    ref_id: Option<Vec<u8>>,
-    id: Option<Vec<u8>>,
-    kind: Option<u32>,
-}
-
-impl Subscription {
-    #[inline]
-    pub fn from_filters(filters: &[Filter]) -> Vec<Subscription> {
-        let mut subs = vec![];
-        filters.iter().for_each(|filter| {
-            let authors: Vec<Option<Vec<u8>>> = if filter.authors.is_empty() {
-                vec![None]
-            } else {
-                filter
-                    .authors
-                    .iter()
-                    .map(|author| Some(author.to_vec()))
-                    .collect()
-            };
-            let ref_public_keys = if filter.references_to_public_key.is_empty() {
-                vec![None]
-            } else {
-                filter
-                    .references_to_public_key
-                    .iter()
-                    .map(|public_key| Some((*public_key).to_vec()))
-                    .collect()
-            };
-            let ref_ids = if filter.references_to_event.is_empty() {
-                vec![None]
-            } else {
-                filter
-                    .references_to_event
-                    .iter()
-                    .map(|id| Some((*id).to_vec()))
-                    .collect()
-            };
-            let kind = if filter.kinds.is_empty() {
-                vec![None]
-            } else {
-                filter
-                    .kinds
-                    .iter()
-                    .map(|kind| Some((*kind).into()))
-                    .collect()
-            };
-
-            let ids = if filter.ids.is_empty() {
-                vec![None]
-            } else {
-                filter.ids.iter().map(|id| Some((*id).to_vec())).collect()
-            };
-
-            for author in authors.iter() {
-                for id in ids.iter() {
-                    for ref_public_key in ref_public_keys.iter() {
-                        for ref_id in ref_ids.iter() {
-                            for kind in kind.iter() {
-                                subs.push(Subscription {
-                                    id: id.clone(),
-                                    ref_public_key: ref_public_key.clone(),
-                                    author: author.clone(),
-                                    ref_id: ref_id.clone(),
-                                    kind: *kind,
-                                });
-                            }
-                        }
-                    }
-                }
-            }
-        });
-
-        subs
-    }
-
-    #[inline]
-    pub fn from_event(event: &Event) -> Vec<Subscription> {
-        let kind = event.kind().into();
-        let public_keys = vec![None, Some(event.author().as_ref().to_vec())];
-        let id = vec![None, Some(event.id.as_ref().to_vec())];
-        let kind = [None, Some(kind)];
-        let mut ref_public_keys = vec![None];
-        let mut ref_ids = vec![None];
-
-        event.tags().iter().for_each(|tag| match tag {
-            Tag::Event(x) => {
-                ref_ids.push(Some(x.id.as_ref().to_vec()));
-            }
-            Tag::PubKey(x) => {
-                ref_public_keys.push(Some(x.id.as_ref().to_vec()));
-            }
-            _ => {}
-        });
-
-        let mut subs = vec![];
-
-        for ref_public_key in ref_public_keys.iter() {
-            for ref_id in ref_ids.iter() {
-                for public_key in public_keys.iter() {
-                    for id in id.iter() {
-                        for kind in kind.iter() {
-                            subs.push(Subscription {
-                                ref_id: ref_id.clone(),
-                                ref_public_key: ref_public_key.clone(),
-                                author: public_key.clone(),
-                                id: id.clone(),
-                                kind: *kind,
-                            });
-                        }
-                    }
-                }
-            }
-        }
-
-        subs
-    }
-}
-
-#[cfg(test)]
-mod test {
-    use super::*;
-    use nostr_rs_types::{types::Addr, Request};
-
-    #[test]
-    fn test_no_listen_to_all() {
-        let request: Request = serde_json::from_str("[\"REQ\",\"6440d5279e350\",{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"since\":1681967101},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[1,3,6,7,9735],\"since\":1681967101},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"#e\":[\"2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a\",\"a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1\"],\"kinds\":[1,6,7,9735]}]").expect("valid");
-        let mut subscriptions =
-            Subscription::from_filters(&request.as_request().expect("req").filters);
-        subscriptions.sort();
-
-        assert!(subscriptions
-            .binary_search(&Subscription::default())
-            .is_err())
-    }
-
-    #[test]
-    fn from_filters() {
-        let request: Request = serde_json::from_str("[\"REQ\",\"1298169700973717\",{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[1,3,6,7,9735],\"since\":1681939304},{\"#p\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"authors\":[\"39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb\"],\"kinds\":[4]},{\"#e\":[\"2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a\",\"a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1\"],\"kinds\":[1,6,7,9735]}]").expect("valid object");
-        let mut subscriptions =
-            Subscription::from_filters(&request.as_request().expect("req").filters);
-        subscriptions.sort();
-
-        assert!(subscriptions
-            .binary_search(&Subscription::default())
-            .is_err())
-    }
-
-    #[test]
-    fn from_event() {
-        let new_event: Request = serde_json::from_str(r#"["EVENT", {"kind":1,"content":"Pong","tags":[["e","9508850d7ddc8ef58c8b392236c49d472dc23fa11f4e73eb5475dfb099ddff42","","root"],["e","2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a","","reply"],["p","39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],["p","ee7202ad91459e013bfef263c59e47deb0163a5e7651b026673765488bfaf102"]],"created_at":1681938616,"pubkey":"a42007e33cfa25673b26f46f39df039aa6003258a68dc88f1f1e0447607aedb3","id":"e862fe23daf52ab09b36a37fa91ca3743e0c323e630e8627891212ca147c2da9","sig":"9036150a6c8a32933cffcc42aec4d2109a22e9f10d1c3860c0435a925e6386babb7df5c95fcf68c8ed6a9726a1f07225af663d0b068eb555014130aad21674fc","meta":{"revision":0,"created":1681939266488,"version":0},"$loki":108}]"#).expect("value");
-        let mut subscriptions = Subscription::from_event(&new_event.as_event().expect("event"));
-        subscriptions.sort();
-
-        let pk: Addr = "39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"
-            .try_into()
-            .expect("id");
-
-        let id: Addr = "e862fe23daf52ab09b36a37fa91ca3743e0c323e630e8627891212ca147c2da9"
-            .try_into()
-            .expect("id");
-
-        let ref_id: Addr = "2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a"
-            .try_into()
-            .expect("id");
-
-        let author: Addr = "a42007e33cfa25673b26f46f39df039aa6003258a68dc88f1f1e0447607aedb3"
-            .try_into()
-            .expect("id");
-
-        let expected = vec![
-            Subscription {
-                ref_public_key: Some(pk.as_ref().to_vec()),
-                ..Subscription::default()
-            },
-            Subscription {
-                id: Some(id.as_ref().to_vec()),
-                ref_public_key: Some(pk.as_ref().to_vec()),
-                ..Subscription::default()
-            },
-            Subscription {
-                id: Some(id.as_ref().to_vec()),
-                ref_id: Some(ref_id.as_ref().to_vec()),
-                ref_public_key: Some(pk.as_ref().to_vec()),
-                ..Subscription::default()
-            },
-            Subscription {
-                author: Some(author.as_ref().to_vec()),
-                kind: Some(1),
-                ..Subscription::default()
-            },
-            Subscription {
-                author: Some(author.as_ref().to_vec()),
-                ..Subscription::default()
-            },
-            Subscription {
-                id: Some(id.as_ref().to_vec()),
-                ref_id: Some(ref_id.as_ref().to_vec()),
-                author: Some(author.as_ref().to_vec()),
-                ..Subscription::default()
-            },
-        ];
-
-        assert_eq!(subscriptions.len(), 72);
-        expected.iter().enumerate().for_each(|(i, sub)| {
-            assert!(
-                subscriptions.binary_search(sub).is_ok(),
-                "{} -> {:?}",
-                i,
-                sub
-            );
-        });
-        assert!(subscriptions
-            .binary_search(&Subscription::default())
-            .is_ok());
-    }
-}

+ 217 - 0
crates/relayer/src/subscription/filter.rs

@@ -0,0 +1,217 @@
+use nostr_rs_types::types::{Event, Filter, Kind, Tag};
+use std::collections::BTreeSet;
+
+/// The subscription keys are used to quickly identify the subscriptions
+/// by one or more fields. Think of it like a database index
+#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)]
+pub(crate) enum Key {
+    /// Key for the author field
+    Author(Vec<u8>, Option<Kind>),
+    /// Key for the reference to an event
+    RefId(Vec<u8>, Option<Kind>),
+    /// Key for the reference to a public key
+    RefPublicKey(Vec<u8>, Option<Kind>),
+    /// Key for the kind field
+    Id(Vec<u8>),
+    /// Key for the kind field
+    Kind(Kind),
+    /// Any value, for a catch all
+    AllUpdates,
+}
+
+type SortedSet<T> = Option<BTreeSet<T>>;
+
+/// Sorted filter
+///
+/// This is a pre-processed filter that is optimized for fast lookups.
+#[derive(Debug, Clone)]
+pub struct SortedFilter {
+    ids: SortedSet<Vec<u8>>,
+    authors: SortedSet<Vec<u8>>,
+    kinds: SortedSet<Kind>,
+    references_to_event: SortedSet<Vec<u8>>,
+    references_to_public_key: SortedSet<Vec<u8>>,
+}
+
+impl From<Filter> for SortedFilter {
+    /// Converts a Filter into a SortedFilter.
+    ///
+    /// SortedFilters have SortedSets that are optimized for fast lookups.
+    fn from(filter: Filter) -> Self {
+        Self {
+            ids: (!filter.ids.is_empty()).then(|| {
+                filter
+                    .ids
+                    .into_iter()
+                    .map(|x| x.as_ref().to_vec())
+                    .collect()
+            }),
+            authors: (!filter.authors.is_empty()).then(|| {
+                filter
+                    .authors
+                    .into_iter()
+                    .map(|x| x.as_ref().to_vec())
+                    .collect()
+            }),
+            kinds: (!filter.kinds.is_empty()).then(|| filter.kinds.into_iter().collect()),
+            references_to_event: (!filter.references_to_event.is_empty()).then(|| {
+                filter
+                    .references_to_event
+                    .into_iter()
+                    .map(|x| x.as_ref().to_vec())
+                    .collect()
+            }),
+            references_to_public_key: (!filter.references_to_public_key.is_empty()).then(|| {
+                filter
+                    .references_to_public_key
+                    .into_iter()
+                    .map(|x| x.as_ref().to_vec())
+                    .collect()
+            }),
+        }
+    }
+}
+
+impl SortedFilter {
+    /// Get the keys for the filter
+    ///
+    /// Get the keys or indexes for the filter. This is used to quickly identify potential matches
+    /// from a set of event listeners.
+    pub fn keys(&self) -> Vec<Key> {
+        let authors = self
+            .authors
+            .as_ref()
+            .map_or_else(|| vec![], |x| x.iter().map(|x| x.clone()).collect());
+
+        let ids = self
+            .ids
+            .as_ref()
+            .map_or_else(|| vec![], |x| x.iter().map(|x| x.clone()).collect());
+
+        let references_to_event = self
+            .references_to_event
+            .as_ref()
+            .map_or_else(|| vec![], |x| x.iter().map(|x| x.clone()).collect());
+
+        let references_to_public_key = self
+            .references_to_public_key
+            .as_ref()
+            .map_or_else(|| vec![], |x| x.iter().map(|x| x.clone()).collect());
+
+        let kinds = self
+            .kinds
+            .as_ref()
+            .map_or_else(|| vec![], |x| x.iter().map(|x| *x).collect());
+
+        let kind_option = if kinds.is_empty() {
+            vec![None]
+        } else {
+            kinds.clone().into_iter().map(Some).collect()
+        };
+
+        let keys = vec![
+            authors
+                .into_iter()
+                .map(|author| {
+                    kind_option
+                        .iter()
+                        .map(|kind| Key::Author(author.clone(), *kind))
+                        .collect::<Vec<_>>()
+                })
+                .collect::<Vec<_>>(),
+            references_to_event
+                .into_iter()
+                .map(|event| {
+                    kind_option
+                        .iter()
+                        .map(|kind| Key::RefId(event.clone(), *kind))
+                        .collect::<Vec<_>>()
+                })
+                .collect::<Vec<_>>(),
+            references_to_public_key
+                .into_iter()
+                .map(|pub_key| {
+                    kind_option
+                        .iter()
+                        .map(|kind| Key::RefPublicKey(pub_key.clone(), *kind))
+                        .collect::<Vec<_>>()
+                })
+                .collect::<Vec<_>>(),
+            vec![kinds
+                .into_iter()
+                .map(|kind| Key::Kind(kind))
+                .collect::<Vec<_>>()],
+            vec![ids.into_iter().map(|id| Key::Id(id)).collect::<Vec<_>>()],
+        ]
+        .concat()
+        .concat();
+
+        if keys.is_empty() {
+            vec![Key::AllUpdates]
+        } else {
+            keys
+        }
+    }
+
+    /// Checks if a given key exists in the sorted set, either as a whole or as a partial match
+    #[inline]
+    fn has_key_or_partial_match<T: AsRef<[u8]>>(id: T, ids: &SortedSet<Vec<u8>>) -> bool {
+        if let Some(ids) = ids.as_ref() {
+            let id = id.as_ref().to_vec();
+            if ids.contains(&id) {
+                return true;
+            }
+
+            for len in 4..=id.len() {
+                let prev_id = &id[..len].to_vec();
+                if ids.contains(prev_id) {
+                    return true;
+                }
+            }
+            false
+        } else {
+            true
+        }
+    }
+
+    /// Checks if any of key given set exists in the sorted set, either as a whole or as a partial match
+    #[inline]
+    fn has_any_key_or_partial_match<T: Iterator<Item = Vec<u8>>>(
+        id_set: T,
+        ids: &SortedSet<Vec<u8>>,
+    ) -> bool {
+        if ids.is_some() {
+            for id in id_set {
+                if Self::has_key_or_partial_match(id, ids) {
+                    return true;
+                }
+            }
+            false
+        } else {
+            true
+        }
+    }
+
+    /// Checks if the event matches the filter
+    pub fn check(&self, event: &Event) -> bool {
+        self.kinds
+            .as_ref()
+            .map_or(true, |kinds| kinds.contains(&event.kind()))
+            && Self::has_key_or_partial_match(&event.id, &self.ids)
+            && Self::has_key_or_partial_match(&event.author(), &self.authors)
+            && Self::has_any_key_or_partial_match(
+                event.tags().iter().filter_map(|f| match f {
+                    Tag::Event(x) => Some(x.id.as_ref().to_vec()),
+                    _ => None,
+                }),
+                &self.references_to_event,
+            )
+            && Self::has_any_key_or_partial_match(
+                event.tags().iter().filter_map(|f| match f {
+                    Tag::PubKey(x) => Some(x.id.as_ref().to_vec()),
+                    _ => None,
+                }),
+                &self.references_to_public_key,
+            )
+    }
+}

+ 226 - 0
crates/relayer/src/subscription/manager.rs

@@ -0,0 +1,226 @@
+use super::filter::{Key, SortedFilter};
+use crate::connection::ConnectionId;
+use nostr_rs_types::{
+    client::Subscribe,
+    types::{Event, SubscriptionId},
+    Response,
+};
+use std::{
+    collections::{BTreeMap, BTreeSet},
+    sync::{
+        atomic::{AtomicUsize, Ordering},
+        Arc,
+    },
+};
+use tokio::sync::{mpsc::Sender, RwLock};
+
+type SubIdx = (Key, ConnectionId, SubscriptionId);
+
+pub const MIN_PREFIX_MATCH_LEN: usize = 4;
+
+/// Subscription for a connection
+///
+/// This object is responsible for keeping track of a subscription for a connection
+///
+/// When dropped their listener will be removed from the subscription manager automatically
+#[derive(Clone, Debug)]
+pub struct ActiveSubscription {
+    conn_id: ConnectionId,
+    name: SubscriptionId,
+    keys: Vec<Key>,
+    manager: Arc<SubscriptionManager>,
+}
+
+impl ActiveSubscription {
+    fn new(
+        conn_id: ConnectionId,
+        name: SubscriptionId,
+        keys: Vec<Key>,
+        manager: Arc<SubscriptionManager>,
+    ) -> Self {
+        Self {
+            conn_id,
+            name,
+            keys,
+            manager,
+        }
+    }
+}
+
+impl Drop for ActiveSubscription {
+    /// When the subscription is dropped, it will remove the listener from the
+    /// subscription manager
+    fn drop(&mut self) {
+        let keys = self
+            .keys
+            .drain(..)
+            .map(|key| (key, self.conn_id, self.name.clone()))
+            .collect::<Vec<_>>();
+
+        let manager = self.manager.clone();
+
+        tokio::spawn(async move {
+            manager.unsubscribe(keys).await;
+        });
+    }
+}
+
+/// Subscription manager
+///
+/// This object is responsible for letting clients and processes subscribe to
+/// events,
+#[derive(Debug)]
+pub struct SubscriptionManager {
+    /// List of subscriptions with their filters and their index.
+    ///
+    /// A single request may be converted to multiple subscriptions entry as
+    /// they are sorted by their index/key
+    subscriptions: RwLock<BTreeMap<SubIdx, (Arc<SortedFilter>, Sender<Response>)>>,
+    /// Total number of subscribers
+    /// A single REQ may have multiple subscriptions
+    total_subscribers: AtomicUsize,
+    /// Minimum prefix match length
+    min_prefix_match_len: usize,
+}
+
+impl Default for SubscriptionManager {
+    fn default() -> Self {
+        Self {
+            subscriptions: Default::default(),
+            total_subscribers: Default::default(),
+            min_prefix_match_len: MIN_PREFIX_MATCH_LEN,
+        }
+    }
+}
+
+impl SubscriptionManager {
+    async fn unsubscribe(self: Arc<Self>, keys: Vec<SubIdx>) {
+        println!("block");
+        let mut subscriptions = self.subscriptions.write().await;
+        println!("\tblocked");
+        for sub in keys {
+            subscriptions.remove(&sub);
+        }
+        self.total_subscribers.fetch_sub(1, Ordering::Relaxed);
+        println!("released");
+    }
+
+    fn get_keys_from_event(event: &Event, min_prefix_match_len: usize) -> Vec<Key> {
+        let mut subscriptions = vec![];
+
+        let author = event.author().as_ref().to_vec();
+        let id = event.id.as_ref().to_vec();
+
+        let len = author.len();
+        for i in min_prefix_match_len..len - min_prefix_match_len {
+            subscriptions.push(Key::Author(author[..len - i].to_vec(), None));
+            subscriptions.push(Key::Author(author[..len - i].to_vec(), Some(event.kind())));
+        }
+
+        for t in event.tags() {
+            match t {
+                nostr_rs_types::types::Tag::Event(ref_event) => {
+                    let len = ref_event.id.len();
+                    for i in min_prefix_match_len..ref_event.id.len() - min_prefix_match_len {
+                        subscriptions.push(Key::RefId(ref_event.id[..len - i].to_vec(), None));
+                        subscriptions.push(Key::RefId(
+                            ref_event.id[..len - i].to_vec(),
+                            Some(event.kind()),
+                        ));
+                    }
+                }
+                nostr_rs_types::types::Tag::PubKey(ref_pub_key) => {
+                    let len = ref_pub_key.id.len();
+                    for i in min_prefix_match_len..len - min_prefix_match_len {
+                        subscriptions.push(Key::RefId(ref_pub_key.id[..len - i].to_vec(), None));
+                        subscriptions.push(Key::RefId(
+                            ref_pub_key.id[..len - i].to_vec(),
+                            Some(event.kind()),
+                        ));
+                    }
+                }
+                _ => {}
+            }
+        }
+
+        let len = id.len();
+        for i in min_prefix_match_len..len - min_prefix_match_len {
+            subscriptions.push(Key::Id(id[..len - i].to_vec()));
+        }
+
+        subscriptions.push(Key::Kind(event.kind().into()));
+        subscriptions.push(Key::AllUpdates);
+
+        subscriptions
+    }
+
+    /// Get the number of subscribers
+    pub fn total_subscribers(self: &Arc<Self>) -> usize {
+        self.total_subscribers.load(Ordering::Relaxed)
+    }
+
+    /// Subscribe to a future events
+    ///
+    /// This will add a new subscription to the subscription manager with a
+    /// given conn_id, sender and a vector of filters.
+    pub async fn subscribe(
+        self: &Arc<Self>,
+        conn_id: ConnectionId,
+        sender: Sender<Response>,
+        request: Subscribe,
+    ) -> Vec<ActiveSubscription> {
+        let name = request.subscription_id;
+        let mut subscriptions = self.subscriptions.write().await;
+        let subscriptions = request
+            .filters
+            .into_iter()
+            .map(|filter| {
+                let filter = Arc::new(SortedFilter::from(filter));
+                let subscription =
+                    ActiveSubscription::new(conn_id, name.clone(), filter.keys(), self.clone());
+                for key in subscription.keys.iter() {
+                    subscriptions.insert(
+                        (key.clone(), conn_id, name.clone()),
+                        (filter.clone(), sender.clone()),
+                    );
+                }
+                subscription
+            })
+            .collect::<Vec<_>>();
+        self.total_subscribers
+            .fetch_add(subscriptions.len(), Ordering::Relaxed);
+        subscriptions
+    }
+
+    /// Publish an event to all subscribers
+    pub fn broadcast(self: &Arc<Self>, event: Event) {
+        let this = self.clone();
+        tokio::spawn(async move {
+            let subscriptions = this.subscriptions.read().await;
+            let subs = Self::get_keys_from_event(&event, this.min_prefix_match_len);
+            let mut deliverded = BTreeSet::new();
+
+            for sub in subs {
+                for ((sub_type, client, name), (filter, sender)) in subscriptions.range(
+                    &(
+                        sub.clone(),
+                        ConnectionId::new_empty(),
+                        SubscriptionId::empty(),
+                    )..,
+                ) {
+                    if sub_type != &sub {
+                        break;
+                    }
+
+                    if deliverded.contains(client) || !filter.check(&event) {
+                        continue;
+                    }
+                    println!("send");
+
+                    let _ = sender.try_send(Response::Event((name, &event).into()));
+                    deliverded.insert(client.clone());
+                }
+            }
+        });
+    }
+}

+ 4 - 0
crates/relayer/src/subscription/mod.rs

@@ -0,0 +1,4 @@
+mod filter;
+mod manager;
+
+pub use self::manager::{ActiveSubscription, SubscriptionManager};

+ 1 - 5
crates/storage/base/src/lib.rs

@@ -7,7 +7,6 @@
 pub mod cursor;
 mod error;
 mod event_filter;
-mod notification;
 mod secondary_index;
 mod storage;
 
@@ -17,10 +16,7 @@ pub mod test;
 #[cfg(feature = "test")]
 pub use tokio;
 
-pub use crate::{
-    error::Error, event_filter::*, notification::Subscription, secondary_index::SecondaryIndex,
-    storage::Storage,
-};
+pub use crate::{error::Error, event_filter::*, secondary_index::SecondaryIndex, storage::Storage};
 
 #[macro_export]
 /// This macro creates the

+ 0 - 120
crates/storage/base/src/notification.rs

@@ -1,120 +0,0 @@
-use crate::{Error, Storage};
-use futures::Stream;
-use nostr_rs_types::types::{Addr, Event, Filter, Kind};
-use std::{
-    collections::HashMap,
-    pin::Pin,
-    sync::atomic::AtomicUsize,
-    task::{Context, Poll},
-};
-use tokio::sync::{mpsc::Sender, RwLock};
-
-#[allow(dead_code)]
-struct SubscriptionEntry {
-    pub id: usize,
-    pub filter: Filter,
-    pub sender: Sender<(usize, Event)>,
-}
-
-#[allow(dead_code)]
-enum SubscriptionListenerType {
-    Id(Addr),
-    Author(Addr),
-    Kind(Kind),
-    ReferenceToEvent(Addr),
-    ReferenceToPublicKey(Addr),
-}
-
-/// Subscription
-pub struct Subscription<T>
-where
-    T: Storage,
-{
-    db: T,
-    subscriptions: RwLock<HashMap<usize, SubscriptionEntry>>,
-    subscription_listener: RwLock<HashMap<SubscriptionListenerType, Vec<usize>>>,
-    last_id: AtomicUsize,
-}
-
-pub struct SubscriptionResultFromDb<I>
-where
-    I: Stream<Item = Result<Event, Error>>,
-{
-    iterator: I,
-}
-
-impl<I> Stream for SubscriptionResultFromDb<I>
-where
-    I: Stream<Item = Result<Event, Error>>,
-{
-    type Item = Result<Event, Error>;
-
-    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        // Safety: it's safe to use Pin::map_unchecked_mut because the iterator field
-        // is pinned as part of the SubscriptionResultFromDb struct
-        let iterator = unsafe { self.map_unchecked_mut(|s| &mut s.iterator) };
-        iterator.poll_next(cx)
-    }
-
-    fn size_hint(&self) -> (usize, Option<usize>) {
-        self.iterator.size_hint()
-    }
-}
-
-impl<T> Subscription<T>
-where
-    T: Storage,
-{
-    /// Wraps the storage layer to by pass the subscription/notification wrapper
-    pub fn new(db: T) -> Self {
-        Self {
-            db,
-            subscriptions: RwLock::new(HashMap::new()),
-            subscription_listener: RwLock::new(HashMap::new()),
-            last_id: AtomicUsize::new(0),
-        }
-    }
-
-    /// Gets an event from the wrapped storage
-    pub async fn get_event<T1: AsRef<[u8]> + Send + Sync>(
-        &self,
-        id: T1,
-    ) -> Result<Option<Event>, Error> {
-        self.db.get_event(id).await
-    }
-
-    /// Removes a subscription from the listener
-    pub async fn unsubscribe(self, subscription_id: usize) -> Result<(), Error> {
-        let mut subscribers = self.subscriptions.write().await;
-        let _ = subscribers.remove(&subscription_id);
-        Ok(())
-    }
-
-    /// Subscribes to a filter. The first streamed bytes will be reads from the
-    /// database.
-    pub async fn subscribe(
-        &self,
-        filter: Filter,
-        sender: Sender<(usize, Event)>,
-    ) -> Result<(usize, SubscriptionResultFromDb<T::Cursor<'_>>), Error> {
-        let mut subscribers = self.subscriptions.write().await;
-        let mut _subscription_listener = self.subscription_listener.write().await;
-        let id = self
-            .last_id
-            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
-        subscribers.insert(
-            id,
-            SubscriptionEntry {
-                id,
-                filter: filter.clone(),
-                sender,
-            },
-        );
-        Ok((
-            id,
-            SubscriptionResultFromDb {
-                iterator: self.db.get_by_filter(filter).await?,
-            },
-        ))
-    }
-}

+ 1 - 1
crates/storage/base/src/storage.rs

@@ -6,7 +6,7 @@ use nostr_rs_types::types::{Event, Filter};
 #[async_trait::async_trait]
 pub trait Storage: Send + Sync {
     /// Result iterators
-    type Cursor<'a>: Stream<Item = Result<Event, Error>> + Unpin
+    type Cursor<'a>: Stream<Item = Result<Event, Error>> + Unpin + Send
     where
         Self: 'a;
 

+ 1 - 1
crates/storage/base/src/test.rs

@@ -24,7 +24,7 @@ where
     }
 
     while db.is_flushing() {
-        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
+        tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
     }
 }
 

+ 9 - 6
crates/storage/rocksdb/src/cursor.rs

@@ -1,16 +1,15 @@
 //! Rocks DB implementation of the storage layer
-use crate::RocksDb;
+use crate::{ReferenceType, RocksDb};
 use futures::Stream;
 use nostr_rs_storage_base::{
     cursor::{check_future_call, FutureResult, FutureValue},
     Error, EventFilter, Storage,
 };
 use nostr_rs_types::types::Event;
-use rocksdb::{BoundColumnFamily, DBIteratorWithThreadMode, DB};
+use rocksdb::{DBIteratorWithThreadMode, DB};
 use std::{
     collections::VecDeque,
     pin::Pin,
-    sync::Arc,
     task::{Context, Poll},
 };
 
@@ -29,7 +28,7 @@ pub struct Cursor<'a> {
     /// Reference to the namespace to use to query the secondary index. If none
     /// is given the secondary_index_iterator must be constructed outside this
     /// wrapper.
-    index: Option<Arc<BoundColumnFamily<'a>>>,
+    index: Option<ReferenceType>,
     /// The current secondary index iterator. If none is given the iterator will
     /// try to create one using the namespace property and the first prefix from
     /// prefixes (it will also be copied to current_prefix)
@@ -49,7 +48,7 @@ pub struct Cursor<'a> {
 impl<'a> Cursor<'a> {
     pub fn new(
         db: &'a RocksDb,
-        index: Option<Arc<BoundColumnFamily<'a>>>,
+        index: Option<ReferenceType>,
         prefixes: Vec<Vec<u8>>,
         filter: Option<EventFilter>,
         secondary_index_iterator: Option<DBIteratorWithThreadMode<'a, DB>>,
@@ -73,10 +72,14 @@ impl<'a> Cursor<'a> {
     fn select_next_prefix_using_secondary_index(&mut self) -> Option<()> {
         self.index_iterator = None;
         let prefix = self.index_keys.pop_front()?;
+        let index = self
+            .index
+            .map(|index| self.db.reference_to_cf_handle(index).ok())?;
+
         self.index_iterator = Some(
             self.db
                 .db
-                .prefix_iterator_cf(self.index.as_ref()?, prefix.clone()),
+                .prefix_iterator_cf(index.as_ref()?, prefix.clone()),
         );
         self.current_index_key = prefix;
         Some(())

+ 6 - 11
crates/storage/rocksdb/src/lib.rs

@@ -11,7 +11,8 @@ use std::{collections::VecDeque, ops::Deref, path::Path, sync::Arc};
 mod cursor;
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
-enum ReferenceType {
+/// Internal index name
+pub enum ReferenceType {
     Events,
     Author,
     RefPublicKey,
@@ -223,21 +224,17 @@ impl Storage for RocksDb {
         };
 
         let (index, secondary_index_iterator, prefixes) = if !query.references_to_event.is_empty() {
-            let ns: Arc<BoundColumnFamily<'_>> =
-                self.reference_to_cf_handle(ReferenceType::RefEvent)?;
-
             let keys = std::mem::take(&mut query.references_to_event)
                 .into_iter()
                 .map(|c| c.take())
                 .collect();
-            (Some(ns), None, keys)
+            (Some(ReferenceType::RefEvent), None, keys)
         } else if !query.references_to_public_key.is_empty() {
-            let ns = self.reference_to_cf_handle(ReferenceType::RefEvent)?;
             let keys = std::mem::take(&mut query.references_to_public_key)
                 .into_iter()
                 .map(|c| c.take())
                 .collect();
-            (Some(ns), None, keys)
+            (Some(ReferenceType::RefPublicKey), None, keys)
         } else if !query.ids.is_empty() {
             let keys = std::mem::take(&mut query.ids)
                 .into_iter()
@@ -245,14 +242,12 @@ impl Storage for RocksDb {
                 .collect();
             (None, None, keys)
         } else if !query.authors.is_empty() {
-            let ns = self.reference_to_cf_handle(ReferenceType::Author)?;
             let keys = std::mem::take(&mut query.authors)
                 .into_iter()
                 .map(|c| c.take())
                 .collect();
-            (Some(ns), None, keys)
+            (Some(ReferenceType::Author), None, keys)
         } else if !query.kinds.is_empty() {
-            let ns = self.reference_to_cf_handle(ReferenceType::Kind)?;
             let keys = std::mem::take(&mut query.kinds)
                 .into_iter()
                 .map(|kind| {
@@ -260,7 +255,7 @@ impl Storage for RocksDb {
                     kind.to_be_bytes().to_vec()
                 })
                 .collect();
-            (Some(ns), None, keys)
+            (Some(ReferenceType::Kind), None, keys)
         } else {
             let cf_handle = self.reference_to_cf_handle(ReferenceType::Stream)?;
             (

+ 9 - 0
crates/types/src/relayer/event.rs

@@ -15,6 +15,15 @@ pub struct Event {
     pub event: types::Event,
 }
 
+impl From<(&SubscriptionId, &types::Event)> for Event {
+    fn from((subscription_id, event): (&SubscriptionId, &types::Event)) -> Self {
+        Self {
+            subscription_id: subscription_id.clone(),
+            event: event.clone(),
+        }
+    }
+}
+
 impl SerializeDeserialize for Event {
     fn get_tag() -> &'static str {
         "EVENT"

+ 2 - 2
crates/types/src/types/addr.rs

@@ -39,7 +39,7 @@ pub enum Error {
 /// Human Readable Part
 ///
 /// Which HDR has been used to encode this address with Bech32
-#[derive(Debug, Clone, PartialEq, Eq, Copy)]
+#[derive(Debug, Clone, Ord, PartialOrd, PartialEq, Eq, Copy)]
 pub enum HumanReadablePart {
     /// Public Key / Account Address
     NPub,
@@ -67,7 +67,7 @@ impl Display for HumanReadablePart {
 ///
 /// Clients may want to use the Bech32 encoded address *but* the protocol only
 /// cares about hex-encoded binary data.
-#[derive(Debug, Default, Clone, Eq)]
+#[derive(Debug, Default, Ord, PartialOrd, Clone, Eq)]
 pub struct Addr {
     /// Bytes (up to 32 bytes)
     pub bytes: Vec<u8>,

+ 1 - 1
crates/types/src/types/kind.rs

@@ -13,7 +13,7 @@ use serde::{
 /// Any unsupported Kind will be wrapped under the Unknown type
 ///
 /// The Kind is represented as a u32 on the wire
-#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
 pub enum Kind {
     /// Metadata
     ///

+ 8 - 1
crates/types/src/types/subscription_id.rs

@@ -23,9 +23,16 @@ pub enum Error {
 /// The rules are simple, any UTF-8 valid string with fewer than 32 characters
 ///
 /// By default a random ID will be created if needed.
-#[derive(Debug, Clone, Hash, PartialEq, Eq)]
+#[derive(Debug, Ord, PartialOrd, Clone, Hash, PartialEq, Eq)]
 pub struct SubscriptionId(String);
 
+impl SubscriptionId {
+    /// Create a new subscription ID
+    pub fn empty() -> Self {
+        Self("".to_owned())
+    }
+}
+
 impl Deref for SubscriptionId {
     type Target = str;
 

+ 13 - 8
src/main.rs

@@ -1,13 +1,10 @@
 use futures::Future;
 use nostr_rs_client::Pool;
+use nostr_rs_relayer::Relayer;
 use nostr_rs_rocksdb::RocksDb;
-use nostr_rs_types::{relayer, types::Filter, Request, Response};
-use std::{collections::HashMap, env, fs, pin::Pin, sync::Arc};
-use tokio::{
-    net::TcpListener,
-    sync::mpsc,
-    time::{sleep, Duration},
-};
+use nostr_rs_types::{types::Filter, Request, Response};
+use std::{env, fs, pin::Pin, sync::Arc};
+use tokio::{net::TcpListener, sync::mpsc};
 use url::Url;
 
 mod config;
@@ -86,6 +83,15 @@ async fn main() {
     ];
 
     let _ = client_pool.subscribe(initial_subscription.into()).await;
+
+    let relayer = Relayer::new(Some(db), Some(client_pool)).expect("relayer");
+
+    let addr = "127.0.0.1:3000";
+    let listener: TcpListener = TcpListener::bind(&addr).await.unwrap();
+
+    let _ = relayer.main(listener).expect("valid main").await;
+
+    /*
     loop {
         tokio::select! {
             Some((event, url)) = client_pool.recv() => {
@@ -94,7 +100,6 @@ async fn main() {
         }
     }
 
-    /*
     let db = RocksDb::new("./relayer-db").expect("db");
     let (relayer, mut server_receiver) = nostr_rs_relayer::Relayer::new(Some(db));
     let mut clients = nostr_rs_client::Pool::new()