Selaa lähdekoodia

Add support for sqlx with SQLite

Cesar Rodas 2 kuukautta sitten
vanhempi
säilyke
b82c260f17

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 766 - 36
Cargo.lock


+ 1 - 1
Cargo.toml

@@ -9,7 +9,7 @@ members = [
     "crates/client",
     "crates/relayer",
     "crates/storage/base",
-    "crates/storage/rocksdb", "crates/dump", "crates/storage/memory", "crates/personal-relayer", "crates/subscription-manager",
+    "crates/storage/rocksdb", "crates/dump", "crates/storage/memory", "crates/personal-relayer", "crates/subscription-manager", "crates/storage/sqlite",
 ]
 
 [dependencies]

+ 6 - 1
crates/storage/base/src/storage.rs

@@ -1,3 +1,5 @@
+use std::hash::Hash;
+
 use crate::Error;
 use futures::Stream;
 use nostr_rs_types::types::{Event, Filter};
@@ -24,7 +26,10 @@ pub trait Storage: Send + Sync {
     async fn set_local_event(&self, event: &Event) -> Result<(), Error>;
 
     /// Returns an event by its ID
-    async fn get_event<T: AsRef<[u8]> + Send + Sync>(&self, id: T) -> Result<Option<Event>, Error>;
+    async fn get_event<T: AsRef<[u8]> + Send + Hash + Sync>(
+        &self,
+        id: T,
+    ) -> Result<Option<Event>, Error>;
 
     /// Get events from the database with a given filter
     ///

+ 20 - 0
crates/storage/sqlite/Cargo.toml

@@ -0,0 +1,20 @@
+[package]
+name = "sqlite"
+version = "0.1.0"
+edition = "2021"
+
+[dependencies]
+async-trait = "0.1.81"
+cuckoofilter = "0.5.0"
+futures = "0.3.30"
+hex = "0.4.3"
+nostr-rs-storage-base = { path = "../base" }
+nostr-rs-subscription-manager = { path = "../../subscription-manager" }
+nostr-rs-types = { path = "../../types" }
+serde = "1.0.210"
+serde_json = "1.0.128"
+sqlx = { version = "0.8.2", features = ["runtime-tokio", "sqlite"] }
+tokio = { version = "1.40.0", features = ["full"] }
+
+[dev-dependencies]
+nostr-rs-storage-base = { path = "../base", features = ["test"] }

+ 484 - 0
crates/storage/sqlite/src/lib.rs

@@ -0,0 +1,484 @@
+use cuckoofilter::CuckooFilter;
+use futures::{Stream, StreamExt};
+use nostr_rs_storage_base::{Error, Storage};
+use nostr_rs_subscription_manager::SortedFilter;
+use nostr_rs_types::types::{Event, Filter};
+use sqlx::{
+    error::ErrorKind,
+    pool::PoolOptions,
+    sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqliteSynchronous},
+    Pool, QueryBuilder, Row, Sqlite,
+};
+use std::{
+    hash::{DefaultHasher, Hash},
+    marker::PhantomData,
+    pin::Pin,
+    str::FromStr,
+    sync::{
+        atomic::{AtomicUsize, Ordering},
+        Arc,
+    },
+    task::{Context, Poll},
+    time::Duration,
+};
+use tokio::{
+    sync::{mpsc, RwLock},
+    task::JoinHandle,
+    time::sleep,
+};
+
+pub struct SQLite {
+    event_db: SqlitePool,
+    index_db: SqlitePool,
+    indexers_running: Arc<AtomicUsize>,
+    filter: Arc<RwLock<Option<CuckooFilter<DefaultHasher>>>>,
+}
+
+impl SQLite {
+    pub async fn new(db_path1: &str, db_path2: &str) -> Result<Self, sqlx::Error> {
+        // Configure options for the first database
+        let connect_options1 = SqliteConnectOptions::from_str(db_path1)?
+            .create_if_missing(true)
+            .journal_mode(SqliteJournalMode::Wal)
+            .synchronous(SqliteSynchronous::Normal)
+            .pragma("temp_store", "MEMORY")
+            .pragma("mmap_size", "30000000000")
+            .pragma("page_size", "32768")
+            .pragma("journal_size_limit", "6144000");
+
+        // Configure options for the second database
+        let connect_options2 = SqliteConnectOptions::from_str(db_path2)?
+            .create_if_missing(true)
+            .journal_mode(SqliteJournalMode::Wal)
+            .synchronous(SqliteSynchronous::Normal)
+            .pragma("temp_store", "MEMORY")
+            .pragma("mmap_size", "30000000000")
+            .pragma("page_size", "32768")
+            .pragma("journal_size_limit", "6144000");
+
+        // Create connection pools for both databases
+        let event_db = PoolOptions::new()
+            .max_connections(5)
+            .connect_with(connect_options1)
+            .await?;
+
+        let index_db = PoolOptions::new()
+            .max_connections(5)
+            .connect_with(connect_options2)
+            .await?;
+
+        Self::initialize_databases(&event_db, &index_db).await?;
+
+        // Return the Storage struct containing both pools
+        Ok(Self {
+            event_db,
+            index_db,
+            indexers_running: Default::default(),
+            filter: Default::default(),
+        })
+    }
+
+    // Initialize tables in both databases
+    async fn initialize_databases(
+        event_db: &Pool<Sqlite>,
+        index_db: &Pool<Sqlite>,
+    ) -> Result<(), sqlx::Error> {
+        // Create the table `events` in the first database
+        sqlx::query(
+            "CREATE TABLE IF NOT EXISTS events (
+                id VARCHAR(64) PRIMARY KEY,
+                content JSONB NOT NULL,
+                is_valid BOOLEAN NOT NULL DEFAULT 1,
+                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+            )",
+        )
+        .execute(event_db)
+        .await?;
+
+        // Create the table `events_by_authors` in the second database
+        sqlx::query(
+            "CREATE TABLE IF NOT EXISTS events_by_authors (
+                id VARCHAR(64) PRIMARY KEY,
+                author_id VARCHAR(64) NOT NULL,
+                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+            );
+            CREATE TABLE IF NOT EXISTS local_events (
+                id VARCHAR(64) PRIMARY KEY,
+                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+            );
+            CREATE INDEX IF NOT EXISTS sorted_local_events ON local_events (created_at);
+            CREATE TABLE IF NOT EXISTS event_index (
+                id INTEGER PRIMARY KEY AUTOINCREMENT,
+                event_id VARCHAR(64) NOT NULL,
+                author_id VARCHAR(64) NOT NULL,
+                kind INT NOT NULL,
+                tag_name VARCHAR(64) DEFAULT NULL,
+                tag_value VARCHAR(64) DEFAULT NULL,
+                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+            );
+            CREATE INDEX IF NOT EXISTS by_id ON event_index (event_id, created_at DESC);
+            CREATE INDEX IF NOT EXISTS by_author_id ON event_index (author_id, kind, created_at DESC);
+            CREATE INDEX IF NOT EXISTS by_tag ON event_index (tag_name, tag_value, created_at DESC);
+            CREATE INDEX IF NOT EXISTS sorted ON event_index (tag_name, tag_value, created_at DESC);
+            ",
+        )
+        .execute(index_db)
+        .await?;
+
+        Ok(())
+    }
+}
+
+pub struct Cursor<'a> {
+    receiver: mpsc::Receiver<Result<Event, Error>>,
+    join_handle: JoinHandle<()>,
+    _phantom: PhantomData<&'a ()>,
+}
+
+impl<'a> Drop for Cursor<'a> {
+    fn drop(&mut self) {
+        self.join_handle.abort();
+    }
+}
+
+impl<'a> Stream for Cursor<'a> {
+    type Item = Result<Event, Error>;
+
+    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        let this = Pin::into_inner(self);
+        this.receiver.poll_recv(cx)
+    }
+}
+
+impl SQLite {
+    fn build_index(indexing: Arc<AtomicUsize>, pool: Pool<Sqlite>, event: Event, is_retry: bool) {
+        if !is_retry {
+            indexing.fetch_add(1, Ordering::Relaxed);
+        }
+
+        tokio::spawn(async move {
+            let mut indexes = vec![];
+
+            let event_id = event.id.to_string();
+            let author_id = event.author().to_string();
+            let created_at = event.created_at().timestamp();
+            let kind: u32 = event.kind().into();
+
+            let mut index = QueryBuilder::<Sqlite>::new(
+                "INSERT INTO event_index(event_id, author_id, kind, created_at) VALUES(",
+            );
+
+            let mut sep = index.separated(",");
+            sep.push_bind(&event_id)
+                .push_bind(&author_id)
+                .push_bind(kind)
+                .push_bind(created_at);
+
+            index.push(")");
+
+            indexes.push(index);
+
+            for tag in event.tags() {
+                let tag_name = tag.get_identifier();
+                if let Some(tag_value) = tag.get_indexable_value() {
+                    let mut index =
+                        QueryBuilder::<Sqlite>::new("INSERT INTO event_index(event_id, author_id, kind, tag_name, tag_value, created_at) VALUES(");
+                    let mut sep = index.separated(",");
+                    sep.push_bind(&event_id)
+                        .push_bind(&author_id)
+                        .push_bind(kind)
+                        .push_bind(tag_name)
+                        .push_bind(tag_value.to_string())
+                        .push_bind(created_at);
+
+                    index.push(")");
+
+                    indexes.push(index);
+                }
+            }
+
+            let mut tx = loop {
+                if let Ok(tx) = pool.begin().await {
+                    break tx;
+                }
+                sleep(Duration::from_millis(1)).await;
+            };
+
+            for mut index in indexes {
+                if index.build().execute(&mut *tx).await.is_err() {
+                    return Self::build_index(indexing, pool, event, true);
+                }
+            }
+
+            if tx.commit().await.is_err() {
+                return Self::build_index(indexing, pool, event, true);
+            }
+
+            indexing.fetch_sub(1, Ordering::Relaxed);
+        });
+    }
+
+    fn create_cursor(
+        &self,
+        sql_query: String,
+        args: Vec<String>,
+        filter: Option<SortedFilter>,
+    ) -> Cursor<'_> {
+        let event_db = self.event_db.clone();
+        let index_db = self.index_db.clone();
+        let (sender, receiver) = mpsc::channel(1_000);
+        let join_handle = tokio::spawn(async move {
+            let mut cursor = args
+                .into_iter()
+                .fold(sqlx::query(&sql_query), |sql, arg| sql.bind(arg))
+                .fetch(&index_db);
+
+            while let Some(Ok(row)) = cursor.next().await {
+                let event = sqlx::query("SELECT content FROM events WHERE id = ?")
+                    .bind(row.get::<String, _>(0))
+                    .fetch_one(&event_db)
+                    .await
+                    .map_err(|e| Error::Internal(e.to_string()))
+                    .and_then(|row| {
+                        row.try_get::<serde_json::Value, _>(0)
+                            .map_err(|e| Error::Internal(e.to_string()))
+                    })
+                    .and_then(|json| {
+                        serde_json::from_value::<Event>(json)
+                            .map_err(|e| Error::Internal(e.to_string()))
+                    });
+
+                if let Ok(event) = &event {
+                    if filter
+                        .as_ref()
+                        .map(|f| !f.check_event(&event))
+                        .unwrap_or_default()
+                    {
+                        continue;
+                    }
+                }
+
+                if sender.try_send(event).is_err() {
+                    break;
+                }
+            }
+        });
+
+        Cursor {
+            receiver,
+            join_handle,
+            _phantom: PhantomData,
+        }
+    }
+}
+
+#[async_trait::async_trait]
+impl Storage for SQLite {
+    type Cursor<'a> = Cursor<'a>;
+
+    async fn get_local_events(&self, limit: Option<usize>) -> Result<Cursor<'_>, Error> {
+        let (sql_query, args) = if let Some(limit) = limit {
+            (
+                "SELECT * FROM local_events LIMIT ? ORDER BY created_at DESC",
+                vec![limit.to_string()],
+            )
+        } else {
+            (
+                "SELECT * FROM local_events ORDER BY created_at DESC",
+                vec![],
+            )
+        };
+
+        Ok(self.create_cursor(sql_query.to_owned(), args, None))
+    }
+
+    /// In order database implementations, a lot more work is done in order to
+    /// select the best index with fewer candidates, in this implementation all
+    /// that logic work is offloaded to the database engine.
+    async fn get_by_filter(&self, filter: Filter) -> Result<Self::Cursor<'_>, Error> {
+        let mut args = vec![];
+        let mut where_stmt = vec![];
+
+        if !filter.authors.is_empty() {
+            where_stmt.push(format!(
+                "author_id IN ({})",
+                "?,".repeat(filter.authors.len()).trim_end_matches(',')
+            ));
+            for arg in filter.authors.iter().map(|x| x.to_string()) {
+                args.push(arg);
+            }
+        }
+        if !filter.ids.is_empty() {
+            where_stmt.push(format!(
+                "event_id IN ({})",
+                "?,".repeat(filter.ids.len()).trim_end_matches(',')
+            ));
+            for arg in filter.ids.iter().map(|x| x.to_string()) {
+                args.push(arg);
+            }
+        }
+        if !filter.tags.is_empty() {
+            let mut where_tag = vec![];
+            for tag in &filter.tags {
+                where_tag.push(format!(
+                    "(tag_name = ? AND tag_value IN ({}))",
+                    "?,".repeat(tag.1.len()).trim_end_matches(',')
+                ));
+
+                for arg in tag.1.iter().map(|x| x.to_string()) {
+                    args.push(tag.0.clone());
+                    args.push(arg);
+                }
+            }
+            where_stmt.push(format!("({})", where_tag.join(" OR ")));
+        }
+        if !filter.kinds.is_empty() {
+            where_stmt.push(format!(
+                "kind IN ({})",
+                "?,".repeat(filter.kinds.len()).trim_end_matches(',')
+            ));
+
+            for arg in filter.kinds.iter().map(|id| {
+                let id: u32 = (*id).into();
+                id.to_string()
+            }) {
+                args.push(arg);
+            }
+        }
+
+        if let Some(since) = &filter.since {
+            where_stmt.push("created_at >= ?".to_string());
+            args.push(since.timestamp().to_string());
+        }
+
+        if let Some(until) = &filter.until {
+            where_stmt.push("created_at <= ?".to_string());
+            args.push(until.timestamp().to_string());
+        }
+
+        let limit = if filter.limit == 0 {
+            1024 * 1024
+        } else {
+            filter.limit
+        };
+
+        let sql_query = if where_stmt.is_empty() {
+            format!(
+                "SELECT DISTINCT event_id FROM event_index ORDER BY created_at DESC LIMIT {}",
+                limit
+            )
+        } else {
+            format!(
+                "SELECT DISTINCT event_id FROM event_index WHERE {} ORDER BY created_at DESC LIMIT {}",
+                where_stmt.join(" AND "),
+                limit
+            )
+        };
+
+        Ok(self.create_cursor(sql_query, args, Some(filter.into())))
+    }
+
+    fn is_flushing(&self) -> bool {
+        self.indexers_running.load(Ordering::SeqCst) > 0
+    }
+
+    async fn set_local_event(&self, event: &Event) -> Result<(), Error> {
+        let mut tx = self
+            .index_db
+            .begin()
+            .await
+            .map_err(|e| Error::Internal(e.to_string()))?;
+
+        sqlx::query("INSERT INTO local_events (id) VALUES (?)")
+            .bind(hex::encode(&event.id))
+            .execute(&mut *tx)
+            .await
+            .map_err(|e| Error::Internal(e.to_string()))?;
+
+        tx.commit()
+            .await
+            .map_err(|e| Error::Internal(e.to_string()))?;
+
+        Ok(())
+    }
+
+    async fn get_event<T: AsRef<[u8]> + Send + Sync + Hash>(
+        &self,
+        id: T,
+    ) -> Result<Option<Event>, Error> {
+        if let Some(filter) = self.filter.read().await.as_ref() {
+            if !filter.contains(&id) {
+                return Ok(None);
+            }
+        }
+
+        let row =
+            sqlx::query("SELECT content FROM events WHERE id LIKE ? ORDER BY created_at DESC")
+                .bind(format!("{}%", hex::encode(id)))
+                .fetch_optional(&self.event_db)
+                .await
+                .map_err(|e| Error::Internal(e.to_string()))?;
+
+        let row = if let Some(row) = row {
+            row
+        } else {
+            return Ok(None);
+        };
+
+        let json = row
+            .try_get::<serde_json::Value, _>(0)
+            .map_err(|e| Error::Internal(e.to_string()))?;
+
+        Ok(Some(serde_json::from_value(json)?))
+    }
+
+    async fn store(&self, event: &Event) -> Result<bool, Error> {
+        if let Some(filter) = self.filter.read().await.as_ref() {
+            if !filter.contains(&event.id) {
+                return Ok(false);
+            }
+        }
+
+        if let Err(err) = sqlx::query("INSERT INTO events VALUES(?, ?, ?, ?)")
+            .bind(event.id.to_string())
+            .bind(serde_json::to_string(event)?)
+            .bind(true)
+            .bind(event.created_at().timestamp())
+            .execute(&self.event_db)
+            .await
+        {
+            if let sqlx::Error::Database(err_db) = &err {
+                let err_kind = err_db.kind();
+                if err_kind == ErrorKind::UniqueViolation || err_kind == ErrorKind::CheckViolation {
+                    return Ok(false);
+                }
+            }
+
+            return Err(Error::Internal(err.to_string()));
+        }
+
+        Self::build_index(
+            self.indexers_running.clone(),
+            self.index_db.clone(),
+            event.to_owned(),
+            false,
+        );
+
+        Ok(true)
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    async fn new_instance(_path: &str) -> SQLite {
+        SQLite::new(":memory:", ":memory:")
+            .await
+            .expect("valid sqlite")
+    }
+
+    async fn destroy_instance(_path: &str) {}
+
+    nostr_rs_storage_base::storage_test!(SQLite, new_instance, destroy_instance);
+}

+ 8 - 0
crates/types/src/types/filter.rs

@@ -42,6 +42,14 @@ impl TagValue {
         bytes.extend_from_slice(&value);
         bytes
     }
+
+    /// Convert the value into a string
+    pub fn to_string(&self) -> String {
+        match self {
+            TagValue::Id(id) => id.to_string(),
+            TagValue::String(s) => s.clone(),
+        }
+    }
 }
 
 fn deserialize_tags<'de, D>(deserializer: D) -> Result<HashMap<String, HashSet<TagValue>>, D::Error>

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä