Browse Source

Fixed bug with foreign keys

[1] https://gist.github.com/crodas/bad00997c63bd5ac58db3c5bd90747ed
Cesar Rodas 1 month ago
parent
commit
960abb63a4

+ 3 - 3
crates/cdk-integration-tests/tests/fake_wallet.rs

@@ -113,7 +113,7 @@ async fn test_fake_melt_payment_fail() -> Result<()> {
     }
 
     let wallet_bal = wallet.total_balance().await?;
-    assert!(wallet_bal == 100.into());
+    assert_eq!(wallet_bal, 100.into());
 
     Ok(())
 }
@@ -256,7 +256,7 @@ async fn test_fake_melt_payment_error_unknown() -> Result<()> {
 
     // The melt should error at the payment invoice command
     let melt = wallet.melt(&melt_quote.id).await;
-    assert!(melt.is_err());
+    assert_eq!(melt.unwrap_err().to_string(), "Payment failed");
 
     let fake_description = FakeInvoiceDescription {
         pay_invoice_state: MeltQuoteState::Unknown,
@@ -271,7 +271,7 @@ async fn test_fake_melt_payment_error_unknown() -> Result<()> {
 
     // The melt should error at the payment invoice command
     let melt = wallet.melt(&melt_quote.id).await;
-    assert!(melt.is_err());
+    assert_eq!(melt.unwrap_err().to_string(), "Payment failed");
 
     let pending = wallet
         .localstore

+ 2 - 2
crates/cdk-integration-tests/tests/regtest.rs

@@ -20,7 +20,7 @@ use cdk_integration_tests::init_regtest::{
     get_mint_url, get_mint_ws_url, LND_RPC_ADDR, LND_TWO_RPC_ADDR,
 };
 use cdk_integration_tests::wait_for_mint_to_be_paid;
-use cdk_sqlite::wallet::memory;
+use cdk_sqlite::wallet::{self, memory};
 use futures::{join, SinkExt, StreamExt};
 use lightning_invoice::Bolt11Invoice;
 use ln_regtest_rs::ln_client::{ClnClient, LightningClient, LndClient};
@@ -453,7 +453,7 @@ async fn test_websocket_connection() -> Result<()> {
     let wallet = Wallet::new(
         &get_mint_url("0"),
         CurrencyUnit::Sat,
-        Arc::new(WalletMemoryDatabase::default()),
+        Arc::new(wallet::memory::empty().await?),
         &Mnemonic::generate(12)?.to_seed_normalized(""),
         None,
     )?;

+ 30 - 69
crates/cdk-sqlite/src/common.rs

@@ -1,77 +1,38 @@
-use std::fs::remove_file;
-use std::ops::Deref;
 use std::str::FromStr;
-use std::sync::atomic::AtomicU64;
-use std::time::{Duration, SystemTime, UNIX_EPOCH};
+use std::time::Duration;
 
 use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
-use sqlx::{Error, Pool, Sqlite};
-
-static FILE_ID: AtomicU64 = AtomicU64::new(0);
-
-/// A wrapper around a `Pool<Sqlite>` that may delete the database file when dropped in order to by
-/// pass the SQLx bug with pools and in-memory databases.
-///
-/// [1] https://github.com/launchbadge/sqlx/issues/362
-/// [2] https://github.com/launchbadge/sqlx/issues/2510
-#[derive(Debug, Clone)]
-pub struct SqlitePool {
-    pool: Pool<Sqlite>,
-    path: String,
-    delete: bool,
-}
-
-impl Drop for SqlitePool {
-    fn drop(&mut self) {
-        if self.delete {
-            let _ = remove_file(&self.path);
-        }
-    }
-}
-
-impl Deref for SqlitePool {
-    type Target = Pool<Sqlite>;
-
-    fn deref(&self) -> &Self::Target {
-        &self.pool
-    }
-}
+use sqlx::{Error, Executor, Pool, Sqlite};
 
 #[inline(always)]
-pub async fn create_sqlite_pool(path: &str) -> Result<SqlitePool, Error> {
-    let (path, delete) = if path.ends_with(":memory:") {
-        (
-            format!(
-                "in-memory-{}-{}",
-                SystemTime::now()
-                    .duration_since(UNIX_EPOCH)
-                    .unwrap()
-                    .as_nanos(),
-                FILE_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
-            ),
-            true,
-        )
-    } else {
-        (path.to_owned(), false)
-    };
-
-    let db_options = SqliteConnectOptions::from_str(&path)?
-        .journal_mode(if delete {
-            sqlx::sqlite::SqliteJournalMode::Memory
-        } else {
-            sqlx::sqlite::SqliteJournalMode::Wal
-        })
-        .busy_timeout(Duration::from_secs(5))
+pub async fn create_sqlite_pool(path: &str) -> Result<Pool<Sqlite>, Error> {
+    let db_options = SqliteConnectOptions::from_str(path)?
+        .busy_timeout(Duration::from_secs(10))
         .read_only(false)
-        .create_if_missing(true)
-        .auto_vacuum(sqlx::sqlite::SqliteAutoVacuum::Full);
+        .create_if_missing(true);
+
+    let pool = SqlitePoolOptions::new()
+        .min_connections(1)
+        .max_connections(1)
+        .before_acquire(|conn, _meta| {
+            Box::pin(async move {
+                // Info: https://phiresky.github.io/blog/2020/sqlite-performance-tuning/
+                conn.execute(
+                    r#"
+                        PRAGMA busy_timeout = 5000;
+                        PRAGMA journal_mode = wal;
+                        PRAGMA synchronous = normal;
+                        PRAGMA temp_store = memory;
+                        PRAGMA mmap_size = 30000000000;
+                        "#,
+                )
+                .await?;
+
+                Ok(true)
+            })
+        })
+        .connect_with(db_options)
+        .await?;
 
-    Ok(SqlitePool {
-        pool: SqlitePoolOptions::new()
-            .max_connections(1)
-            .connect_with(db_options)
-            .await?,
-        delete,
-        path,
-    })
+    Ok(pool)
 }

+ 4 - 5
crates/cdk-sqlite/src/mint/mod.rs

@@ -1,7 +1,6 @@
 //! SQLite Mint
 
 use std::collections::HashMap;
-use std::ops::Deref;
 use std::path::Path;
 use std::str::FromStr;
 
@@ -21,11 +20,11 @@ use cdk_common::{
 use error::Error;
 use lightning_invoice::Bolt11Invoice;
 use sqlx::sqlite::SqliteRow;
-use sqlx::Row;
+use sqlx::{Pool, Row, Sqlite};
 use uuid::fmt::Hyphenated;
 use uuid::Uuid;
 
-use crate::common::{create_sqlite_pool, SqlitePool};
+use crate::common::create_sqlite_pool;
 
 pub mod error;
 pub mod memory;
@@ -33,7 +32,7 @@ pub mod memory;
 /// Mint SQLite Database
 #[derive(Debug, Clone)]
 pub struct MintSqliteDatabase {
-    pool: SqlitePool,
+    pool: Pool<Sqlite>,
 }
 
 impl MintSqliteDatabase {
@@ -47,7 +46,7 @@ impl MintSqliteDatabase {
     /// Migrate [`MintSqliteDatabase`]
     pub async fn migrate(&self) {
         sqlx::migrate!("./src/mint/migrations")
-            .run(self.pool.deref())
+            .run(&self.pool)
             .await
             .expect("Could not run migrations");
     }

+ 29 - 30
crates/cdk-sqlite/src/wallet/mod.rs

@@ -1,7 +1,6 @@
 //! SQLite Wallet Database
 
 use std::collections::HashMap;
-use std::ops::Deref;
 use std::path::Path;
 use std::str::FromStr;
 
@@ -18,10 +17,10 @@ use cdk_common::{
 };
 use error::Error;
 use sqlx::sqlite::SqliteRow;
-use sqlx::Row;
+use sqlx::{Pool, Row, Sqlite};
 use tracing::instrument;
 
-use crate::common::{create_sqlite_pool, SqlitePool};
+use crate::common::create_sqlite_pool;
 
 pub mod error;
 pub mod memory;
@@ -29,7 +28,7 @@ pub mod memory;
 /// Wallet SQLite Database
 #[derive(Debug, Clone)]
 pub struct WalletSqliteDatabase {
-    pool: SqlitePool,
+    pool: Pool<Sqlite>,
 }
 
 impl WalletSqliteDatabase {
@@ -43,7 +42,7 @@ impl WalletSqliteDatabase {
     /// Migrate [`WalletSqliteDatabase`]
     pub async fn migrate(&self) {
         sqlx::migrate!("./src/wallet/migrations")
-            .run(self.pool.deref())
+            .run(&self.pool)
             .await
             .expect("Could not run migrations");
     }
@@ -58,7 +57,7 @@ impl WalletSqliteDatabase {
         )
         .bind(state.to_string())
         .bind(y.to_bytes().to_vec())
-        .execute(self.pool.deref())
+        .execute(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -155,7 +154,7 @@ ON CONFLICT(mint_url) DO UPDATE SET
         .bind(urls)
         .bind(motd)
         .bind(time.map(|v| v as i64))
-        .execute(self.pool.deref())
+        .execute(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -171,7 +170,7 @@ WHERE mint_url=?
         "#,
         )
         .bind(mint_url.to_string())
-        .execute(self.pool.deref())
+        .execute(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -188,7 +187,7 @@ WHERE mint_url=?;
         "#,
         )
         .bind(mint_url.to_string())
-        .fetch_one(self.pool.deref())
+        .fetch_one(&self.pool)
         .await;
 
         let rec = match rec {
@@ -210,7 +209,7 @@ SELECT *
 FROM mint
         "#,
         )
-        .fetch_all(self.pool.deref())
+        .fetch_all(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -251,7 +250,7 @@ FROM mint
             sqlx::query(&query)
                 .bind(new_mint_url.to_string())
                 .bind(old_mint_url.to_string())
-                .execute(self.pool.deref())
+                .execute(&self.pool)
                 .await
                 .map_err(Error::from)?;
         }
@@ -282,7 +281,7 @@ FROM mint
             .bind(keyset.unit.to_string())
             .bind(keyset.active)
             .bind(keyset.input_fee_ppk as i64)
-            .execute(self.pool.deref())
+            .execute(&self.pool)
             .await
             .map_err(Error::from)?;
         }
@@ -303,7 +302,7 @@ WHERE mint_url=?
         "#,
         )
         .bind(mint_url.to_string())
-        .fetch_all(self.pool.deref())
+        .fetch_all(&self.pool)
         .await;
 
         let recs = match recs {
@@ -335,7 +334,7 @@ WHERE id=?
         "#,
         )
         .bind(keyset_id.to_string())
-        .fetch_one(self.pool.deref())
+        .fetch_one(&self.pool)
         .await;
 
         let rec = match rec {
@@ -375,7 +374,7 @@ ON CONFLICT(id) DO UPDATE SET
         .bind(quote.state.to_string())
         .bind(quote.expiry as i64)
         .bind(quote.secret_key.map(|p| p.to_string()))
-        .execute(self.pool.deref())
+        .execute(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -392,7 +391,7 @@ WHERE id=?;
         "#,
         )
         .bind(quote_id)
-        .fetch_one(self.pool.deref())
+        .fetch_one(&self.pool)
         .await;
 
         let rec = match rec {
@@ -414,7 +413,7 @@ SELECT *
 FROM mint_quote
         "#,
         )
-        .fetch_all(self.pool.deref())
+        .fetch_all(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -435,7 +434,7 @@ WHERE id=?
         "#,
         )
         .bind(quote_id)
-        .execute(self.pool.deref())
+        .execute(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -466,7 +465,7 @@ ON CONFLICT(id) DO UPDATE SET
         .bind(u64::from(quote.fee_reserve) as i64)
         .bind(quote.state.to_string())
         .bind(quote.expiry as i64)
-        .execute(self.pool.deref())
+        .execute(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -483,7 +482,7 @@ WHERE id=?;
         "#,
         )
         .bind(quote_id)
-        .fetch_one(self.pool.deref())
+        .fetch_one(&self.pool)
         .await;
 
         let rec = match rec {
@@ -506,7 +505,7 @@ WHERE id=?
         "#,
         )
         .bind(quote_id)
-        .execute(self.pool.deref())
+        .execute(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -527,7 +526,7 @@ ON CONFLICT(id) DO UPDATE SET
         )
         .bind(Id::from(&keys).to_string())
         .bind(serde_json::to_string(&keys).map_err(Error::from)?)
-        .execute(self.pool.deref())
+        .execute(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -544,7 +543,7 @@ WHERE id=?;
         "#,
         )
         .bind(keyset_id.to_string())
-        .fetch_one(self.pool.deref())
+        .fetch_one(&self.pool)
         .await;
 
         let rec = match rec {
@@ -569,7 +568,7 @@ WHERE id=?
         "#,
         )
         .bind(id.to_string())
-        .execute(self.pool.deref())
+        .execute(&self.pool)
         .await
         .map_err(Error::from)?;
 
@@ -619,7 +618,7 @@ WHERE id=?
                     .witness
                     .map(|w| serde_json::to_string(&w).unwrap()),
             )
-            .execute(self.pool.deref())
+            .execute(&self.pool)
             .await
             .map_err(Error::from)?;
         }
@@ -633,7 +632,7 @@ WHERE id=?
             "#,
             )
             .bind(y.to_bytes().to_vec())
-            .execute(self.pool.deref())
+            .execute(&self.pool)
             .await
             .map_err(Error::from)?;
         }
@@ -679,7 +678,7 @@ SELECT *
 FROM proof;
         "#,
         )
-        .fetch_all(self.pool.deref())
+        .fetch_all(&self.pool)
         .await;
 
         let recs = match recs {
@@ -749,7 +748,7 @@ WHERE id=?;
         "#,
         )
         .bind(keyset_id.to_string())
-        .fetch_one(self.pool.deref())
+        .fetch_one(&self.pool)
         .await;
 
         let count = match rec {
@@ -779,7 +778,7 @@ WHERE key=?;
         "#,
         )
         .bind(verifying_key.to_bytes().to_vec())
-        .fetch_one(self.pool.deref())
+        .fetch_one(&self.pool)
         .await;
 
         let count = match rec {
@@ -814,7 +813,7 @@ ON CONFLICT(key) DO UPDATE SET
         )
         .bind(verifying_key.to_bytes().to_vec())
         .bind(last_checked)
-        .execute(self.pool.deref())
+        .execute(&self.pool)
         .await
         .map_err(Error::from)?;