Просмотр исходного кода

Buffer events and their writes

Cesar Rodas 1 год назад
Родитель
Сommit
6cf3a2d206
4 измененных файлов с 114 добавлено и 73 удалено
  1. 1 0
      Cargo.lock
  2. 1 0
      Cargo.toml
  3. 31 33
      crates/client/src/lib.rs
  4. 81 40
      src/main.rs

+ 1 - 0
Cargo.lock

@@ -756,6 +756,7 @@ dependencies = [
  "futures-util",
  "nostr-rs-client",
  "nostr-rs-types",
+ "parking_lot 0.12.1",
  "serde_json",
  "sqlx",
  "thiserror",

+ 1 - 0
Cargo.toml

@@ -18,3 +18,4 @@ serde_json = "1.0.94"
 sqlx = { version = "0.6.2", features = ["sqlite", "runtime-tokio-native-tls"] }
 futures-util = "0.3.27"
 thiserror = "1.0.40"
+parking_lot = "0.12.1"

+ 31 - 33
crates/client/src/lib.rs

@@ -31,31 +31,31 @@ pub enum Error {
 pub struct Client {
     pub url: String,
 
-    pub sender: mpsc::Sender<Request>,
+    pub send_to_socket: mpsc::Sender<Request>,
 
-    receiver: broadcast::Receiver<(String, Response)>,
-    stopper: oneshot::Sender<()>,
+    recv_from_socket: broadcast::Receiver<(Response, String)>,
+    stop_service: oneshot::Sender<()>,
 }
 
 impl Client {
     pub fn new(url: &str) -> Result<Self, Error> {
-        let (sender, receiver) = mpsc::channel(10_000);
-        let (receiver, stopper) = Self::spawn(receiver, url)?;
+        let (send_to_socket, receiver) = mpsc::channel(10_000);
+        let (recv_from_socket, stop_service) = Self::spawn(receiver, url)?;
 
         Ok(Self {
             url: url.to_owned(),
-            sender,
-            stopper,
-            receiver,
+            send_to_socket,
+            stop_service,
+            recv_from_socket,
         })
     }
 
     fn spawn(
         mut receiver: mpsc::Receiver<Request>,
         url_str: &str,
-    ) -> Result<(broadcast::Receiver<(String, Response)>, oneshot::Sender<()>), Error> {
-        let (response_sender, response_receiver) = broadcast::channel(10_000);
-        let (stopper_sender, mut stopper_recv) = oneshot::channel();
+    ) -> Result<(broadcast::Receiver<(Response, String)>, oneshot::Sender<()>), Error> {
+        let (publish_to_listener, recv_from_socket) = broadcast::channel(10_000);
+        let (stop_service, mut stopper_recv) = oneshot::channel();
 
         let url = url_str.to_owned();
         let url_parsed = Url::parse(&url)?;
@@ -70,7 +70,7 @@ impl Client {
                     x.0
                 } else {
                     println!("{}: Failed to connect", url);
-                    sleep(Duration::from_secs(1)).await;
+                    sleep(Duration::from_secs(5)).await;
                     continue;
                 };
 
@@ -112,7 +112,7 @@ impl Client {
                             let msg: Result<Response, _> = serde_json::from_str(&msg);
 
                             if let Ok(msg) = msg {
-                                if let Err(error) = response_sender.send((url.to_owned(), msg)) {
+                                if let Err(error) = publish_to_listener.send((msg, url.to_owned())) {
                                     println!("{}: Reconnecting client because of {}", url, error);
                                     break;
                                 }
@@ -128,48 +128,46 @@ impl Client {
             println!("{}: Disconnected", url);
         });
 
-        Ok((response_receiver, stopper_sender))
+        Ok((recv_from_socket, stop_service))
     }
 
     pub fn is_running(&self) -> bool {
-        !self.stopper.is_closed()
+        !self.stop_service.is_closed()
     }
 
-    pub fn subscribe(&self) -> broadcast::Receiver<(String, Response)> {
-        self.receiver.resubscribe()
+    pub fn subscribe(&self) -> broadcast::Receiver<(Response, String)> {
+        self.recv_from_socket.resubscribe()
     }
 
     pub async fn send(&self, request: Request) -> Result<(), Error> {
-        Ok(self.sender.send(request).await?)
+        Ok(self.send_to_socket.send(request).await?)
     }
 
     pub async fn stop(self) {
-        let _ = self.stopper.send(());
+        let _ = self.stop_service.send(());
     }
 }
 
 #[derive(Debug, Clone)]
 pub struct Clients {
     clients: Arc<RwLock<HashMap<String, Client>>>,
-    subscriptions: Arc<RwLock<Vec<broadcast::Receiver<(String, Response)>>>>,
 }
 
 impl Default for Clients {
     fn default() -> Self {
         Self {
             clients: Arc::new(RwLock::new(HashMap::new())),
-            subscriptions: Arc::new(RwLock::new(vec![])),
         }
     }
 }
 
 impl Clients {
-    pub async fn recv(&self) -> Option<(String, Response)> {
+    pub async fn recv(&self) -> Option<(Response, String)> {
         let mut subscriptions = self
-            .subscriptions
+            .clients
             .read()
             .iter()
-            .map(|s| s.resubscribe())
+            .map(|(_, c)| c.subscribe())
             .collect::<Vec<_>>();
 
         let mut futures = FuturesUnordered::new();
@@ -184,8 +182,13 @@ impl Clients {
             None
         }
     }
-    pub fn try_recv(&self) -> Option<(String, Response)> {
-        let mut subscriptions = self.subscriptions.write();
+    pub fn try_recv(&self) -> Option<(Response, String)> {
+        let mut subscriptions = self
+            .clients
+            .read()
+            .iter()
+            .map(|(_, c)| c.subscribe())
+            .collect::<Vec<_>>();
         for sub in subscriptions.iter_mut() {
             if let Ok(msg) = sub.try_recv() {
                 return Some(msg);
@@ -196,7 +199,6 @@ impl Clients {
 
     pub fn check_active_connections(&self) -> usize {
         let mut clients = self.clients.write();
-        let mut subscriptions = self.subscriptions.write();
         let mut to_remove = vec![];
         for (url, client) in clients.iter() {
             if !client.is_running() {
@@ -208,7 +210,6 @@ impl Clients {
             clients.remove(url);
         }
 
-        subscriptions.retain(|s| s.len() > 0);
         clients.len()
     }
 
@@ -217,7 +218,7 @@ impl Clients {
             .clients
             .read()
             .iter()
-            .map(|(_, c)| c.sender.clone())
+            .map(|(_, c)| c.send_to_socket.clone())
             .collect::<Vec<mpsc::Sender<_>>>();
 
         for sender in senders.iter() {
@@ -231,10 +232,7 @@ impl Clients {
             false
         } else {
             println!("Connecting to {}", url);
-            let client = Client::new(url)?;
-            let mut subscriptions = self.subscriptions.write();
-            subscriptions.push(client.subscribe());
-            clients.insert(url.to_owned(), client);
+            clients.insert(url.to_owned(), Client::new(url)?);
             true
         })
     }

+ 81 - 40
src/main.rs

@@ -1,12 +1,12 @@
-use std::collections::HashMap;
-
 use nostr_rs_client::{Clients, Error as ClientError};
 use nostr_rs_types::{
     client::{Close, Subscribe},
     types::{Addr, Content, Event, Filter, Kind, SubscriptionId, Tag},
     Response,
 };
+use parking_lot::RwLock;
 use sqlx::{query, FromRow, Pool, Sqlite, SqlitePool};
+use std::{collections::HashMap, sync::Arc};
 use tokio::time::{sleep, Duration};
 
 #[derive(Clone, FromRow, Debug)]
@@ -220,7 +220,7 @@ async fn fetch_related_content(clients: Clients, conn: Pool<Sqlite>) -> Result<(
     tx.commit().await?;
 
     loop {
-        let data_to_fetch = sqlx::query_as::<_, ToFetch>(
+        let data_to_fetch = if let Ok(q) = sqlx::query_as::<_, ToFetch>(
             r#"
         SELECT
             id,
@@ -232,7 +232,14 @@ async fn fetch_related_content(clients: Clients, conn: Pool<Sqlite>) -> Result<(
         "#,
         )
         .fetch_all(&conn)
-        .await?;
+        .await
+        {
+            q
+        } else {
+            println!("Database locked, retrying...");
+            sleep(Duration::from_secs(1)).await;
+            continue;
+        };
 
         let mut public_keys = vec![];
         let mut ids = vec![];
@@ -308,9 +315,9 @@ async fn fetch_related_content(clients: Clients, conn: Pool<Sqlite>) -> Result<(
 
         let _ = clients.send(Close(subscription_id).into()).await;
 
-        sqlx::query("DELETE FROM to_fetch WHERE id IN (SELECT id FROM to_fetch ORDER BY refs DESC LIMIT 40)")
+        let _ = sqlx::query("DELETE FROM to_fetch WHERE id IN (SELECT id FROM to_fetch ORDER BY refs DESC LIMIT 40)")
             .execute(&conn)
-            .await?;
+            .await;
     }
 
     Ok(())
@@ -361,16 +368,18 @@ async fn main() {
     let clients_for_worker = clients.clone();
     let conn_for_worker = conn.clone();
     tokio::spawn(async move {
+        sleep(Duration::from_millis(35_000)).await;
         loop {
             let r1 =
                 fetch_related_content(clients_for_worker.clone(), conn_for_worker.clone()).await;
             println!("Fetch related content {:?}", r1);
-            sleep(Duration::from_millis(5_000)).await;
+            sleep(Duration::from_millis(5_000)).await
         }
     });
 
     let conn_for_worker = conn.clone();
     tokio::spawn(async move {
+        sleep(Duration::from_millis(55_000)).await;
         loop {
             let r1 = process_events(conn_for_worker.clone()).await;
             let r2 = discover_relayers(conn_for_worker.clone()).await;
@@ -390,42 +399,28 @@ async fn main() {
         }
     });
 
-    let relayers = sqlx::query_as::<_, Relayer>(
-        r#"select id, url from relayers where url like 'wss://%.%/' order by weight desc limit 30"#,
-    )
-    .fetch_all(&conn)
-    .await
-    .expect("query")
-    .iter()
-    .map(|r| (r.url.clone(), r.id))
-    .collect::<HashMap<String, i64>>();
+    let queue = Arc::new(RwLock::new(HashMap::<String, (Event, i64)>::new()));
+    let conn_for_queue = conn.clone();
+    let queue_for_worker = queue.clone();
 
-    for (relayer, _) in relayers.iter() {
-        let _ = clients.connect_to(&relayer).await;
-    }
+    tokio::spawn(async move {
+        loop {
+            let to_persist = { queue_for_worker.read().clone() };
 
-    loop {
-        if let Some((relayed_by, msg)) = clients.recv().await {
-            match msg {
-                Response::EndOfStoredEvents(x) => {
-                    let subscription_id = &*x;
-                    if &subscription_id[0..5] == "temp:" {
-                        // Remove listener, to avoid having too many requests at the same time
-                        let _ = clients.send(Close((*x).clone()).into()).await;
-                        println!("Remove listener: {0}", (*subscription_id).to_string());
-                    }
-                }
-                Response::Event(x) => {
-                    let event = x.event;
-                    let kind: u32 = event.inner.kind.try_into().expect("kind");
+            if let Ok(mut tx) = conn_for_queue.begin().await {
+                println!("Persisting {}", to_persist.len());
+                let mut skip: u64 = 0;
+                let mut processed: u64 = 0;
+                for (_, (event, relayer_id)) in to_persist.iter() {
                     if let Ok(Some(_)) = query(r#"SELECT id FROM events WHERE id = ?"#)
                         .bind(event.id.to_string())
-                        .fetch_optional(&conn)
+                        .fetch_optional(&mut tx)
                         .await
                     {
-                        //println!("Skip storing: {}", event.id.to_string(),);
+                        skip += 1;
                         continue;
                     }
+                    let kind: u32 = event.inner.kind.into();
                     let created_at = event.inner.created_at.to_rfc3339();
                     let _ = query(
                         r#"
@@ -437,15 +432,61 @@ async fn main() {
                     .bind(event.inner.public_key.to_string())
                     .bind(kind.to_string())
                     .bind(serde_json::to_string(&event).unwrap())
-                    .bind(relayers.get(&relayed_by).unwrap_or(&0))
+                    .bind(relayer_id)
                     .bind(created_at)
-                    .execute(&conn)
+                    .execute(&mut tx)
                     .await;
+                    processed += 1;
+                }
+                if let Ok(_) = tx.commit().await {
+                    let mut queue = queue_for_worker.write();
+                    for (id, _) in to_persist.iter() {
+                        queue.remove(id);
+                    }
+                    println!(
+                        "Persisted {}, skip = {} and (queue new size is {})",
+                        processed,
+                        skip,
+                        queue.len()
+                    );
+                }
+            }
+
+            sleep(Duration::from_millis(60_000)).await;
+        }
+    });
 
-                    //println!("Stored: {} (from {})", event.id.to_string(), relayed_by);
+    let relayers = sqlx::query_as::<_, Relayer>(
+        r#"select id, url from relayers where url like 'wss://%.%/' order by weight desc limit 30"#,
+    )
+    .fetch_all(&conn)
+    .await
+    .expect("query")
+    .iter()
+    .map(|r| (r.url.clone(), r.id))
+    .collect::<HashMap<String, i64>>();
+
+    for (relayer, _) in relayers.iter() {
+        let _ = clients.connect_to(&relayer).await;
+    }
+
+    loop {
+        if let Some((msg, relayed_by)) = clients.recv().await {
+            if let Response::Event(x) = msg {
+                let event = x.event;
+                let id = event.id.to_string();
+                let mut q = queue.write();
+
+                if !q.contains_key(&id) {
+                    q.insert(
+                        id,
+                        (
+                            event,
+                            relayers.get(&relayed_by).cloned().unwrap_or_default(),
+                        ),
+                    );
                 }
-                _ => {}
-            };
+            }
         }
     }
 }