lib.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. use std::fmt::Debug;
  2. use std::sync::atomic::AtomicBool;
  3. use std::sync::{Arc, OnceLock};
  4. use std::time::Duration;
  5. use cdk_common::database::Error;
  6. use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
  7. use cdk_sql_common::mint::SQLMintAuthDatabase;
  8. use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
  9. use cdk_sql_common::stmt::{Column, Statement};
  10. use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
  11. use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
  12. use native_tls::TlsConnector;
  13. use postgres_native_tls::MakeTlsConnector;
  14. use tokio::sync::{Mutex, Notify};
  15. use tokio::time::timeout;
  16. use tokio_postgres::{connect, Client, Error as PgError, NoTls};
  17. mod db;
  18. mod value;
  19. #[derive(Debug)]
  20. pub struct PgConnectionPool;
  21. #[derive(Clone)]
  22. pub enum SslMode {
  23. NoTls(NoTls),
  24. NativeTls(postgres_native_tls::MakeTlsConnector),
  25. }
  26. const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
  27. const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
  28. const SSLMODE_PREFER: &str = "sslmode=prefer";
  29. const SSLMODE_ALLOW: &str = "sslmode=allow";
  30. const SSLMODE_REQUIRE: &str = "sslmode=require";
  31. impl Default for SslMode {
  32. fn default() -> Self {
  33. SslMode::NoTls(NoTls {})
  34. }
  35. }
  36. impl Debug for SslMode {
  37. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  38. let debug_text = match self {
  39. Self::NoTls(_) => "NoTls",
  40. Self::NativeTls(_) => "NativeTls",
  41. };
  42. write!(f, "SslMode::{debug_text}")
  43. }
  44. }
  45. /// Postgres configuration
  46. #[derive(Clone, Debug)]
  47. pub struct PgConfig {
  48. url: String,
  49. schema: Option<String>,
  50. tls: SslMode,
  51. }
  52. impl DatabaseConfig for PgConfig {
  53. fn default_timeout(&self) -> Duration {
  54. Duration::from_secs(10)
  55. }
  56. fn max_size(&self) -> usize {
  57. 20
  58. }
  59. }
  60. impl PgConfig {
  61. /// strip schema from the connection string
  62. fn strip_schema(input: &str) -> (Option<String>, String) {
  63. let mut schema: Option<String> = None;
  64. // Split by whitespace
  65. let mut parts = Vec::new();
  66. for token in input.split_whitespace() {
  67. if let Some(rest) = token.strip_prefix("schema=") {
  68. schema = Some(rest.to_string());
  69. } else {
  70. parts.push(token);
  71. }
  72. }
  73. let cleaned = parts.join(" ");
  74. (schema, cleaned)
  75. }
  76. }
  77. impl From<&str> for PgConfig {
  78. fn from(conn_str: &str) -> Self {
  79. let (schema, conn_str) = Self::strip_schema(conn_str);
  80. fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
  81. let mut builder = TlsConnector::builder();
  82. if accept_invalid_certs {
  83. builder.danger_accept_invalid_certs(true);
  84. }
  85. if accept_invalid_hostnames {
  86. builder.danger_accept_invalid_hostnames(true);
  87. }
  88. match builder.build() {
  89. Ok(connector) => {
  90. let make_tls_connector = MakeTlsConnector::new(connector);
  91. SslMode::NativeTls(make_tls_connector)
  92. }
  93. Err(_) => SslMode::NoTls(NoTls {}),
  94. }
  95. }
  96. let tls = if conn_str.contains(SSLMODE_VERIFY_FULL) {
  97. // Strict TLS: valid certs and hostnames required
  98. build_tls(false, false)
  99. } else if conn_str.contains(SSLMODE_VERIFY_CA) {
  100. // Verify CA, but allow invalid hostnames
  101. build_tls(false, true)
  102. } else if conn_str.contains(SSLMODE_PREFER)
  103. || conn_str.contains(SSLMODE_ALLOW)
  104. || conn_str.contains(SSLMODE_REQUIRE)
  105. {
  106. // Lenient TLS for preferred/allow/require: accept invalid certs and hostnames
  107. build_tls(true, true)
  108. } else {
  109. SslMode::NoTls(NoTls {})
  110. };
  111. PgConfig {
  112. url: conn_str.to_owned(),
  113. schema,
  114. tls,
  115. }
  116. }
  117. }
  118. impl DatabasePool for PgConnectionPool {
  119. type Config = PgConfig;
  120. type Connection = PostgresConnection;
  121. type Error = PgError;
  122. fn new_resource(
  123. config: &Self::Config,
  124. stale: Arc<AtomicBool>,
  125. timeout: Duration,
  126. ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
  127. Ok(PostgresConnection::new(config.to_owned(), timeout, stale))
  128. }
  129. }
  130. /// A postgres connection
  131. #[derive(Debug)]
  132. pub struct PostgresConnection {
  133. timeout: Duration,
  134. error: Arc<Mutex<Option<cdk_common::database::Error>>>,
  135. result: Arc<OnceLock<Client>>,
  136. notify: Arc<Notify>,
  137. }
  138. impl PostgresConnection {
  139. /// Creates a new instance
  140. pub fn new(config: PgConfig, timeout: Duration, stale: Arc<AtomicBool>) -> Self {
  141. let failed = Arc::new(Mutex::new(None));
  142. let result = Arc::new(OnceLock::new());
  143. let notify = Arc::new(Notify::new());
  144. let error_clone = failed.clone();
  145. let result_clone = result.clone();
  146. let notify_clone = notify.clone();
  147. async fn select_schema(conn: &Client, schema: &str) -> Result<(), Error> {
  148. conn.batch_execute(&format!(
  149. r#"
  150. CREATE SCHEMA IF NOT EXISTS "{schema}";
  151. SET search_path TO "{schema}"
  152. "#
  153. ))
  154. .await
  155. .map_err(|e| Error::Database(Box::new(e)))
  156. }
  157. tokio::spawn(async move {
  158. match config.tls {
  159. SslMode::NoTls(tls) => {
  160. let (client, connection) = match connect(&config.url, tls).await {
  161. Ok((client, connection)) => (client, connection),
  162. Err(err) => {
  163. *error_clone.lock().await =
  164. Some(cdk_common::database::Error::Database(Box::new(err)));
  165. stale.store(false, std::sync::atomic::Ordering::Release);
  166. notify_clone.notify_waiters();
  167. return;
  168. }
  169. };
  170. let stale_for_spawn = stale.clone();
  171. tokio::spawn(async move {
  172. let _ = connection.await;
  173. stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
  174. });
  175. if let Some(schema) = config.schema.as_ref() {
  176. if let Err(err) = select_schema(&client, schema).await {
  177. *error_clone.lock().await = Some(err);
  178. stale.store(false, std::sync::atomic::Ordering::Release);
  179. notify_clone.notify_waiters();
  180. return;
  181. }
  182. }
  183. let _ = result_clone.set(client);
  184. notify_clone.notify_waiters();
  185. }
  186. SslMode::NativeTls(tls) => {
  187. let (client, connection) = match connect(&config.url, tls).await {
  188. Ok((client, connection)) => (client, connection),
  189. Err(err) => {
  190. *error_clone.lock().await =
  191. Some(cdk_common::database::Error::Database(Box::new(err)));
  192. stale.store(false, std::sync::atomic::Ordering::Release);
  193. notify_clone.notify_waiters();
  194. return;
  195. }
  196. };
  197. let stale_for_spawn = stale.clone();
  198. tokio::spawn(async move {
  199. let _ = connection.await;
  200. stale_for_spawn.store(true, std::sync::atomic::Ordering::Release);
  201. });
  202. if let Some(schema) = config.schema.as_ref() {
  203. if let Err(err) = select_schema(&client, schema).await {
  204. *error_clone.lock().await = Some(err);
  205. stale.store(true, std::sync::atomic::Ordering::Release);
  206. notify_clone.notify_waiters();
  207. return;
  208. }
  209. }
  210. let _ = result_clone.set(client);
  211. notify_clone.notify_waiters();
  212. }
  213. }
  214. });
  215. Self {
  216. error: failed,
  217. timeout,
  218. result,
  219. notify,
  220. }
  221. }
  222. /// Gets the wrapped instance or the connection error. The connection is returned as reference,
  223. /// and the actual error is returned once, next times a generic error would be returned
  224. async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
  225. if let Some(client) = self.result.get() {
  226. return Ok(client);
  227. }
  228. if let Some(error) = self.error.lock().await.take() {
  229. return Err(error);
  230. }
  231. if timeout(self.timeout, self.notify.notified()).await.is_err() {
  232. return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
  233. }
  234. // Check result again
  235. if let Some(client) = self.result.get() {
  236. Ok(client)
  237. } else if let Some(error) = self.error.lock().await.take() {
  238. Err(error)
  239. } else {
  240. Err(cdk_common::database::Error::Internal(
  241. "Failed connection".to_owned(),
  242. ))
  243. }
  244. }
  245. }
  246. #[async_trait::async_trait]
  247. impl DatabaseConnector for PostgresConnection {
  248. type Transaction = GenericTransactionHandler<Self>;
  249. }
  250. #[async_trait::async_trait]
  251. impl DatabaseExecutor for PostgresConnection {
  252. fn name() -> &'static str {
  253. "postgres"
  254. }
  255. async fn execute(&self, statement: Statement) -> Result<usize, Error> {
  256. pg_execute(self.inner().await?, statement).await
  257. }
  258. async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
  259. pg_fetch_one(self.inner().await?, statement).await
  260. }
  261. async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
  262. pg_fetch_all(self.inner().await?, statement).await
  263. }
  264. async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
  265. pg_pluck(self.inner().await?, statement).await
  266. }
  267. async fn batch(&self, statement: Statement) -> Result<(), Error> {
  268. pg_batch(self.inner().await?, statement).await
  269. }
  270. }
  271. /// Mint DB implementation with PostgreSQL
  272. pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
  273. /// Mint Auth database with Postgres
  274. #[cfg(feature = "auth")]
  275. pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
  276. /// Mint DB implementation with PostgresSQL
  277. pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
  278. #[cfg(test)]
  279. mod test {
  280. use cdk_common::mint_db_test;
  281. use super::*;
  282. async fn provide_db(test_id: String) -> MintPgDatabase {
  283. let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
  284. .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
  285. .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
  286. let db_url = format!("{db_url} schema={test_id}");
  287. let db = MintPgDatabase::new(db_url.as_str())
  288. .await
  289. .expect("database");
  290. db
  291. }
  292. mint_db_test!(provide_db);
  293. }