Quellcode durchsuchen

Working on a better subscription

Cesar Rodas vor 3 Monaten
Ursprung
Commit
8a7d7009f4

+ 2 - 0
Cargo.lock

@@ -947,6 +947,7 @@ dependencies = [
  "nostr-rs-storage-base",
  "nostr-rs-types",
  "thiserror",
+ "tokio",
  "url",
 ]
 
@@ -954,6 +955,7 @@ dependencies = [
 name = "nostr-rs-relayer"
 version = "0.1.0"
 dependencies = [
+ "futures",
  "futures-util",
  "log",
  "nostr-rs-client",

+ 1 - 0
crates/personal-relayer/Cargo.toml

@@ -11,3 +11,4 @@ nostr-rs-relayer = { path = "../relayer" }
 thiserror = "1.0.39"
 url = { version = "2.5.2", features = ["serde"] }
 futures = "0.3.30"
+tokio = { version = "1.39.2", features = ["full"] }

+ 34 - 5
crates/personal-relayer/src/lib.rs

@@ -3,17 +3,39 @@ use nostr_rs_client::Pool;
 use nostr_rs_relayer::Relayer;
 use nostr_rs_storage_base::Storage;
 use nostr_rs_types::types::{Addr, Filter};
+use tokio::{net::TcpListener, task::JoinHandle};
 use url::Url;
 
-pub struct PersonalRelayer<T: Storage + Send + Sync + 'static> {
-    relayer: Relayer<T>,
-    accounts: Vec<Addr>,
+pub struct Stoppable(Option<Vec<JoinHandle<()>>>);
+
+impl From<Vec<JoinHandle<()>>> for Stoppable {
+    fn from(value: Vec<JoinHandle<()>>) -> Self {
+        Self(Some(value))
+    }
+}
+
+impl Drop for Stoppable {
+    fn drop(&mut self) {
+        if let Some(tasks) = self.0.take() {
+            for join_handle in tasks.into_iter() {
+                join_handle.abort();
+            }
+        }
+    }
 }
 
 #[derive(thiserror::Error, Debug)]
 pub enum Error {
     #[error("Relayer: {0}")]
     Relayer(#[from] nostr_rs_relayer::Error),
+
+    #[error("Client error: {0}")]
+    Client(#[from] nostr_rs_client::Error),
+}
+
+pub struct PersonalRelayer<T: Storage + Send + Sync + 'static> {
+    relayer: Relayer<T>,
+    accounts: Vec<Addr>,
 }
 
 impl<T: Storage + Send + Sync + 'static> PersonalRelayer<T> {
@@ -24,7 +46,7 @@ impl<T: Storage + Send + Sync + 'static> PersonalRelayer<T> {
     ) -> Result<Self, Error> {
         let pool = Pool::new_with_clients(client_urls);
 
-        let subscriptions = join_all(
+        join_all(
             accounts
                 .iter()
                 .map(|account| {
@@ -38,11 +60,18 @@ impl<T: Storage + Send + Sync + 'static> PersonalRelayer<T> {
                 })
                 .collect::<Vec<_>>(),
         )
-        .await;
+        .await
+        .into_iter()
+        .collect::<Result<Vec<_>, _>>()?;
 
         Ok(Self {
             relayer: Relayer::new(Some(storage), Some(pool))?,
             accounts,
         })
     }
+
+    pub fn main(self, server: TcpListener) -> Result<Stoppable, Error> {
+        let tasks = vec![self.relayer.main(server)?, tokio::spawn(async move {})];
+        Ok(tasks.into())
+    }
 }

+ 0 - 3
crates/personal-relayer/src/main.rs

@@ -1,3 +0,0 @@
-fn main() {
-    println!("Hello, world!");
-}

+ 1 - 0
crates/relayer/Cargo.toml

@@ -20,6 +20,7 @@ thiserror = "1.0.39"
 serde_json = "1.0.94"
 rand = "0.8.5"
 log = "0.4.17"
+futures = "0.3.30"
 
 [dev-dependencies]
 nostr-rs-memory = { path = "../storage/memory" }

+ 21 - 6
crates/relayer/src/connection.rs

@@ -5,7 +5,10 @@ use nostr_rs_types::{
     types::Addr,
     Request, Response,
 };
-use std::collections::HashMap;
+use std::{
+    collections::HashMap,
+    sync::atomic::{AtomicUsize, Ordering},
+};
 use tokio::{
     net::TcpStream,
     sync::{
@@ -17,6 +20,16 @@ use tokio::{
 #[allow(unused_imports)]
 use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream};
 
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
+pub struct ConnectionId(pub(crate) usize);
+
+impl Default for ConnectionId {
+    fn default() -> Self {
+        static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
+        Self(NEXT_ID.fetch_add(1, Ordering::SeqCst))
+    }
+}
+
 #[derive(Debug)]
 /// Relayer connection
 ///
@@ -24,9 +37,9 @@ use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream};
 /// upcoming messages form the client
 pub struct Connection {
     #[allow(unused)]
-    pub(crate) conn_id: u128,
+    pub(crate) conn_id: ConnectionId,
     sender: Sender<Response>,
-    subscriptions: RwLock<HashMap<String, u128>>,
+    subscriptions: RwLock<HashMap<String, ConnectionId>>,
     handler: Option<JoinHandle<()>>,
 }
 
@@ -46,7 +59,7 @@ impl Connection {
         let (sender, receiver) = channel(MAX_SUBSCRIPTIONS_BUFFER);
         (
             Self {
-                conn_id: 0,
+                conn_id: ConnectionId::default(),
                 sender,
                 subscriptions: RwLock::new(HashMap::new()),
                 handler: None,
@@ -156,9 +169,11 @@ impl Connection {
     }
 
     /// Create a subscription for this connection
-    pub async fn create_subscription(&self, id: String) -> (u128, Sender<Response>) {
+    pub async fn create_subscription(&self, id: String) -> (ConnectionId, Sender<Response>) {
         let mut subscriptions = self.subscriptions.write().await;
-        let internal_id = subscriptions.entry(id).or_insert_with(get_id);
+        let internal_id = subscriptions
+            .entry(id)
+            .or_insert_with(ConnectionId::default);
         (*internal_id, self.sender.clone())
     }
 }

+ 5 - 9
crates/relayer/src/relayer.rs

@@ -53,7 +53,7 @@ pub struct Relayer<T: Storage + Send + Sync + 'static> {
     /// Client pool
     ///
     /// A relayer can optionally be connected to a pool of clients to get foreign events.
-    client_pool: Option<(RwLock<Pool>, JoinHandle<()>)>,
+    client_pool: Option<(Pool, JoinHandle<()>)>,
 }
 
 impl<T: Storage + Send + Sync + 'static> Drop for Relayer<T> {
@@ -143,7 +143,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
     fn handle_client_pool(
         client_pool: Pool,
         sender: Sender<(u128, Request)>,
-    ) -> Result<(RwLock<Pool>, JoinHandle<()>), ClientError> {
+    ) -> Result<(Pool, JoinHandle<()>), ClientError> {
         let (mut receiver, client_pool) = client_pool.split()?;
 
         let handle = tokio::spawn(async move {
@@ -161,7 +161,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
             }
         });
 
-        Ok((RwLock::new(client_pool), handle))
+        Ok((client_pool, handle))
     }
 
     /// Returns a reference to the internal database
@@ -203,18 +203,14 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                 if let Some((client_pool, _)) = self.client_pool.as_ref() {
                     // pass the event to the pool of clients, so this relayer can relay
                     // their local events to the clients in the network of relayers
-                    let _ = client_pool.write().await.post(event.clone().into()).await;
+                    let _ = client_pool.post(event.clone().into()).await;
                 }
             }
             Request::Request(request) => {
                 if let Some((client_pool, _)) = self.client_pool.as_ref() {
                     // pass the subscription request to the pool of clients, so this relayer
                     // can relay any unknown event to the clients through their subscriptions
-                    let _ = client_pool
-                        .write()
-                        .await
-                        .subscribe(request.filters.clone().into())
-                        .await;
+                    let _ = client_pool.subscribe(request.filters.clone().into()).await;
                 }
 
                 // Create subscription

+ 100 - 0
crates/relayer/src/subscription/manager.rs

@@ -0,0 +1,100 @@
+use crate::{connection::ConnectionId, subscription::Key};
+use futures::executor::block_on;
+use nostr_rs_types::{
+    relayer,
+    types::{Event, SubscriptionId},
+    Response,
+};
+use std::{collections::BTreeMap, sync::Arc};
+use tokio::sync::{mpsc::Sender, RwLock};
+
+type SubIdx = (Key, ConnectionId, SubscriptionId);
+
+/// Subscription for a connection
+///
+/// This object is responsible for keeping track of a subscription for a connection
+///
+/// When dropped their listener will be removed from the subscription manager automatically
+pub struct SubscriptionForConnection {
+    conn_id: ConnectionId,
+    name: SubscriptionId,
+    keys: Vec<Key>,
+    manager: Arc<SubscriptionManager>,
+}
+
+impl Drop for SubscriptionForConnection {
+    fn drop(&mut self) {
+        let keys = self
+            .keys
+            .drain(..)
+            .map(|key| (key, self.conn_id, self.name.clone()))
+            .collect::<Vec<_>>();
+
+        block_on(async {
+            self.manager.clone().unsubscribe(keys).await;
+        });
+    }
+}
+
+/// Subscription manager
+///
+/// This object is responsible for letting clients and processes subscribe to
+/// events,
+#[derive(Default)]
+pub struct SubscriptionManager {
+    subscriptions: Arc<RwLock<BTreeMap<SubIdx, Sender<Response>>>>,
+}
+
+impl SubscriptionManager {
+    pub async fn unsubscribe(&self, keys: Vec<SubIdx>) {
+        let mut subscriptions = self.subscriptions.write().await;
+        for sub in keys {
+            subscriptions.remove(&sub);
+        }
+    }
+
+    fn get_keys_from_event(event: &Event) -> Vec<Key> {
+        let mut subscriptions = vec![Key::Kind(event.kind().into())];
+
+        let author = event.author().as_ref().to_vec();
+        let id = event.id.as_ref().to_vec();
+
+        for i in 4..author.len() {
+            subscriptions.push(Key::Author(author[..author.len() - i].to_vec()));
+        }
+        for i in 4..id.len() {
+            subscriptions.push(Key::Id(author[..author.len() - i].to_vec()));
+        }
+
+        subscriptions
+    }
+
+    pub fn publish(&self, event: Event) {
+        let subscriptions = self.subscriptions.clone();
+        tokio::spawn(async move {
+            let subscriptions = subscriptions.read().await;
+            let subs = Self::get_keys_from_event(&event);
+
+            let empty_sub_id = SubscriptionId::default();
+
+            for sub in subs {
+                for ((sub_type, _, sub_id), sender) in
+                    subscriptions.range(&(sub.clone(), ConnectionId(0), empty_sub_id.clone())..)
+                {
+                    if sub_type != &sub {
+                        break;
+                    }
+
+                    let _ = sender
+                        .send(Response::Event(relayer::Event {
+                            subscription_id: sub_id.clone(),
+                            event: event.clone(),
+                        }))
+                        .await;
+                }
+            }
+
+            todo!()
+        });
+    }
+}

+ 13 - 0
crates/relayer/src/subscription.rs → crates/relayer/src/subscription/mod.rs

@@ -1,5 +1,18 @@
 use nostr_rs_types::types::{Event, Filter, Tag};
 
+mod manager;
+
+/// The subscription keys are used to quickly identify the subscriptions
+/// by one or more fields. Think of it like a database index
+#[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)]
+pub enum Key {
+    Author(Vec<u8>),
+    RefPublicKey(Vec<u8>),
+    RefId(Vec<u8>),
+    Id(Vec<u8>),
+    Kind(u32),
+}
+
 #[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash)]
 /// Client subscription
 pub struct Subscription {

+ 1 - 1
crates/types/src/types/subscription_id.rs

@@ -23,7 +23,7 @@ pub enum Error {
 /// The rules are simple, any UTF-8 valid string with fewer than 32 characters
 ///
 /// By default a random ID will be created if needed.
-#[derive(Debug, Clone, Hash, PartialEq, Eq)]
+#[derive(Debug, Ord, PartialOrd, Clone, Hash, PartialEq, Eq)]
 pub struct SubscriptionId(String);
 
 impl Deref for SubscriptionId {