Selaa lähdekoodia

Fix postgres tests

Add concept of schema, so each test is isolated
Cesar Rodas 2 kuukautta sitten
vanhempi
säilyke
faf6859567
2 muutettua tiedostoa jossa 72 lisäystä ja 10 poistoa
  1. 6 1
      crates/cdk-common/src/database/mint/test/mod.rs
  2. 66 9
      crates/cdk-postgres/src/lib.rs

+ 6 - 1
crates/cdk-common/src/database/mint/test/mod.rs

@@ -242,7 +242,12 @@ macro_rules! mint_db_test {
         $(
             #[tokio::test]
             async fn $name() {
-                cdk_common::database::mint::test::$name($make_db_fn().await).await;
+                use std::time::{SystemTime, UNIX_EPOCH};
+                let now = SystemTime::now()
+                    .duration_since(UNIX_EPOCH)
+                    .expect("Time went backwards");
+
+                cdk_common::database::mint::test::$name($make_db_fn(format!("test_{}_{}", now.as_nanos(), stringify!($name))).await).await;
             }
         )+
     };

+ 66 - 9
crates/cdk-postgres/src/lib.rs

@@ -10,6 +10,7 @@ use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
 use cdk_sql_common::stmt::{Column, Statement};
 use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
 use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
+use futures_util::future::BoxFuture;
 use native_tls::TlsConnector;
 use postgres_native_tls::MakeTlsConnector;
 use tokio::sync::{Mutex, Notify};
@@ -54,6 +55,7 @@ impl Debug for SslMode {
 #[derive(Clone, Debug)]
 pub struct PgConfig {
     url: String,
+    schema: Option<String>,
     tls: SslMode,
 }
 
@@ -67,8 +69,29 @@ impl DatabaseConfig for PgConfig {
     }
 }
 
+impl PgConfig {
+    /// strip schema from the connection string
+    fn strip_schema(input: &str) -> (Option<String>, String) {
+        let mut schema: Option<String> = None;
+
+        // Split by whitespace
+        let mut parts = Vec::new();
+        for token in input.split_whitespace() {
+            if let Some(rest) = token.strip_prefix("schema=") {
+                schema = Some(rest.to_string());
+            } else {
+                parts.push(token);
+            }
+        }
+
+        let cleaned = parts.join(" ");
+        (schema, cleaned)
+    }
+}
+
 impl From<&str> for PgConfig {
     fn from(conn_str: &str) -> Self {
+        let (schema, conn_str) = Self::strip_schema(conn_str);
         fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
             let mut builder = TlsConnector::builder();
             if accept_invalid_certs {
@@ -105,6 +128,7 @@ impl From<&str> for PgConfig {
 
         PgConfig {
             url: conn_str.to_owned(),
+            schema,
             tls,
         }
     }
@@ -149,6 +173,22 @@ impl PostgresConnection {
         let result_clone = result.clone();
         let notify_clone = notify.clone();
 
+        fn select_schema<'a>(
+            conn: &'a Client,
+            schema: &'a str,
+        ) -> BoxFuture<'a, Result<(), Error>> {
+            Box::pin(async move {
+                conn.batch_execute(&format!(
+                    r#"
+                    CREATE SCHEMA IF NOT EXISTS "{schema}";
+                    SET search_path TO "{schema}"
+                    "#
+                ))
+                .await
+                .map_err(|e| Error::Database(Box::new(e)))
+            })
+        }
+
         tokio::spawn(async move {
             match config.tls {
                 SslMode::NoTls(tls) => {
@@ -163,11 +203,21 @@ impl PostgresConnection {
                         }
                     };
 
+                    let still_valid_for_spawn = still_valid.clone();
                     tokio::spawn(async move {
                         let _ = connection.await;
-                        still_valid.store(false, std::sync::atomic::Ordering::Release);
+                        still_valid_for_spawn.store(false, std::sync::atomic::Ordering::Release);
                     });
 
+                    if let Some(schema) = config.schema.as_ref() {
+                        if let Err(err) = select_schema(&client, schema).await {
+                            *error_clone.lock().await = Some(err);
+                            still_valid.store(false, std::sync::atomic::Ordering::Release);
+                            notify_clone.notify_waiters();
+                            return;
+                        }
+                    }
+
                     let _ = result_clone.set(client);
                     notify_clone.notify_waiters();
                 }
@@ -183,11 +233,21 @@ impl PostgresConnection {
                         }
                     };
 
+                    let still_valid_for_spawn = still_valid.clone();
                     tokio::spawn(async move {
                         let _ = connection.await;
-                        still_valid.store(false, std::sync::atomic::Ordering::Release);
+                        still_valid_for_spawn.store(false, std::sync::atomic::Ordering::Release);
                     });
 
+                    if let Some(schema) = config.schema.as_ref() {
+                        if let Err(err) = select_schema(&client, schema).await {
+                            *error_clone.lock().await = Some(err);
+                            still_valid.store(false, std::sync::atomic::Ordering::Release);
+                            notify_clone.notify_waiters();
+                            return;
+                        }
+                    }
+
                     let _ = result_clone.set(client);
                     notify_clone.notify_waiters();
                 }
@@ -275,22 +335,19 @@ pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
 #[cfg(test)]
 mod test {
     use cdk_common::mint_db_test;
-    use once_cell::sync::Lazy;
-    use tokio::sync::Mutex;
 
     use super::*;
 
-    static MIGRATION_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
-
-    async fn provide_db() -> MintPgDatabase {
-        let m = MIGRATION_LOCK.lock().await;
+    async fn provide_db(test_id: String) -> MintPgDatabase {
         let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
             .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
             .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
+
+        let db_url = format!("{db_url} schema={test_id}");
+
         let db = MintPgDatabase::new(db_url.as_str())
             .await
             .expect("database");
-        drop(m);
         db
     }