Prechádzať zdrojové kódy

Merge branch 'subscription-bug' of cesar/nostr-prototype into main

Cesar Rodas 2 mesiacov pred
rodič
commit
4c7f472ea6

+ 1 - 1
crates/client/src/client.rs

@@ -155,7 +155,7 @@ impl Client {
                     tokio::select! {
                         Some(msg) = send_to_socket.recv() => {
                             if let Request::Request(sub) = &msg {
-                                if subscriptions.get(&sub.subscription_id).is_some() {
+                                if subscriptions.contains_key(&sub.subscription_id) {
                                     log::warn!("{}: Already subscribed to {}", url, sub.subscription_id);
                                     continue;
                                 }

+ 50 - 0
crates/relayer/src/connection/local.rs

@@ -0,0 +1,50 @@
+//! Local connection
+//!
+//! Add types for adding local connections
+use crate::{connection::ConnectionId, Error};
+use nostr_rs_types::{Request, Response};
+use tokio::sync::mpsc::{Receiver, Sender};
+
+/// Local connection
+pub struct LocalConnection {
+    sender: Sender<(ConnectionId, Request)>,
+    receiver: Receiver<Response>,
+    conn_id: ConnectionId,
+}
+
+impl LocalConnection {
+    /// Receive a message from the relayer
+    pub async fn recv(&mut self) -> Option<Response> {
+        self.receiver.recv().await
+    }
+
+    /// Sends a request to the relayer
+    pub async fn send(&self, request: Request) -> Result<(), Error> {
+        self.sender
+            .send((self.conn_id, request))
+            .await
+            .map_err(|e| Error::LocalSendError(Box::new(e)))
+    }
+}
+
+impl
+    From<(
+        ConnectionId,
+        Receiver<Response>,
+        Sender<(ConnectionId, Request)>,
+    )> for LocalConnection
+{
+    fn from(
+        value: (
+            ConnectionId,
+            Receiver<Response>,
+            Sender<(ConnectionId, Request)>,
+        ),
+    ) -> Self {
+        LocalConnection {
+            conn_id: value.0,
+            receiver: value.1,
+            sender: value.2,
+        }
+    }
+}

+ 13 - 6
crates/relayer/src/connection.rs → crates/relayer/src/connection/mod.rs

@@ -24,6 +24,10 @@ use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream};
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
 pub struct ConnectionId(usize);
 
+mod local;
+
+pub use local::LocalConnection;
+
 impl Default for ConnectionId {
     fn default() -> Self {
         static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
@@ -47,6 +51,8 @@ impl ConnectionId {
     }
 }
 
+type CompoundSubcription = (Option<PoolSubscription>, Vec<ActiveSubscription>);
+
 #[derive(Debug)]
 /// Relayer connection
 ///
@@ -55,8 +61,7 @@ impl ConnectionId {
 pub struct Connection {
     conn_id: ConnectionId,
     sender: Sender<Response>,
-    subscriptions:
-        RwLock<HashMap<SubscriptionId, (Option<PoolSubscription>, Vec<ActiveSubscription>)>>,
+    subscriptions: RwLock<HashMap<SubscriptionId, CompoundSubcription>>,
     handler: Option<JoinHandle<()>>,
 }
 
@@ -65,14 +70,14 @@ const MAX_SUBSCRIPTIONS_BUFFER: usize = 100;
 impl Drop for Connection {
     fn drop(&mut self) {
         if let Some(handler) = self.handler.take() {
-            let _ = handler.abort();
+            handler.abort();
         }
     }
 }
 
 impl Connection {
-    #[cfg(test)]
-    pub fn new_for_test() -> (Self, Receiver<Response>) {
+    /// Create a new local connection
+    pub fn new_local_connection() -> (Self, Receiver<Response>) {
         let (sender, receiver) = channel(MAX_SUBSCRIPTIONS_BUFFER);
         (
             Self {
@@ -136,7 +141,9 @@ impl Connection {
                             let msg: Result<Request, _> = serde_json::from_str(&msg);
                             match msg {
                                 Ok(msg) => {
-                                    let _ = send_message_to_relayer.send((conn_id, msg)).await;
+                                    if let Err(err) = send_message_to_relayer.send((conn_id, msg)).await {
+                                        log::error!("Error sending message to relayer: {}", err);
+                                    }
                                 },
                                 Err(err) => {
                                     log::error!("Error parsing message from client: {}", err);

+ 6 - 1
crates/relayer/src/error.rs

@@ -1,4 +1,5 @@
-use nostr_rs_types::Response;
+use crate::connection::ConnectionId;
+use nostr_rs_types::{Request, Response};
 
 #[derive(Debug, thiserror::Error)]
 /// Relayer error
@@ -19,6 +20,10 @@ pub enum Error {
     #[error("TrySendError: {0}")]
     TrySendError(#[from] Box<tokio::sync::mpsc::error::TrySendError<Response>>),
 
+    /// Tokio channel's error
+    #[error("LocalTrySendError: {0}")]
+    LocalSendError(#[from] Box<tokio::sync::mpsc::error::SendError<(ConnectionId, Request)>>),
+
     /// Client related errors
     #[error("Nostr client error: {0}")]
     Client(#[from] nostr_rs_client::Error),

+ 5 - 1
crates/relayer/src/lib.rs

@@ -15,4 +15,8 @@ mod error;
 mod relayer;
 mod subscription;
 
-pub use self::{connection::Connection, error::Error, relayer::Relayer};
+pub use self::{
+    connection::{Connection, LocalConnection},
+    error::Error,
+    relayer::Relayer,
+};

+ 114 - 20
crates/relayer/src/relayer.rs

@@ -1,4 +1,8 @@
-use crate::{connection::ConnectionId, subscription::SubscriptionManager, Connection, Error};
+use crate::{
+    connection::{ConnectionId, LocalConnection},
+    subscription::SubscriptionManager,
+    Connection, Error,
+};
 use futures_util::StreamExt;
 use nostr_rs_client::{Error as ClientError, Pool};
 use nostr_rs_storage_base::Storage;
@@ -21,9 +25,10 @@ pub struct Relayer<T: Storage + Send + Sync + 'static> {
     /// relayer just a dumb proxy (that can be useful for privacy) but it won't
     /// be able to perform any optimization like prefetching content while offline
     storage: Option<T>,
-    /// x
+    /// Subscription manager
     subscriptions: Arc<SubscriptionManager>,
-    clients: RwLock<HashMap<ConnectionId, Connection>>,
+    /// List of all active connections
+    connections: RwLock<HashMap<ConnectionId, Connection>>,
     /// This Sender can be used to send requests from anywhere to the relayer.
     send_to_relayer: Sender<(ConnectionId, Request)>,
     /// This Receiver is the relayer the way the relayer receives messages
@@ -37,7 +42,7 @@ pub struct Relayer<T: Storage + Send + Sync + 'static> {
 impl<T: Storage + Send + Sync + 'static> Drop for Relayer<T> {
     fn drop(&mut self) {
         if let Some((_, handle)) = self.client_pool.take() {
-            let _ = handle.abort();
+            handle.abort();
         }
     }
 }
@@ -58,7 +63,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
             subscriptions: Default::default(),
             send_to_relayer: sender.clone(),
             relayer_receiver: Some(receiver),
-            clients: Default::default(),
+            connections: Default::default(),
             client_pool: if let Some(client_pool) = client_pool {
                 Some(Self::handle_client_pool(client_pool, sender)?)
             } else {
@@ -99,12 +104,12 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                                 if let Some(storage) = this.storage.as_ref() {
                                     let _ = storage.store_local_event(&event).await;
                                 }
-                                this.broadcast(&event.deref()).await;
+                                this.broadcast(event.deref()).await;
                             }
                             continue;
                         }
 
-                        let connections = this.clients.read().await;
+                        let connections = this.connections.read().await;
                         let connection = if let Some(connection) = connections.get(&conn_id) {
                             connection
                         } else {
@@ -160,6 +165,15 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
         &self.storage
     }
 
+    /// Adds a new local connection to the list of active connections.
+    pub async fn create_new_local_connection(&self) -> LocalConnection {
+        let (conn, receiver) = Connection::new_local_connection();
+        let conn_id = conn.get_conn_id();
+        self.connections.write().await.insert(conn_id, conn);
+
+        (conn_id, receiver, self.send_to_relayer.clone()).into()
+    }
+
     /// Adds a new TpStream and adds it to the list of active connections.
     ///
     /// This function will spawn the client's loop to receive incoming messages and send those messages
@@ -168,11 +182,11 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
         disconnection_notify: Option<mpsc::Sender<ConnectionId>>,
         stream: TcpStream,
     ) -> Result<ConnectionId, Error> {
-        let client =
+        let conn =
             Connection::new_connection(self.send_to_relayer.clone(), disconnection_notify, stream)
                 .await?;
-        let id = client.get_conn_id();
-        self.clients.write().await.insert(id, client);
+        let id = conn.get_conn_id();
+        self.connections.write().await.insert(id, conn);
 
         Ok(id)
     }
@@ -195,7 +209,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                 if let Some((client_pool, _)) = self.client_pool.as_ref() {
                     // pass the event to the pool of clients, so this relayer can relay
                     // their local events to the clients in the network of relayers
-                    let _ = client_pool.post(event.clone().into()).await;
+                    let _ = client_pool.post(event.clone()).await;
                 }
             }
             Request::Request(request) => {
@@ -249,7 +263,7 @@ impl<T: Storage + Send + Sync + 'static> Relayer<T> {
                     .await;
             }
             Request::Close(close) => {
-                connection.unsubscribe(&*close).await;
+                connection.unsubscribe(close).await;
             }
         };
 
@@ -320,7 +334,7 @@ mod test {
         if prefill {
             let events = include_str!("../tests/events.json")
                 .lines()
-                .map(|line| serde_json::from_str(&line).expect("valid"))
+                .map(|line| serde_json::from_str(line).expect("valid"))
                 .collect::<Vec<Event>>();
 
             for event in events {
@@ -389,7 +403,7 @@ mod test {
         ]))
         .expect("valid object");
         let relayer = Relayer::new(Some(get_db(true).await), None).expect("valid relayer");
-        let (connection, mut recv) = Connection::new_for_test();
+        let (connection, mut recv) = Connection::new_local_connection();
         let _ = relayer
             .process_request_from_client(&connection, request)
             .await;
@@ -461,6 +475,86 @@ mod test {
     }
 
     #[tokio::test]
+    async fn server_listener_real_time_single_argument() {
+        let request: Request = serde_json::from_value(json!(
+            [
+                "REQ",
+                "1298169700973717",
+                {
+                    "authors": ["39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],
+                    "since":1681939304
+                },
+                {
+                    "#p":["39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],
+                    "since":1681939304
+                },
+                {
+                    "#p":["39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],
+                },
+                {
+                    "authors":["39a7d06e824c0c2523bedb93f0cef84245e4401fee03b6257a1c6dfd18b57efb"],
+                },
+                {
+                    "#e":[
+                        "2e72250d80e9b3fd30230b3db3ed7d22f15d266ed345c36700b01ec153c9e28a",
+                        "a5e3369c43daf2675ecbce18831e5f4e07db0d4dde0ef4f5698e645e4c46eed1"
+                    ],
+                }
+        ]))
+        .expect("valid object");
+        let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
+        let (connection, mut recv) = Connection::new_local_connection();
+
+        assert_eq!(relayer.total_subscribers(), 0);
+        let _ = relayer
+            .process_request_from_client(&connection, request)
+            .await;
+
+        assert_eq!(relayer.total_subscribers(), 5);
+
+        // eod
+        assert!(recv
+            .try_recv()
+            .expect("valid")
+            .as_end_of_stored_events()
+            .is_some());
+
+        // It is empty
+        assert!(recv.try_recv().is_err());
+
+        relayer
+            .process_request_from_client(&connection, get_note())
+            .await
+            .expect("process event");
+
+        sleep(Duration::from_millis(100)).await;
+
+        // It is not empty
+        let msg = recv.try_recv();
+        assert!(msg.is_ok());
+        assert_eq!(
+            msg.expect("is ok")
+                .as_event()
+                .expect("valid")
+                .subscription_id
+                .to_string(),
+            "1298169700973717".to_owned()
+        );
+
+        // it must be deliverd at most once
+        assert!(recv.try_recv().is_err());
+        assert_eq!(relayer.total_subscribers(), 5);
+
+        // when client is dropped, the subscription is removed
+        // automatically
+        drop(connection);
+
+        sleep(Duration::from_millis(10)).await;
+
+        assert_eq!(relayer.total_subscribers(), 0);
+    }
+
+    #[tokio::test]
     async fn server_listener_real_time() {
         let request: Request = serde_json::from_value(json!(
             [
@@ -493,7 +587,7 @@ mod test {
         ]))
         .expect("valid object");
         let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
-        let (connection, mut recv) = Connection::new_for_test();
+        let (connection, mut recv) = Connection::new_local_connection();
 
         assert_eq!(relayer.total_subscribers(), 0);
         let _ = relayer
@@ -557,7 +651,7 @@ mod test {
         .expect("valid object");
 
         let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
-        let (connection, mut recv) = Connection::new_for_test();
+        let (connection, mut recv) = Connection::new_local_connection();
 
         assert_eq!(relayer.total_subscribers(), 0);
         let _ = relayer
@@ -620,14 +714,14 @@ mod test {
         .expect("valid object");
 
         let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
-        let (publisher, _) = Connection::new_for_test();
+        let (publisher, _) = Connection::new_local_connection();
 
         let mut set1 = (0..1000)
-            .map(|_| Connection::new_for_test())
+            .map(|_| Connection::new_local_connection())
             .collect::<Vec<_>>();
 
         let mut set2 = (0..100)
-            .map(|_| Connection::new_for_test())
+            .map(|_| Connection::new_local_connection())
             .collect::<Vec<_>>();
 
         let subscribe1 = set1
@@ -712,7 +806,7 @@ mod test {
             serde_json::from_value(json!(["REQ", "1298169700973717", {}])).expect("valid object");
 
         let relayer = Relayer::new(Some(get_db(false).await), None).expect("valid relayer");
-        let (connection, mut recv) = Connection::new_for_test();
+        let (connection, mut recv) = Connection::new_local_connection();
 
         assert_eq!(relayer.total_subscribers(), 0);
         let _ = relayer

+ 10 - 12
crates/relayer/src/subscription/filter.rs

@@ -4,6 +4,7 @@ use std::collections::BTreeSet;
 /// The subscription keys are used to quickly identify the subscriptions
 /// by one or more fields. Think of it like a database index
 #[derive(Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash)]
+#[allow(clippy::enum_variant_names)]
 pub(crate) enum Key {
     /// Key for the author field
     Author(Vec<u8>, Option<Kind>),
@@ -81,27 +82,27 @@ impl SortedFilter {
         let authors = self
             .authors
             .as_ref()
-            .map_or_else(|| vec![], |x| x.iter().map(|x| x.clone()).collect());
+            .map_or_else(std::vec::Vec::new, |x| x.iter().cloned().collect());
 
         let ids = self
             .ids
             .as_ref()
-            .map_or_else(|| vec![], |x| x.iter().map(|x| x.clone()).collect());
+            .map_or_else(std::vec::Vec::new, |x| x.iter().cloned().collect());
 
         let references_to_event = self
             .references_to_event
             .as_ref()
-            .map_or_else(|| vec![], |x| x.iter().map(|x| x.clone()).collect());
+            .map_or_else(std::vec::Vec::new, |x| x.iter().cloned().collect());
 
         let references_to_public_key = self
             .references_to_public_key
             .as_ref()
-            .map_or_else(|| vec![], |x| x.iter().map(|x| x.clone()).collect());
+            .map_or_else(std::vec::Vec::new, |x| x.iter().cloned().collect());
 
         let kinds = self
             .kinds
             .as_ref()
-            .map_or_else(|| vec![], |x| x.iter().map(|x| *x).collect());
+            .map_or_else(std::vec::Vec::new, |x| x.iter().copied().collect());
 
         let kind_option = if kinds.is_empty() {
             vec![None]
@@ -109,7 +110,7 @@ impl SortedFilter {
             kinds.clone().into_iter().map(Some).collect()
         };
 
-        let keys = vec![
+        let keys = [
             authors
                 .into_iter()
                 .map(|author| {
@@ -137,11 +138,8 @@ impl SortedFilter {
                         .collect::<Vec<_>>()
                 })
                 .collect::<Vec<_>>(),
-            vec![kinds
-                .into_iter()
-                .map(|kind| Key::Kind(kind))
-                .collect::<Vec<_>>()],
-            vec![ids.into_iter().map(|id| Key::Id(id)).collect::<Vec<_>>()],
+            vec![kinds.into_iter().map(Key::Kind).collect::<Vec<_>>()],
+            vec![ids.into_iter().map(Key::Id).collect::<Vec<_>>()],
         ]
         .concat()
         .concat();
@@ -198,7 +196,7 @@ impl SortedFilter {
             .as_ref()
             .map_or(true, |kinds| kinds.contains(&event.kind()))
             && Self::has_key_or_partial_match(&event.id, &self.ids)
-            && Self::has_key_or_partial_match(&event.author(), &self.authors)
+            && Self::has_key_or_partial_match(event.author(), &self.authors)
             && Self::has_any_key_or_partial_match(
                 event.tags().iter().filter_map(|f| match f {
                     Tag::Event(x) => Some(x.id.as_ref().to_vec()),

+ 19 - 25
crates/relayer/src/subscription/manager.rs

@@ -16,7 +16,7 @@ use tokio::sync::{mpsc::Sender, RwLock};
 
 type SubIdx = (Key, ConnectionId, SubscriptionId);
 
-pub const MIN_PREFIX_MATCH_LEN: usize = 4;
+pub const MIN_PREFIX_MATCH_LEN: usize = 2;
 
 /// Subscription for a connection
 ///
@@ -65,6 +65,8 @@ impl Drop for ActiveSubscription {
     }
 }
 
+type SubscriptionValue = (Arc<SortedFilter>, Sender<Response>);
+
 /// Subscription manager
 ///
 /// This object is responsible for letting clients and processes subscribe to
@@ -75,7 +77,7 @@ pub struct SubscriptionManager {
     ///
     /// A single request may be converted to multiple subscriptions entry as
     /// they are sorted by their index/key
-    subscriptions: RwLock<BTreeMap<SubIdx, (Arc<SortedFilter>, Sender<Response>)>>,
+    subscriptions: RwLock<BTreeMap<SubIdx, SubscriptionValue>>,
     /// Total number of subscribers
     /// A single REQ may have multiple subscriptions
     total_subscribers: AtomicUsize,
@@ -108,44 +110,36 @@ impl SubscriptionManager {
         let author = event.author().as_ref().to_vec();
         let id = event.id.as_ref().to_vec();
 
-        let len = author.len();
-        for i in min_prefix_match_len..len - min_prefix_match_len {
-            subscriptions.push(Key::Author(author[..len - i].to_vec(), None));
-            subscriptions.push(Key::Author(author[..len - i].to_vec(), Some(event.kind())));
+        for i in min_prefix_match_len..=author.len() {
+            subscriptions.push(Key::Author(author[..i].to_vec(), None));
+            subscriptions.push(Key::Author(author[..i].to_vec(), Some(event.kind())));
         }
 
         for t in event.tags() {
             match t {
                 nostr_rs_types::types::Tag::Event(ref_event) => {
-                    let len = ref_event.id.len();
-                    for i in min_prefix_match_len..ref_event.id.len() - min_prefix_match_len {
-                        subscriptions.push(Key::RefId(ref_event.id[..len - i].to_vec(), None));
-                        subscriptions.push(Key::RefId(
-                            ref_event.id[..len - i].to_vec(),
-                            Some(event.kind()),
-                        ));
+                    for i in min_prefix_match_len..=ref_event.id.len() {
+                        subscriptions.push(Key::RefId(ref_event.id[..i].to_vec(), None));
+                        subscriptions
+                            .push(Key::RefId(ref_event.id[..i].to_vec(), Some(event.kind())));
                     }
                 }
                 nostr_rs_types::types::Tag::PubKey(ref_pub_key) => {
-                    let len = ref_pub_key.id.len();
-                    for i in min_prefix_match_len..len - min_prefix_match_len {
-                        subscriptions.push(Key::RefId(ref_pub_key.id[..len - i].to_vec(), None));
-                        subscriptions.push(Key::RefId(
-                            ref_pub_key.id[..len - i].to_vec(),
-                            Some(event.kind()),
-                        ));
+                    for i in min_prefix_match_len..=ref_pub_key.id.len() {
+                        subscriptions.push(Key::RefId(ref_pub_key.id[..i].to_vec(), None));
+                        subscriptions
+                            .push(Key::RefId(ref_pub_key.id[..i].to_vec(), Some(event.kind())));
                     }
                 }
                 _ => {}
             }
         }
 
-        let len = id.len();
-        for i in min_prefix_match_len..len - min_prefix_match_len {
-            subscriptions.push(Key::Id(id[..len - i].to_vec()));
+        for i in min_prefix_match_len..=id.len() {
+            subscriptions.push(Key::Id(id[..i].to_vec()));
         }
 
-        subscriptions.push(Key::Kind(event.kind().into()));
+        subscriptions.push(Key::Kind(event.kind()));
         subscriptions.push(Key::AllUpdates);
 
         subscriptions
@@ -214,7 +208,7 @@ impl SubscriptionManager {
                     }
 
                     let _ = sender.try_send(Response::Event((name, &event).into()));
-                    deliverded.insert(client.clone());
+                    deliverded.insert(*client);
                 }
             }
         });

+ 2 - 3
crates/storage/memory/src/lib.rs

@@ -7,7 +7,6 @@ use std::{
         atomic::{AtomicUsize, Ordering},
         Arc,
     },
-    u64,
 };
 use tokio::sync::RwLock;
 
@@ -80,7 +79,7 @@ impl Storage for Memory {
                 .ids_by_time
                 .write()
                 .await
-                .insert(secondary_index.index_by(&time_desc), event_id.clone());
+                .insert(secondary_index.index_by(time_desc), event_id.clone());
 
             indexes
                 .author
@@ -133,7 +132,7 @@ impl Storage for Memory {
     async fn set_local_event(&self, event: &Event) -> Result<(), Error> {
         let mut local_events = self.indexes.local_events.write().await;
         local_events.insert(
-            SecondaryIndex::new(&event.id, event.created_at()).index_by(&[]),
+            SecondaryIndex::new(&event.id, event.created_at()).index_by([]),
             event.id.0.to_vec(),
         );
         Ok(())

+ 1 - 1
crates/types/src/types/signature.rs

@@ -29,7 +29,7 @@ pub struct Signature(pub [u8; 64]);
 
 impl From<secp256k1::schnorr::Signature> for Signature {
     fn from(signature: secp256k1::schnorr::Signature) -> Self {
-        Self(signature.as_ref().clone())
+        Self(*signature.as_ref())
     }
 }