Răsfoiți Sursa

Introduce async storage traits

Having sync traits is suboptimal. Replace the iterator for a stream.
Cesar Rodas 3 luni în urmă
părinte
comite
47739c54e6

+ 1 - 0
Cargo.lock

@@ -968,6 +968,7 @@ version = "0.1.0"
 dependencies = [
  "async-trait",
  "chrono",
+ "futures",
  "nostr-rs-storage-base",
  "nostr-rs-types",
  "rocksdb",

+ 12 - 11
crates/relayer/src/relayer.rs

@@ -1,4 +1,5 @@
 use crate::{Connection, Error, Subscription};
+use futures_util::StreamExt;
 use nostr_rs_storage_base::Storage;
 use nostr_rs_types::{
     relayer,
@@ -122,17 +123,17 @@ impl<T: Storage> Relayer<T> {
                 if let Some(storage) = self.storage.as_ref() {
                     // Sent all events that match the filter that are stored in our database
                     for filter in request.filters.clone().into_iter() {
-                        storage.get_by_filter(filter).await?.for_each(|event| {
-                            if let Ok(event) = event {
-                                let _ = connection.send(
-                                    relayer::Event {
-                                        subscription_id: request.subscription_id.clone(),
-                                        event,
-                                    }
-                                    .into(),
-                                );
-                            }
-                        });
+                        let mut result = storage.get_by_filter(filter).await?;
+
+                        while let Some(Ok(event)) = result.next().await {
+                            let _ = connection.send(
+                                relayer::Event {
+                                    subscription_id: request.subscription_id.clone(),
+                                    event,
+                                }
+                                .into(),
+                            );
+                        }
                     }
                 }
 

+ 2 - 6
crates/storage/base/src/error.rs

@@ -6,7 +6,7 @@ use std::num::TryFromIntError;
 pub enum Error {
     /// Internal database error
     #[error("Unknown: {0}")]
-    Unknown(String),
+    Internal(String),
 
     /// Serialization error
     #[error("Serde: {0}")]
@@ -14,13 +14,9 @@ pub enum Error {
 
     /// Internal error while converting types to integer
     #[error("Internal error: {0}")]
-    Internal(#[from] TryFromIntError),
+    IntErr(#[from] TryFromIntError),
 
     /// Transaction error
     #[error("Tx: {0}")]
     Tx(String),
-
-    /// Internal error
-    #[error("Unknown family column")]
-    InvalidColumnFamily,
 }

+ 1 - 1
crates/storage/base/src/storage.rs

@@ -6,7 +6,7 @@ use nostr_rs_types::types::{Event, Filter};
 #[async_trait::async_trait]
 pub trait Storage: Send + Sync {
     /// Result iterators
-    type Stream<'a>: Stream<Item = Result<Event, Error>>
+    type Stream<'a>: Stream<Item = Result<Event, Error>> + Unpin
     where
         Self: 'a;
 

+ 41 - 26
crates/storage/base/src/test.rs

@@ -3,6 +3,7 @@
 //! This crate will storage events into a database. It will also build index to
 //! find events by their tags, kind and references.
 use super::*;
+use futures::{StreamExt, TryStreamExt};
 use nostr_rs_types::types::{Addr, Event, Filter, Kind};
 use std::{
     fs::File,
@@ -47,7 +48,7 @@ where
         .try_into()
         .expect("pk");
 
-    let vec = db
+    let vec: Vec<Event> = db
         .get_by_filter(Filter {
             authors: vec![pk],
             limit: 10,
@@ -55,7 +56,8 @@ where
         })
         .await
         .expect("set of results")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
 
     let dates = vec.iter().map(|e| e.created_at()).collect::<Vec<_>>();
@@ -72,7 +74,7 @@ where
 {
     setup_db(db).await;
 
-    let related_events = db
+    let related_events: Vec<Event> = db
         .get_by_filter(Filter {
             references_to_event: vec![
                 "f513f1422ee5dbf30f57118b6cc34e788746e589a9b07be767664a164c57b9b1"
@@ -88,7 +90,8 @@ where
         })
         .await
         .expect("valid")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
     assert_eq!(related_events.len(), 1);
 }
@@ -99,7 +102,7 @@ where
 {
     setup_db(db);
 
-    let related_events = db
+    let related_events: Vec<Event> = db
         .get_by_filter(Filter {
             references_to_event: vec![
                 "42224859763652914db53052103f0b744df79dfc4efef7e950fc0802fc3df3c5"
@@ -115,7 +118,8 @@ where
         })
         .await
         .expect("valid")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
     assert_eq!(related_events.len(), 0);
 }
@@ -126,7 +130,7 @@ where
 {
     setup_db(db).await;
 
-    let related_events = db
+    let related_events: Vec<Event> = db
         .get_by_filter(Filter {
             kinds: vec![Kind::Reaction, Kind::ShortTextNote],
             references_to_event: vec![
@@ -138,7 +142,8 @@ where
         })
         .await
         .expect("valid")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
     assert_eq!(related_events.len(), 3);
 }
@@ -153,26 +158,28 @@ where
         .try_into()
         .expect("pk");
 
-    let events = db
+    let events: Vec<Event> = db
         .get_by_filter(Filter {
             ids: vec![id.clone()],
             ..Default::default()
         })
         .await
         .expect("events")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
 
     assert_eq!(events.len(), 1);
 
-    let related_events = db
+    let related_events: Vec<Event> = db
         .get_by_filter(Filter {
             references_to_event: vec![id],
             ..Default::default()
         })
         .await
         .expect("valid")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
     assert_eq!(related_events.len(), 2_538);
 
@@ -200,11 +207,12 @@ where
         ],
         ..Default::default()
     };
-    let records = db
+    let records: Vec<Event> = db
         .get_by_filter(query)
         .await
         .expect("valid")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
     assert_eq!(records.len(), 27);
 }
@@ -222,11 +230,12 @@ where
         ],
         ..Default::default()
     };
-    let records = db
+    let records: Vec<Event> = db
         .get_by_filter(query)
         .await
         .expect("valid")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
     assert_eq!(records.len(), 3);
 }
@@ -245,11 +254,12 @@ where
         kinds: vec![Kind::ShortTextNote, Kind::Reaction],
         ..Default::default()
     };
-    let records = db
+    let records: Vec<Event> = db
         .get_by_filter(query)
         .await
         .expect("iterator")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
     assert_eq!(records.len(), 2);
 }
@@ -263,12 +273,14 @@ where
         kinds: vec![Kind::ShortTextNote],
         ..Default::default()
     };
-    let records = db
+    let records: Vec<Event> = db
         .get_by_filter(query)
         .await
         .expect("valid")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
+
     assert_eq!(records.len(), 1_511);
     records
         .iter()
@@ -282,27 +294,30 @@ where
 {
     setup_db(db).await;
 
-    let ids = db
+    let events_from_filter: Vec<Event> = db
         .get_by_filter(Filter {
             limit: 10,
             ..Default::default()
         })
         .await
         .expect("valid")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
 
-    for event in ids.iter() {
+    for event in events_from_filter.iter() {
         db.set_local_event(event).await.expect("valid");
     }
 
-    assert_eq!(10, ids.len());
+    assert_eq!(10, events_from_filter.len());
 
-    let records = db
+    let records: Vec<Event> = db
         .get_local_events(None)
         .await
         .expect("valid iterator")
-        .collect::<Result<Vec<_>, _>>()
+        .try_collect()
+        .await
         .expect("valid");
+
     assert_eq!(10, records.len())
 }

+ 1 - 0
crates/storage/rocksdb/Cargo.toml

@@ -16,6 +16,7 @@ rocksdb = { version = "0.20.1", features = [
 chrono = "0.4.26"
 serde_json = "1.0"
 async-trait = "0.1.81"
+futures = "0.3.30"
 
 [dev-dependencies]
 nostr-rs-storage-base = { path = "../base", features = ["test"] }

+ 76 - 44
crates/storage/rocksdb/src/iterator.rs

@@ -1,9 +1,24 @@
 //! Rocks DB implementation of the storage layer
 use crate::{event_filter::EventFilter, RocksDb};
+use futures::{Future, FutureExt, Stream};
 use nostr_rs_storage_base::{Error, Storage};
 use nostr_rs_types::types::Event;
 use rocksdb::{BoundColumnFamily, DBIteratorWithThreadMode, DB};
-use std::{collections::VecDeque, sync::Arc};
+use std::{
+    collections::VecDeque,
+    pin::Pin,
+    sync::Arc,
+    task::{Context, Poll},
+};
+
+type CurrentEventByPrefixFuture<'a> = Pin<
+    Box<
+        dyn Future<
+                Output = Result<Option<nostr_rs_types::types::Event>, nostr_rs_storage_base::Error>,
+            > + Send
+            + 'a,
+    >,
+>;
 
 pub struct WrapperIterator<'a> {
     /// Reference to the rocks db database. This is useful to load the event
@@ -29,6 +44,8 @@ pub struct WrapperIterator<'a> {
     pub prefixes: VecDeque<Vec<u8>>,
     pub limit: Option<usize>,
     pub returned: usize,
+
+    pub current_event_by_prefix: Option<CurrentEventByPrefixFuture<'a>>,
 }
 
 impl<'a> WrapperIterator<'a> {
@@ -47,64 +64,79 @@ impl<'a> WrapperIterator<'a> {
     }
 }
 
-impl<'a> Iterator for WrapperIterator<'a> {
+impl<'a> Stream for WrapperIterator<'a> {
     type Item = Result<Event, Error>;
 
-    fn next(&mut self) -> Option<Self::Item> {
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        (0, None)
+    }
+
+    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
         if Some(self.returned) == self.limit {
-            return None;
+            return Poll::Ready(None);
         }
-        if self.secondary_index_iterator.is_none() {
-            if self.namespace.is_some() {
-                self.select_next_prefix_using_secondary_index()?;
-            } else {
-                // No secondary index is used to query, this means the query is
-                // using the ID filter, so it is more efficient to use the
-                // primary index to prefetch events that may satisfy the query
-                loop {
-                    let prefix = self.prefixes.pop_front()?;
-                    if let Ok(Some(event)) = self.db.get_event(prefix).await {
-                        if let Some(filter) = &self.filter {
-                            if filter.check_event(&event) {
-                                self.returned += 1;
-                                return Some(Ok(event));
-                            }
-                        } else {
-                            self.returned += 1;
-                            return Some(Ok(event));
+
+        let this = Pin::into_inner(self);
+        let db = this.db;
+
+        if let Some(mut current_event_filter) = this.current_event_by_prefix.take() {
+            match current_event_filter.poll_unpin(cx) {
+                Poll::Ready(Ok(Some(event))) => {
+                    // event is ready, apply the neccesary filters
+                    if let Some(filter) = &this.filter {
+                        if filter.check_event(&event) {
+                            this.returned += 1;
+                            return Poll::Ready(Some(Ok(event)));
                         }
+                    } else {
+                        this.returned += 1;
+                        return Poll::Ready(Some(Ok(event)));
                     }
                 }
+                Poll::Ready(Err(x)) => return Poll::Ready(Some(Err(x))),
+                Poll::Pending => {
+                    // add it back
+                    this.current_event_by_prefix = Some(current_event_filter);
+                    return Poll::Pending;
+                }
+                _ => {}
             }
         }
+        let secondary_index = if let Some(iterator) = this.secondary_index_iterator.as_mut() {
+            iterator
+        } else {
+            return Poll::Ready(None);
+        };
 
-        loop {
-            loop {
-                let secondary_index = self.secondary_index_iterator.as_mut()?;
-                let (key, value) = match secondary_index.next() {
-                    Some(Ok((k, v))) => (k, v),
-                    _ => {
-                        // break this loop to select next available prefix
-                        break;
+        match secondary_index.next() {
+            Some(Ok((key, value))) => {
+                if !key.starts_with(&this.current_prefix) {
+                    if this.select_next_prefix_using_secondary_index().is_none() {
+                        return Poll::Ready(None);
                     }
-                };
-                if !key.starts_with(&self.current_prefix) {
-                    break;
+                } else {
+                    // query the database to get the record
+                    this.current_event_by_prefix = Some(db.get_event(value));
                 }
-                if let Ok(Some(event)) = self.db.get_event(value) {
-                    if let Some(filter) = &self.filter {
-                        if filter.check_event(&event) {
-                            self.returned += 1;
-                            return Some(Ok(event));
-                        }
-                    } else {
-                        self.returned += 1;
-                        return Some(Ok(event));
+
+                Poll::Pending
+            }
+            Some(Err(err)) => Poll::Ready(Some(Err(Error::Internal(err.to_string())))),
+            None => {
+                if this.namespace.is_some() {
+                    if this.select_next_prefix_using_secondary_index().is_none() {
+                        return Poll::Ready(None);
                     }
+                } else {
+                    // No secondary index is used to query, this means the query is
+                    // using the ID filter, so it is more efficient to use the
+                    // primary index to prefetch events that may satisfy the query
+                    let current_event_by_prefix =
+                        this.prefixes.pop_front().map(|prefix| db.get_event(prefix));
+                    this.current_event_by_prefix = current_event_by_prefix;
                 }
+                Poll::Pending
             }
-            // Select next prefix if available, or exists
-            self.select_next_prefix_using_secondary_index()?;
         }
     }
 }

+ 8 - 5
crates/storage/rocksdb/src/lib.rs

@@ -59,7 +59,7 @@ impl RocksDb {
                 ColumnFamilyDescriptor::new(ReferenceType::Stream.as_str(), options.clone()),
             ],
         )
-        .map_err(|e| Error::Unknown(e.to_string()))?;
+        .map_err(|e| Error::Internal(e.to_string()))?;
         Ok(Self { db })
     }
 
@@ -85,13 +85,14 @@ impl RocksDb {
     ) -> Result<Arc<BoundColumnFamily>, Error> {
         self.db
             .cf_handle(namespace.as_str())
-            .ok_or(Error::InvalidColumnFamily)
+            .ok_or(Error::Internal("Unknown db-family".to_owned()))
     }
 }
 
 #[async_trait::async_trait]
 impl Storage for RocksDb {
     type Stream<'a> = WrapperIterator<'a>;
+
     async fn get_local_events(&self, limit: Option<usize>) -> Result<WrapperIterator<'_>, Error> {
         let cf_handle = self.reference_to_cf_handle(ReferenceType::LocalEvents)?;
         Ok(WrapperIterator {
@@ -103,6 +104,7 @@ impl Storage for RocksDb {
             prefixes: VecDeque::new(),
             limit,
             returned: 0,
+            current_event_by_prefix: None,
         })
     }
 
@@ -115,7 +117,7 @@ impl Storage for RocksDb {
                 secondary_index.index_by([]),
                 event_id.deref(),
             )
-            .map_err(|e| Error::Unknown(e.to_string()))?;
+            .map_err(|e| Error::Internal(e.to_string()))?;
         Ok(())
     }
 
@@ -196,7 +198,7 @@ impl Storage for RocksDb {
 
         self.db
             .write(buffer)
-            .map_err(|e| Error::Unknown(e.to_string()))?;
+            .map_err(|e| Error::Internal(e.to_string()))?;
 
         Ok(true)
     }
@@ -205,7 +207,7 @@ impl Storage for RocksDb {
         Ok(self
             .db
             .get_cf(&self.reference_to_cf_handle(ReferenceType::Events)?, id)
-            .map_err(|e| Error::Unknown(e.to_string()))?
+            .map_err(|e| Error::Internal(e.to_string()))?
             .map(|event| serde_json::from_slice(&event))
             .transpose()?)
     }
@@ -276,6 +278,7 @@ impl Storage for RocksDb {
             prefixes,
             returned: 0,
             limit,
+            current_event_by_prefix: None,
         })
 
         //load_events_and_filter(self, query, event_ids, for_each)