| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345 |
- use std::fmt::Debug;
- use std::sync::atomic::AtomicBool;
- use std::sync::{Arc, OnceLock};
- use std::time::Duration;
- use cdk_common::database::Error;
- use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
- use cdk_sql_common::mint::SQLMintAuthDatabase;
- use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
- use cdk_sql_common::stmt::{Column, Statement};
- use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
- use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
- use native_tls::TlsConnector;
- use postgres_native_tls::MakeTlsConnector;
- use tokio::sync::{Mutex, Notify};
- use tokio::time::timeout;
- use tokio_postgres::{connect, Client, Error as PgError, NoTls};
- mod db;
- mod value;
- #[derive(Debug)]
- pub struct PgConnectionPool;
- #[derive(Clone)]
- pub enum SslMode {
- NoTls(NoTls),
- NativeTls(postgres_native_tls::MakeTlsConnector),
- }
- const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
- const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
- const SSLMODE_PREFER: &str = "sslmode=prefer";
- const SSLMODE_ALLOW: &str = "sslmode=allow";
- const SSLMODE_REQUIRE: &str = "sslmode=require";
- impl Default for SslMode {
- fn default() -> Self {
- SslMode::NoTls(NoTls {})
- }
- }
- impl Debug for SslMode {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- let debug_text = match self {
- Self::NoTls(_) => "NoTls",
- Self::NativeTls(_) => "NativeTls",
- };
- write!(f, "SslMode::{debug_text}")
- }
- }
- /// Postgres configuration
- #[derive(Clone, Debug)]
- pub struct PgConfig {
- url: String,
- schema: Option<String>,
- tls: SslMode,
- }
- impl DatabaseConfig for PgConfig {
- fn default_timeout(&self) -> Duration {
- Duration::from_secs(10)
- }
- fn max_size(&self) -> usize {
- 20
- }
- }
- impl PgConfig {
- /// strip schema from the connection string
- fn strip_schema(input: &str) -> (Option<String>, String) {
- let mut schema: Option<String> = None;
- // Split by whitespace
- let mut parts = Vec::new();
- for token in input.split_whitespace() {
- if let Some(rest) = token.strip_prefix("schema=") {
- schema = Some(rest.to_string());
- } else {
- parts.push(token);
- }
- }
- let cleaned = parts.join(" ");
- (schema, cleaned)
- }
- }
- impl From<&str> for PgConfig {
- fn from(conn_str: &str) -> Self {
- let (schema, conn_str) = Self::strip_schema(conn_str);
- fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
- let mut builder = TlsConnector::builder();
- if accept_invalid_certs {
- builder.danger_accept_invalid_certs(true);
- }
- if accept_invalid_hostnames {
- builder.danger_accept_invalid_hostnames(true);
- }
- match builder.build() {
- Ok(connector) => {
- let make_tls_connector = MakeTlsConnector::new(connector);
- SslMode::NativeTls(make_tls_connector)
- }
- Err(_) => SslMode::NoTls(NoTls {}),
- }
- }
- let tls = if conn_str.contains(SSLMODE_VERIFY_FULL) {
- // Strict TLS: valid certs and hostnames required
- build_tls(false, false)
- } else if conn_str.contains(SSLMODE_VERIFY_CA) {
- // Verify CA, but allow invalid hostnames
- build_tls(false, true)
- } else if conn_str.contains(SSLMODE_PREFER)
- || conn_str.contains(SSLMODE_ALLOW)
- || conn_str.contains(SSLMODE_REQUIRE)
- {
- // Lenient TLS for preferred/allow/require: accept invalid certs and hostnames
- build_tls(true, true)
- } else {
- SslMode::NoTls(NoTls {})
- };
- PgConfig {
- url: conn_str.to_owned(),
- schema,
- tls,
- }
- }
- }
- impl DatabasePool for PgConnectionPool {
- type Config = PgConfig;
- type Connection = PostgresConnection;
- type Error = PgError;
- fn new_resource(
- config: &Self::Config,
- stale: Arc<AtomicBool>,
- timeout: Duration,
- ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
- Ok(PostgresConnection::new(config.to_owned(), timeout, stale))
- }
- }
- /// A postgres connection
- #[derive(Debug)]
- pub struct PostgresConnection {
- timeout: Duration,
- error: Arc<Mutex<Option<cdk_common::database::Error>>>,
- result: Arc<OnceLock<Client>>,
- notify: Arc<Notify>,
- }
- impl PostgresConnection {
- /// Creates a new instance
- pub fn new(config: PgConfig, timeout: Duration, stale: Arc<AtomicBool>) -> Self {
- let failed = Arc::new(Mutex::new(None));
- let result = Arc::new(OnceLock::new());
- let notify = Arc::new(Notify::new());
- let error_clone = failed.clone();
- let result_clone = result.clone();
- let notify_clone = notify.clone();
- async fn select_schema(conn: &Client, schema: &str) -> Result<(), Error> {
- conn.batch_execute(&format!(
- r#"
- CREATE SCHEMA IF NOT EXISTS "{schema}";
- SET search_path TO "{schema}"
- "#
- ))
- .await
- .map_err(|e| Error::Database(Box::new(e)))
- }
- tokio::spawn(async move {
- match config.tls {
- SslMode::NoTls(tls) => {
- let (client, connection) = match connect(&config.url, tls).await {
- Ok((client, connection)) => (client, connection),
- Err(err) => {
- *error_clone.lock().await =
- Some(cdk_common::database::Error::Database(Box::new(err)));
- stale.store(false, std::sync::atomic::Ordering::Release);
- notify_clone.notify_waiters();
- return;
- }
- };
- let stale_for_spawn = stale.clone();
- tokio::spawn(async move {
- let _ = connection.await;
- stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
- });
- if let Some(schema) = config.schema.as_ref() {
- if let Err(err) = select_schema(&client, schema).await {
- *error_clone.lock().await = Some(err);
- stale.store(false, std::sync::atomic::Ordering::Release);
- notify_clone.notify_waiters();
- return;
- }
- }
- let _ = result_clone.set(client);
- notify_clone.notify_waiters();
- }
- SslMode::NativeTls(tls) => {
- let (client, connection) = match connect(&config.url, tls).await {
- Ok((client, connection)) => (client, connection),
- Err(err) => {
- *error_clone.lock().await =
- Some(cdk_common::database::Error::Database(Box::new(err)));
- stale.store(false, std::sync::atomic::Ordering::Release);
- notify_clone.notify_waiters();
- return;
- }
- };
- let stale_for_spawn = stale.clone();
- tokio::spawn(async move {
- let _ = connection.await;
- stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
- });
- if let Some(schema) = config.schema.as_ref() {
- if let Err(err) = select_schema(&client, schema).await {
- *error_clone.lock().await = Some(err);
- stale.store(true, std::sync::atomic::Ordering::Release);
- notify_clone.notify_waiters();
- return;
- }
- }
- let _ = result_clone.set(client);
- notify_clone.notify_waiters();
- }
- }
- });
- Self {
- error: failed,
- timeout,
- result,
- notify,
- }
- }
- /// Gets the wrapped instance or the connection error. The connection is returned as reference,
- /// and the actual error is returned once, next times a generic error would be returned
- async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
- if let Some(client) = self.result.get() {
- return Ok(client);
- }
- if let Some(error) = self.error.lock().await.take() {
- return Err(error);
- }
- if timeout(self.timeout, self.notify.notified()).await.is_err() {
- return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
- }
- // Check result again
- if let Some(client) = self.result.get() {
- Ok(client)
- } else if let Some(error) = self.error.lock().await.take() {
- Err(error)
- } else {
- Err(cdk_common::database::Error::Internal(
- "Failed connection".to_owned(),
- ))
- }
- }
- }
- #[async_trait::async_trait]
- impl DatabaseConnector for PostgresConnection {
- type Transaction = GenericTransactionHandler<Self>;
- }
- #[async_trait::async_trait]
- impl DatabaseExecutor for PostgresConnection {
- fn name() -> &'static str {
- "postgres"
- }
- async fn execute(&self, statement: Statement) -> Result<usize, Error> {
- pg_execute(self.inner().await?, statement).await
- }
- async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
- pg_fetch_one(self.inner().await?, statement).await
- }
- async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
- pg_fetch_all(self.inner().await?, statement).await
- }
- async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
- pg_pluck(self.inner().await?, statement).await
- }
- async fn batch(&self, statement: Statement) -> Result<(), Error> {
- pg_batch(self.inner().await?, statement).await
- }
- }
- /// Mint DB implementation with PostgreSQL
- pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
- /// Mint Auth database with Postgres
- #[cfg(feature = "auth")]
- pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
- /// Mint DB implementation with PostgresSQL
- pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
- #[cfg(test)]
- mod test {
- use cdk_common::mint_db_test;
- use super::*;
- async fn provide_db(test_id: String) -> MintPgDatabase {
- let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
- .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
- .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
- let db_url = format!("{db_url} schema={test_id}");
- let db = MintPgDatabase::new(db_url.as_str())
- .await
- .expect("database");
- db
- }
- mint_db_test!(provide_db);
- }
|