lib.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. use std::fmt::Debug;
  2. use std::sync::atomic::AtomicBool;
  3. use std::sync::{Arc, OnceLock};
  4. use std::time::Duration;
  5. use cdk_common::database::Error;
  6. use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
  7. use cdk_sql_common::mint::SQLMintAuthDatabase;
  8. use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
  9. use cdk_sql_common::stmt::{Column, Statement};
  10. use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
  11. use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
  12. use futures_util::future::BoxFuture;
  13. use native_tls::TlsConnector;
  14. use postgres_native_tls::MakeTlsConnector;
  15. use tokio::sync::{Mutex, Notify};
  16. use tokio::time::timeout;
  17. use tokio_postgres::{connect, Client, Error as PgError, NoTls};
  18. mod db;
  19. mod value;
  20. #[derive(Debug)]
  21. pub struct PgConnectionPool;
  22. #[derive(Clone)]
  23. pub enum SslMode {
  24. NoTls(NoTls),
  25. NativeTls(postgres_native_tls::MakeTlsConnector),
  26. }
  27. const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
  28. const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
  29. const SSLMODE_PREFER: &str = "sslmode=prefer";
  30. const SSLMODE_ALLOW: &str = "sslmode=allow";
  31. const SSLMODE_REQUIRE: &str = "sslmode=require";
  32. impl Default for SslMode {
  33. fn default() -> Self {
  34. SslMode::NoTls(NoTls {})
  35. }
  36. }
  37. impl Debug for SslMode {
  38. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  39. let debug_text = match self {
  40. Self::NoTls(_) => "NoTls",
  41. Self::NativeTls(_) => "NativeTls",
  42. };
  43. write!(f, "SslMode::{debug_text}")
  44. }
  45. }
  46. /// Postgres configuration
  47. #[derive(Clone, Debug)]
  48. pub struct PgConfig {
  49. url: String,
  50. schema: Option<String>,
  51. tls: SslMode,
  52. }
  53. impl DatabaseConfig for PgConfig {
  54. fn default_timeout(&self) -> Duration {
  55. Duration::from_secs(10)
  56. }
  57. fn max_size(&self) -> usize {
  58. 20
  59. }
  60. }
  61. impl PgConfig {
  62. /// strip schema from the connection string
  63. fn strip_schema(input: &str) -> (Option<String>, String) {
  64. let mut schema: Option<String> = None;
  65. // Split by whitespace
  66. let mut parts = Vec::new();
  67. for token in input.split_whitespace() {
  68. if let Some(rest) = token.strip_prefix("schema=") {
  69. schema = Some(rest.to_string());
  70. } else {
  71. parts.push(token);
  72. }
  73. }
  74. let cleaned = parts.join(" ");
  75. (schema, cleaned)
  76. }
  77. }
  78. impl From<&str> for PgConfig {
  79. fn from(conn_str: &str) -> Self {
  80. let (schema, conn_str) = Self::strip_schema(conn_str);
  81. fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
  82. let mut builder = TlsConnector::builder();
  83. if accept_invalid_certs {
  84. builder.danger_accept_invalid_certs(true);
  85. }
  86. if accept_invalid_hostnames {
  87. builder.danger_accept_invalid_hostnames(true);
  88. }
  89. match builder.build() {
  90. Ok(connector) => {
  91. let make_tls_connector = MakeTlsConnector::new(connector);
  92. SslMode::NativeTls(make_tls_connector)
  93. }
  94. Err(_) => SslMode::NoTls(NoTls {}),
  95. }
  96. }
  97. let tls = if conn_str.contains(SSLMODE_VERIFY_FULL) {
  98. // Strict TLS: valid certs and hostnames required
  99. build_tls(false, false)
  100. } else if conn_str.contains(SSLMODE_VERIFY_CA) {
  101. // Verify CA, but allow invalid hostnames
  102. build_tls(false, true)
  103. } else if conn_str.contains(SSLMODE_PREFER)
  104. || conn_str.contains(SSLMODE_ALLOW)
  105. || conn_str.contains(SSLMODE_REQUIRE)
  106. {
  107. // Lenient TLS for preferred/allow/require: accept invalid certs and hostnames
  108. build_tls(true, true)
  109. } else {
  110. SslMode::NoTls(NoTls {})
  111. };
  112. PgConfig {
  113. url: conn_str.to_owned(),
  114. schema,
  115. tls,
  116. }
  117. }
  118. }
  119. impl DatabasePool for PgConnectionPool {
  120. type Config = PgConfig;
  121. type Connection = PostgresConnection;
  122. type Error = PgError;
  123. fn new_resource(
  124. config: &Self::Config,
  125. still_valid: Arc<AtomicBool>,
  126. timeout: Duration,
  127. ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
  128. Ok(PostgresConnection::new(
  129. config.to_owned(),
  130. timeout,
  131. still_valid,
  132. ))
  133. }
  134. }
  135. /// A postgres connection
  136. #[derive(Debug)]
  137. pub struct PostgresConnection {
  138. timeout: Duration,
  139. error: Arc<Mutex<Option<cdk_common::database::Error>>>,
  140. result: Arc<OnceLock<Client>>,
  141. notify: Arc<Notify>,
  142. }
  143. impl PostgresConnection {
  144. /// Creates a new instance
  145. pub fn new(config: PgConfig, timeout: Duration, still_valid: Arc<AtomicBool>) -> Self {
  146. let failed = Arc::new(Mutex::new(None));
  147. let result = Arc::new(OnceLock::new());
  148. let notify = Arc::new(Notify::new());
  149. let error_clone = failed.clone();
  150. let result_clone = result.clone();
  151. let notify_clone = notify.clone();
  152. fn select_schema<'a>(
  153. conn: &'a Client,
  154. schema: &'a str,
  155. ) -> BoxFuture<'a, Result<(), Error>> {
  156. Box::pin(async move {
  157. conn.batch_execute(&format!(
  158. r#"
  159. CREATE SCHEMA IF NOT EXISTS "{schema}";
  160. SET search_path TO "{schema}"
  161. "#
  162. ))
  163. .await
  164. .map_err(|e| Error::Database(Box::new(e)))
  165. })
  166. }
  167. tokio::spawn(async move {
  168. match config.tls {
  169. SslMode::NoTls(tls) => {
  170. let (client, connection) = match connect(&config.url, tls).await {
  171. Ok((client, connection)) => (client, connection),
  172. Err(err) => {
  173. *error_clone.lock().await =
  174. Some(cdk_common::database::Error::Database(Box::new(err)));
  175. still_valid.store(false, std::sync::atomic::Ordering::Release);
  176. notify_clone.notify_waiters();
  177. return;
  178. }
  179. };
  180. let still_valid_for_spawn = still_valid.clone();
  181. tokio::spawn(async move {
  182. let _ = connection.await;
  183. still_valid_for_spawn.store(false, std::sync::atomic::Ordering::Release);
  184. });
  185. if let Some(schema) = config.schema.as_ref() {
  186. if let Err(err) = select_schema(&client, schema).await {
  187. *error_clone.lock().await = Some(err);
  188. still_valid.store(false, std::sync::atomic::Ordering::Release);
  189. notify_clone.notify_waiters();
  190. return;
  191. }
  192. }
  193. let _ = result_clone.set(client);
  194. notify_clone.notify_waiters();
  195. }
  196. SslMode::NativeTls(tls) => {
  197. let (client, connection) = match connect(&config.url, tls).await {
  198. Ok((client, connection)) => (client, connection),
  199. Err(err) => {
  200. *error_clone.lock().await =
  201. Some(cdk_common::database::Error::Database(Box::new(err)));
  202. still_valid.store(false, std::sync::atomic::Ordering::Release);
  203. notify_clone.notify_waiters();
  204. return;
  205. }
  206. };
  207. let still_valid_for_spawn = still_valid.clone();
  208. tokio::spawn(async move {
  209. let _ = connection.await;
  210. still_valid_for_spawn.store(false, std::sync::atomic::Ordering::Release);
  211. });
  212. if let Some(schema) = config.schema.as_ref() {
  213. if let Err(err) = select_schema(&client, schema).await {
  214. *error_clone.lock().await = Some(err);
  215. still_valid.store(false, std::sync::atomic::Ordering::Release);
  216. notify_clone.notify_waiters();
  217. return;
  218. }
  219. }
  220. let _ = result_clone.set(client);
  221. notify_clone.notify_waiters();
  222. }
  223. }
  224. });
  225. Self {
  226. error: failed,
  227. timeout,
  228. result,
  229. notify,
  230. }
  231. }
  232. /// Gets the wrapped instance or the connection error. The connection is returned as reference,
  233. /// and the actual error is returned once, next times a generic error would be returned
  234. async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
  235. if let Some(client) = self.result.get() {
  236. return Ok(client);
  237. }
  238. if let Some(error) = self.error.lock().await.take() {
  239. return Err(error);
  240. }
  241. if timeout(self.timeout, self.notify.notified()).await.is_err() {
  242. return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
  243. }
  244. // Check result again
  245. if let Some(client) = self.result.get() {
  246. Ok(client)
  247. } else if let Some(error) = self.error.lock().await.take() {
  248. Err(error)
  249. } else {
  250. Err(cdk_common::database::Error::Internal(
  251. "Failed connection".to_owned(),
  252. ))
  253. }
  254. }
  255. }
  256. #[async_trait::async_trait]
  257. impl DatabaseConnector for PostgresConnection {
  258. type Transaction = GenericTransactionHandler<Self>;
  259. }
  260. #[async_trait::async_trait]
  261. impl DatabaseExecutor for PostgresConnection {
  262. fn name() -> &'static str {
  263. "postgres"
  264. }
  265. async fn execute(&self, statement: Statement) -> Result<usize, Error> {
  266. pg_execute(self.inner().await?, statement).await
  267. }
  268. async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
  269. pg_fetch_one(self.inner().await?, statement).await
  270. }
  271. async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
  272. pg_fetch_all(self.inner().await?, statement).await
  273. }
  274. async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
  275. pg_pluck(self.inner().await?, statement).await
  276. }
  277. async fn batch(&self, statement: Statement) -> Result<(), Error> {
  278. pg_batch(self.inner().await?, statement).await
  279. }
  280. }
  281. /// Mint DB implementation with PostgreSQL
  282. pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
  283. /// Mint Auth database with Postgres
  284. #[cfg(feature = "auth")]
  285. pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
  286. /// Mint DB implementation with PostgresSQL
  287. pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
  288. #[cfg(test)]
  289. mod test {
  290. use cdk_common::mint_db_test;
  291. use super::*;
  292. async fn provide_db(test_id: String) -> MintPgDatabase {
  293. let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
  294. .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
  295. .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
  296. let db_url = format!("{db_url} schema={test_id}");
  297. let db = MintPgDatabase::new(db_url.as_str())
  298. .await
  299. .expect("database");
  300. db
  301. }
  302. mint_db_test!(provide_db);
  303. }