Browse Source

Make client resilient to errors and reconnections

Cesar Rodas 2 years ago
parent
commit
650c0ce68f
2 changed files with 177 additions and 73 deletions
  1. 87 31
      crates/client/src/lib.rs
  2. 90 42
      src/main.rs

+ 87 - 31
crates/client/src/lib.rs

@@ -1,11 +1,14 @@
 use futures_util::{stream::FuturesUnordered, SinkExt, StreamExt};
-use nostr_rs_types::{Request, Response};
+use nostr_rs_types::{client::Subscribe, Request, Response};
 use parking_lot::RwLock;
 use std::{collections::HashMap, sync::Arc};
-use tokio::sync::{
-    broadcast,
-    mpsc::{self, error::SendError},
-    oneshot,
+use tokio::{
+    sync::{
+        broadcast,
+        mpsc::{self, error::SendError},
+        oneshot,
+    },
+    time::{sleep, Duration},
 };
 use tokio_tungstenite::{
     connect_async, tungstenite::error::Error as TungsteniteError, tungstenite::Message,
@@ -51,52 +54,87 @@ impl Client {
         mut receiver: mpsc::Receiver<Request>,
         url_str: &str,
     ) -> Result<(broadcast::Receiver<(String, Response)>, oneshot::Sender<()>), Error> {
-        let url = Url::parse(url_str)?;
         let (response_sender, response_receiver) = broadcast::channel(10_000);
         let (stopper_sender, mut stopper_recv) = oneshot::channel();
 
-        let url_str = url_str.to_owned();
+        let url = url_str.to_owned();
+        let url_parsed = Url::parse(&url)?;
 
         tokio::spawn(async move {
-            let (mut socket, _) = connect_async(url).await.expect("valid connection");
-            loop {
-                tokio::select! {
-                    Ok(()) = &mut stopper_recv => {
-                        println!("Breaking client");
-                        break;
-                    },
-                    Some(msg) = receiver.recv() => {
-                        if let Ok(json) = serde_json::to_string(&msg) {
-                            socket.send(Message::Text(json)).await.unwrap();
-                        }
+            let mut reconnect = true;
+            let mut retries: usize = 0;
+            while reconnect && retries <= 10 {
+                println!("{}: Connect attempt {}", url, retries);
+                retries += 1;
+                let mut socket = if let Ok(x) = connect_async(url_parsed.clone()).await {
+                    x.0
+                } else {
+                    println!("{}: Failed to connect", url);
+                    sleep(Duration::from_secs(1)).await;
+                    continue;
+                };
+
+                let request: Request = Subscribe::default().into();
+                if let Ok(json) = serde_json::to_string(&request) {
+                    if let Err(err) = socket.send(Message::Text(json)).await {
+                        println!("{}: Failed to send request {}", url, err);
+                        continue;
                     }
-                    Some(Ok(msg)) = socket.next() => {
-                        let msg =if let Ok(msg) = msg.into_text() {
-                                msg
-                            } else {
-                                continue;
-                            };
+                }
 
-                        if msg.is_empty() {
-                            continue;
+                loop {
+                    tokio::select! {
+                        Ok(()) = &mut stopper_recv => {
+                            println!("{}: Breaking client due external signal", url);
+                            reconnect = false;
+                            break;
+                        },
+                        Some(msg) = receiver.recv() => {
+                            if let Ok(json) = serde_json::to_string(&msg) {
+                                if let Err(x) = socket.send(Message::Text(json)).await {
+                                    println!("{} :Reconnecting due {}", url, x);
+                                    break;
+
+                                }
+                            }
                         }
+                        Some(Ok(msg)) = socket.next() => {
+                            let msg =if let Ok(msg) = msg.into_text() {
+                                    msg
+                                } else {
+                                    continue;
+                                };
+
+                            if msg.is_empty() {
+                                continue;
+                            }
 
-                        let msg: Result<Response, _> = serde_json::from_str(&msg);
+                            let msg: Result<Response, _> = serde_json::from_str(&msg);
 
-                        if let Ok(msg) = msg {
-                            if let Err(error) = response_sender.send((url_str.to_owned(), msg)) {
-                                println!("Disconnecting client because of {}", error);
-                                break;
+                            if let Ok(msg) = msg {
+                                if let Err(error) = response_sender.send((url.to_owned(), msg)) {
+                                    println!("{}: Reconnecting client because of {}", url, error);
+                                    break;
+                                }
                             }
                         }
+                        else => {
+                            println!("{}: else", url);
+                        }
                     }
                 }
             }
+
+            println!("{}: Disconnected", url);
         });
 
         Ok((response_receiver, stopper_sender))
     }
 
+    pub fn is_running(&self) -> bool {
+        !self.stopper.is_closed()
+    }
+
     pub fn subscribe(&self) -> broadcast::Receiver<(String, Response)> {
         self.receiver.resubscribe()
     }
@@ -156,6 +194,24 @@ impl Clients {
         None
     }
 
+    pub fn check_active_connections(&self) -> usize {
+        let mut clients = self.clients.write();
+        let mut subscriptions = self.subscriptions.write();
+        let mut to_remove = vec![];
+        for (url, client) in clients.iter() {
+            if !client.is_running() {
+                to_remove.push(url.to_owned());
+            }
+        }
+
+        for url in to_remove.iter() {
+            clients.remove(url);
+        }
+
+        subscriptions.retain(|s| s.len() > 0);
+        clients.len()
+    }
+
     pub async fn send(&self, request: Request) {
         let senders = self
             .clients

+ 90 - 42
src/main.rs

@@ -4,7 +4,7 @@ use nostr_rs_client::{Clients, Error as ClientError};
 use nostr_rs_types::{
     client::{Close, Subscribe},
     types::{Addr, Content, Event, Filter, Kind, SubscriptionId, Tag},
-    Request, Response,
+    Response,
 };
 use sqlx::{query, FromRow, Pool, Sqlite, SqlitePool};
 use tokio::time::{sleep, Duration};
@@ -16,6 +16,7 @@ struct PublicKeys {
 
 #[derive(Clone, FromRow, Debug)]
 struct Relayer {
+    pub id: i64,
     pub url: String,
 }
 
@@ -45,7 +46,7 @@ async fn discover_relayers(conn: Pool<Sqlite>) -> Result<(), Error> {
     .fetch_all(&mut tx)
     .await?;
 
-    let mut relayers = HashMap::new();
+    let mut relayers: HashMap<String, u64> = HashMap::new();
 
     for value in values {
         let event: Result<Event, _> = serde_json::from_str(&value.event);
@@ -54,7 +55,14 @@ async fn discover_relayers(conn: Pool<Sqlite>) -> Result<(), Error> {
                 match event.content() {
                     Content::Contacts(x) => x
                         .iter()
-                        .map(|(relayer, _)| relayers.insert(relayer.to_ascii_lowercase(), 1))
+                        .map(|(relayer, _)| {
+                            let relayer = relayer.to_ascii_lowercase();
+                            if !relayer.is_empty() {
+                                if relayers.get_mut(&relayer).map(|x| *x += 1).is_none() {
+                                    relayers.insert(relayer, 1);
+                                }
+                            }
+                        })
                         .for_each(drop),
                     _ => {}
                 }
@@ -62,15 +70,21 @@ async fn discover_relayers(conn: Pool<Sqlite>) -> Result<(), Error> {
                     match tag {
                         Tag::PubKey(pubkey) => {
                             if let Some(relayer) = pubkey.relayer_url.as_ref() {
+                                let relayer = relayer.to_ascii_lowercase();
                                 if !relayer.is_empty() {
-                                    relayers.insert(relayer.to_ascii_lowercase(), 1);
+                                    if relayers.get_mut(&relayer).map(|x| *x += 1).is_none() {
+                                        relayers.insert(relayer, 1);
+                                    }
                                 }
                             }
                         }
                         Tag::Event(tag) => {
                             if let Some(relayer) = tag.relayer_url.as_ref() {
+                                let relayer = relayer.to_ascii_lowercase();
                                 if !relayer.is_empty() {
-                                    relayers.insert(relayer.to_ascii_lowercase(), 1);
+                                    if relayers.get_mut(&relayer).map(|x| *x += 1).is_none() {
+                                        relayers.insert(relayer, 1);
+                                    }
                                 }
                             }
                         }
@@ -87,11 +101,25 @@ async fn discover_relayers(conn: Pool<Sqlite>) -> Result<(), Error> {
             .await?;
     }
 
-    for relayer in relayers.keys() {
-        let _ = sqlx::query("INSERT INTO relayers (url) values(?)")
-            .bind(relayer)
-            .execute(&mut tx)
-            .await;
+    for (url, weight) in relayers.iter() {
+        let _ = sqlx::query(
+            r#"
+            INSERT INTO relayers (url, weight) values(?, 0)
+            "#,
+        )
+        .bind(url)
+        .execute(&mut tx)
+        .await;
+
+        let _ = sqlx::query(
+            r#"
+            UPDATE relayers SET weight = weight + ? WHERE url = ?
+            "#,
+        )
+        .bind(weight.to_string())
+        .bind(url)
+        .execute(&mut tx)
+        .await;
     }
 
     tx.commit().await?;
@@ -143,21 +171,21 @@ async fn process_events(conn: Pool<Sqlite>) -> Result<(), Error> {
     Ok(())
 }
 
-async fn request_profiles_from_db(clients: Clients, conn: Pool<Sqlite>) -> Result<(), Error> {
+async fn request_profiles_from_db(
+    clients: Clients,
+    conn: Pool<Sqlite>,
+    skip: usize,
+) -> Result<usize, Error> {
     let public_keys = sqlx::query_as::<_, PublicKeys>(
         r#"
         SELECT
             distinct public_key
         FROM
             events
-        WHERE public_key NOT IN (
-            SELECT public_key
-            FROM events
-            WHERE kind = 0
-        )
-        LIMIT 50
+        LIMIT ?, 50
         "#,
     )
+    .bind(skip.to_string())
     .fetch_all(&conn)
     .await?
     .iter()
@@ -165,8 +193,9 @@ async fn request_profiles_from_db(clients: Clients, conn: Pool<Sqlite>) -> Resul
     .collect::<Result<Vec<Addr>, _>>()?;
 
     let subscription_id: SubscriptionId = "fetch_profiles".try_into().unwrap();
+    let len = public_keys.len();
 
-    println!("Fetching {} profiles", public_keys.len());
+    println!("Fetching {} profiles (skip = {})", len, skip);
     clients
         .send(
             Subscribe {
@@ -193,7 +222,7 @@ async fn request_profiles_from_db(clients: Clients, conn: Pool<Sqlite>) -> Resul
     let _ = clients.send(Close(subscription_id).into()).await;
     println!("Remove listener");
 
-    Ok(())
+    Ok(len)
 }
 
 #[tokio::main]
@@ -202,11 +231,6 @@ async fn main() {
     let conn = SqlitePool::connect("sqlite://./db.sqlite").await.unwrap();
     let clients = Clients::default();
 
-    clients
-        .connect_to("wss://relay.damus.io")
-        .await
-        .expect("register");
-
     let _ = query(
         r#"
     CREATE TABLE events(
@@ -214,9 +238,14 @@ async fn main() {
         public_key varchar(64) not null,
         kind int,
         event text,
+        discovered_at datetime,
+        discovered_by int,
+        created_at datetime,
         processed INT DEFAULT 0
     );
     CREATE INDEX events_processed_index ON events (processed);
+    CREATE INDEX events_public_key_index ON events (public_key);
+    CREATE INDEX events_kind_index ON events (kind);
     CREATE TABLE relationships (
         id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
         source_id VARCHAR(64) NOT NULL,
@@ -226,9 +255,11 @@ async fn main() {
     CREATE INDEX relationships_source_id_index ON relationships (source_id);
     CREATE TABLE relayers (
         id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
-        url VARCHAR(64) NOT NULL
+        url VARCHAR(64) NOT NULL,
+        weight INT NOT NULL DEFAULT '0'
     );
     CREATE UNIQUE INDEX url ON relayers (url);
+    CREATE INDEX relayers_weight_index ON relayers (weight);
     "#,
     )
     .execute(&conn)
@@ -237,42 +268,53 @@ async fn main() {
     let clients_for_worker = clients.clone();
     let conn_for_worker = conn.clone();
     tokio::spawn(async move {
+        let mut i = 0;
         loop {
             let _ =
-                request_profiles_from_db(clients_for_worker.clone(), conn_for_worker.clone()).await;
+                request_profiles_from_db(clients_for_worker.clone(), conn_for_worker.clone(), i)
+                    .await
+                    .map(|count| i += count);
         }
     });
 
     let conn_for_worker = conn.clone();
     tokio::spawn(async move {
         loop {
-            let _ = process_events(conn_for_worker.clone()).await;
-            let _ = discover_relayers(conn_for_worker.clone()).await;
+            let r1 = process_events(conn_for_worker.clone()).await;
+            let r2 = discover_relayers(conn_for_worker.clone()).await;
+            println!("Processed events {:?} {:?}", r1, r2);
             sleep(Duration::from_millis(5_000)).await;
         }
     });
 
+    let clients_for_worker = clients.clone();
+    tokio::spawn(async move {
+        loop {
+            sleep(Duration::from_millis(5_000)).await;
+            println!(
+                "Active connections: {}",
+                clients_for_worker.check_active_connections()
+            );
+        }
+    });
+
     let relayers = sqlx::query_as::<_, Relayer>(
-        r#"select url from relayers where url like 'wss://%.%/' limit 20"#,
+        r#"select id, url from relayers where url like 'wss://%.%/' order by weight desc limit 30"#,
     )
     .fetch_all(&conn)
     .await
-    .expect("query");
+    .expect("query")
+    .iter()
+    .map(|r| (r.url.clone(), r.id))
+    .collect::<HashMap<String, i64>>();
 
-    for relayer in relayers {
-        let _ = clients.connect_to(&relayer.url).await;
+    for (relayer, _) in relayers.iter() {
+        let _ = clients.connect_to(&relayer).await;
     }
 
-    let request: Request = Subscribe::default().into();
-
-    clients.send(request).await;
-
     loop {
-        if let Some((hostname, msg)) = clients.recv().await {
+        if let Some((relayed_by, msg)) = clients.recv().await {
             match msg {
-                Response::Notice(n) => {
-                    panic!("Error: {}", &*n);
-                }
                 Response::EndOfStoredEvents(x) => {
                     let subscription_id = &*x;
                     if &subscription_id[0..5] == "temp:" {
@@ -292,17 +334,23 @@ async fn main() {
                         //println!("Skip storing: {}", event.id.to_string(),);
                         continue;
                     }
+                    let created_at = event.inner.created_at.to_rfc3339();
                     let _ = query(
-                        r#"INSERT INTO events(id, public_key, kind, event) VALUES(?, ?, ?, ?)"#,
+                        r#"
+                        INSERT INTO events(id, public_key, kind, event, discovered_by, discovered_at, created_at)
+                        VALUES(?, ?, ?, ?, ?, strftime('%Y-%m-%d %H:%M:%S','now'), ?)
+                        "#,
                     )
                     .bind(event.id.to_string())
                     .bind(event.inner.public_key.to_string())
                     .bind(kind.to_string())
                     .bind(serde_json::to_string(&event).unwrap())
+                    .bind(relayers.get(&relayed_by).unwrap_or(&0))
+                    .bind(created_at)
                     .execute(&conn)
                     .await;
 
-                    println!("Stored: {} (from {})", event.id.to_string(), hostname);
+                    //println!("Stored: {} (from {})", event.id.to_string(), relayed_by);
                 }
                 _ => {}
             };