Эх сурвалжийг харах

Merge pull request #1002 from asmogo/fix/psql_tls

feat: add TLS support for PostgreSQL connections
C 2 сар өмнө
parent
commit
14473d8051

+ 1 - 0
crates/cdk-postgres/Cargo.toml

@@ -31,4 +31,5 @@ uuid.workspace = true
 tokio-postgres = "0.7.13"
 futures-util = "0.3.31"
 postgres-native-tls = "0.5.1"
+native-tls = "0.2"
 once_cell.workspace = true

+ 44 - 3
crates/cdk-postgres/src/lib.rs

@@ -10,6 +10,8 @@ use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
 use cdk_sql_common::stmt::{Column, Statement};
 use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
 use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
+use native_tls::TlsConnector;
+use postgres_native_tls::MakeTlsConnector;
 use tokio::sync::{Mutex, Notify};
 use tokio::time::timeout;
 use tokio_postgres::{connect, Client, Error as PgError, NoTls};
@@ -25,6 +27,11 @@ pub enum SslMode {
     NoTls(NoTls),
     NativeTls(postgres_native_tls::MakeTlsConnector),
 }
+const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
+const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
+const SSLMODE_PREFER: &str = "sslmode=prefer";
+const SSLMODE_ALLOW: &str = "sslmode=allow";
+const SSLMODE_REQUIRE: &str = "sslmode=require";
 
 impl Default for SslMode {
     fn default() -> Self {
@@ -61,10 +68,44 @@ impl DatabaseConfig for PgConfig {
 }
 
 impl From<&str> for PgConfig {
-    fn from(value: &str) -> Self {
+    fn from(conn_str: &str) -> Self {
+        fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
+            let mut builder = TlsConnector::builder();
+            if accept_invalid_certs {
+                builder.danger_accept_invalid_certs(true);
+            }
+            if accept_invalid_hostnames {
+                builder.danger_accept_invalid_hostnames(true);
+            }
+
+            match builder.build() {
+                Ok(connector) => {
+                    let make_tls_connector = MakeTlsConnector::new(connector);
+                    SslMode::NativeTls(make_tls_connector)
+                }
+                Err(_) => SslMode::NoTls(NoTls {}),
+            }
+        }
+
+        let tls = if conn_str.contains(SSLMODE_VERIFY_FULL) {
+            // Strict TLS: valid certs and hostnames required
+            build_tls(false, false)
+        } else if conn_str.contains(SSLMODE_VERIFY_CA) {
+            // Verify CA, but allow invalid hostnames
+            build_tls(false, true)
+        } else if conn_str.contains(SSLMODE_PREFER)
+            || conn_str.contains(SSLMODE_ALLOW)
+            || conn_str.contains(SSLMODE_REQUIRE)
+        {
+            // Lenient TLS for preferred/allow/require: accept invalid certs and hostnames
+            build_tls(true, true)
+        } else {
+            SslMode::NoTls(NoTls {})
+        };
+
         PgConfig {
-            url: value.to_owned(),
-            tls: Default::default(),
+            url: conn_str.to_owned(),
+            tls,
         }
     }
 }