lib.rs 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. use std::fmt::Debug;
  2. use std::sync::atomic::AtomicBool;
  3. use std::sync::{Arc, OnceLock};
  4. use std::time::Duration;
  5. use cdk_common::database::Error;
  6. use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
  7. use cdk_sql_common::mint::SQLMintAuthDatabase;
  8. use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
  9. use cdk_sql_common::stmt::{Column, Statement};
  10. use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
  11. use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
  12. use native_tls::TlsConnector;
  13. use postgres_native_tls::MakeTlsConnector;
  14. use tokio::sync::{Mutex, Notify};
  15. use tokio::time::timeout;
  16. use tokio_postgres::{connect, Client, Error as PgError, NoTls};
  17. mod db;
  18. mod value;
  19. #[derive(Debug)]
  20. pub struct PgConnectionPool;
  21. #[derive(Clone)]
  22. pub enum SslMode {
  23. NoTls(NoTls),
  24. NativeTls(postgres_native_tls::MakeTlsConnector),
  25. }
  26. const SSLMODE_VERIFY_FULL: &str = "sslmode=verify-full";
  27. const SSLMODE_VERIFY_CA: &str = "sslmode=verify-ca";
  28. const SSLMODE_PREFER: &str = "sslmode=prefer";
  29. const SSLMODE_ALLOW: &str = "sslmode=allow";
  30. const SSLMODE_REQUIRE: &str = "sslmode=require";
  31. impl Default for SslMode {
  32. fn default() -> Self {
  33. SslMode::NoTls(NoTls {})
  34. }
  35. }
  36. impl Debug for SslMode {
  37. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  38. let debug_text = match self {
  39. Self::NoTls(_) => "NoTls",
  40. Self::NativeTls(_) => "NativeTls",
  41. };
  42. write!(f, "SslMode::{debug_text}")
  43. }
  44. }
  45. /// Postgres configuration
  46. #[derive(Clone, Debug)]
  47. pub struct PgConfig {
  48. url: String,
  49. tls: SslMode,
  50. }
  51. impl DatabaseConfig for PgConfig {
  52. fn default_timeout(&self) -> Duration {
  53. Duration::from_secs(10)
  54. }
  55. fn max_size(&self) -> usize {
  56. 20
  57. }
  58. }
  59. impl From<&str> for PgConfig {
  60. fn from(conn_str: &str) -> Self {
  61. fn build_tls(accept_invalid_certs: bool, accept_invalid_hostnames: bool) -> SslMode {
  62. let mut builder = TlsConnector::builder();
  63. if accept_invalid_certs {
  64. builder.danger_accept_invalid_certs(true);
  65. }
  66. if accept_invalid_hostnames {
  67. builder.danger_accept_invalid_hostnames(true);
  68. }
  69. match builder.build() {
  70. Ok(connector) => {
  71. let make_tls_connector = MakeTlsConnector::new(connector);
  72. SslMode::NativeTls(make_tls_connector)
  73. }
  74. Err(_) => SslMode::NoTls(NoTls {}),
  75. }
  76. }
  77. let tls = if conn_str.contains(SSLMODE_VERIFY_FULL) {
  78. // Strict TLS: valid certs and hostnames required
  79. build_tls(false, false)
  80. } else if conn_str.contains(SSLMODE_VERIFY_CA) {
  81. // Verify CA, but allow invalid hostnames
  82. build_tls(false, true)
  83. } else if conn_str.contains(SSLMODE_PREFER)
  84. || conn_str.contains(SSLMODE_ALLOW)
  85. || conn_str.contains(SSLMODE_REQUIRE)
  86. {
  87. // Lenient TLS for preferred/allow/require: accept invalid certs and hostnames
  88. build_tls(true, true)
  89. } else {
  90. SslMode::NoTls(NoTls {})
  91. };
  92. PgConfig {
  93. url: conn_str.to_owned(),
  94. tls,
  95. }
  96. }
  97. }
  98. impl DatabasePool for PgConnectionPool {
  99. type Config = PgConfig;
  100. type Connection = PostgresConnection;
  101. type Error = PgError;
  102. fn new_resource(
  103. config: &Self::Config,
  104. still_valid: Arc<AtomicBool>,
  105. timeout: Duration,
  106. ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
  107. Ok(PostgresConnection::new(
  108. config.to_owned(),
  109. timeout,
  110. still_valid,
  111. ))
  112. }
  113. }
  114. /// A postgres connection
  115. #[derive(Debug)]
  116. pub struct PostgresConnection {
  117. timeout: Duration,
  118. error: Arc<Mutex<Option<cdk_common::database::Error>>>,
  119. result: Arc<OnceLock<Client>>,
  120. notify: Arc<Notify>,
  121. }
  122. impl PostgresConnection {
  123. /// Creates a new instance
  124. pub fn new(config: PgConfig, timeout: Duration, still_valid: Arc<AtomicBool>) -> Self {
  125. let failed = Arc::new(Mutex::new(None));
  126. let result = Arc::new(OnceLock::new());
  127. let notify = Arc::new(Notify::new());
  128. let error_clone = failed.clone();
  129. let result_clone = result.clone();
  130. let notify_clone = notify.clone();
  131. tokio::spawn(async move {
  132. match config.tls {
  133. SslMode::NoTls(tls) => {
  134. let (client, connection) = match connect(&config.url, tls).await {
  135. Ok((client, connection)) => (client, connection),
  136. Err(err) => {
  137. *error_clone.lock().await =
  138. Some(cdk_common::database::Error::Database(Box::new(err)));
  139. still_valid.store(false, std::sync::atomic::Ordering::Release);
  140. notify_clone.notify_waiters();
  141. return;
  142. }
  143. };
  144. tokio::spawn(async move {
  145. let _ = connection.await;
  146. still_valid.store(false, std::sync::atomic::Ordering::Release);
  147. });
  148. let _ = result_clone.set(client);
  149. notify_clone.notify_waiters();
  150. }
  151. SslMode::NativeTls(tls) => {
  152. let (client, connection) = match connect(&config.url, tls).await {
  153. Ok((client, connection)) => (client, connection),
  154. Err(err) => {
  155. *error_clone.lock().await =
  156. Some(cdk_common::database::Error::Database(Box::new(err)));
  157. still_valid.store(false, std::sync::atomic::Ordering::Release);
  158. notify_clone.notify_waiters();
  159. return;
  160. }
  161. };
  162. tokio::spawn(async move {
  163. let _ = connection.await;
  164. still_valid.store(false, std::sync::atomic::Ordering::Release);
  165. });
  166. let _ = result_clone.set(client);
  167. notify_clone.notify_waiters();
  168. }
  169. }
  170. });
  171. Self {
  172. error: failed,
  173. timeout,
  174. result,
  175. notify,
  176. }
  177. }
  178. /// Gets the wrapped instance or the connection error. The connection is returned as reference,
  179. /// and the actual error is returned once, next times a generic error would be returned
  180. async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
  181. if let Some(client) = self.result.get() {
  182. return Ok(client);
  183. }
  184. if let Some(error) = self.error.lock().await.take() {
  185. return Err(error);
  186. }
  187. if timeout(self.timeout, self.notify.notified()).await.is_err() {
  188. return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
  189. }
  190. // Check result again
  191. if let Some(client) = self.result.get() {
  192. Ok(client)
  193. } else if let Some(error) = self.error.lock().await.take() {
  194. Err(error)
  195. } else {
  196. Err(cdk_common::database::Error::Internal(
  197. "Failed connection".to_owned(),
  198. ))
  199. }
  200. }
  201. }
  202. #[async_trait::async_trait]
  203. impl DatabaseConnector for PostgresConnection {
  204. type Transaction = GenericTransactionHandler<Self>;
  205. }
  206. #[async_trait::async_trait]
  207. impl DatabaseExecutor for PostgresConnection {
  208. fn name() -> &'static str {
  209. "postgres"
  210. }
  211. async fn execute(&self, statement: Statement) -> Result<usize, Error> {
  212. pg_execute(self.inner().await?, statement).await
  213. }
  214. async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
  215. pg_fetch_one(self.inner().await?, statement).await
  216. }
  217. async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
  218. pg_fetch_all(self.inner().await?, statement).await
  219. }
  220. async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
  221. pg_pluck(self.inner().await?, statement).await
  222. }
  223. async fn batch(&self, statement: Statement) -> Result<(), Error> {
  224. pg_batch(self.inner().await?, statement).await
  225. }
  226. }
  227. /// Mint DB implementation with PostgreSQL
  228. pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
  229. /// Mint Auth database with Postgres
  230. #[cfg(feature = "auth")]
  231. pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
  232. /// Mint DB implementation with PostgresSQL
  233. pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
  234. #[cfg(test)]
  235. mod test {
  236. use cdk_common::mint_db_test;
  237. use once_cell::sync::Lazy;
  238. use tokio::sync::Mutex;
  239. use super::*;
  240. static MIGRATION_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
  241. async fn provide_db() -> MintPgDatabase {
  242. let m = MIGRATION_LOCK.lock().await;
  243. let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
  244. .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
  245. .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
  246. let db = MintPgDatabase::new(db_url.as_str())
  247. .await
  248. .expect("database");
  249. drop(m);
  250. db
  251. }
  252. mint_db_test!(provide_db);
  253. }