lib.rs 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. use std::fmt::Debug;
  2. use std::sync::atomic::AtomicBool;
  3. use std::sync::{Arc, OnceLock};
  4. use std::time::Duration;
  5. use cdk_common::database::Error;
  6. use cdk_sql_common::database::{DatabaseConnector, DatabaseExecutor, GenericTransactionHandler};
  7. use cdk_sql_common::mint::SQLMintAuthDatabase;
  8. use cdk_sql_common::pool::{DatabaseConfig, DatabasePool};
  9. use cdk_sql_common::stmt::{Column, Statement};
  10. use cdk_sql_common::{SQLMintDatabase, SQLWalletDatabase};
  11. use db::{pg_batch, pg_execute, pg_fetch_all, pg_fetch_one, pg_pluck};
  12. use tokio::sync::{Mutex, Notify};
  13. use tokio::time::timeout;
  14. use tokio_postgres::{connect, Client, Error as PgError, NoTls};
  15. mod db;
  16. mod value;
  17. #[derive(Debug)]
  18. pub struct PgConnectionPool;
  19. #[derive(Clone)]
  20. pub enum SslMode {
  21. NoTls(NoTls),
  22. NativeTls(postgres_native_tls::MakeTlsConnector),
  23. }
  24. impl Default for SslMode {
  25. fn default() -> Self {
  26. SslMode::NoTls(NoTls {})
  27. }
  28. }
  29. impl Debug for SslMode {
  30. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  31. let debug_text = match self {
  32. Self::NoTls(_) => "NoTls",
  33. Self::NativeTls(_) => "NativeTls",
  34. };
  35. write!(f, "SslMode::{debug_text}")
  36. }
  37. }
  38. /// Postgres configuration
  39. #[derive(Clone, Debug)]
  40. pub struct PgConfig {
  41. url: String,
  42. tls: SslMode,
  43. }
  44. impl DatabaseConfig for PgConfig {
  45. fn default_timeout(&self) -> Duration {
  46. Duration::from_secs(10)
  47. }
  48. fn max_size(&self) -> usize {
  49. 20
  50. }
  51. }
  52. impl From<&str> for PgConfig {
  53. fn from(value: &str) -> Self {
  54. PgConfig {
  55. url: value.to_owned(),
  56. tls: Default::default(),
  57. }
  58. }
  59. }
  60. impl DatabasePool for PgConnectionPool {
  61. type Config = PgConfig;
  62. type Connection = PostgresConnection;
  63. type Error = PgError;
  64. fn new_resource(
  65. config: &Self::Config,
  66. still_valid: Arc<AtomicBool>,
  67. timeout: Duration,
  68. ) -> Result<Self::Connection, cdk_sql_common::pool::Error<Self::Error>> {
  69. Ok(PostgresConnection::new(
  70. config.to_owned(),
  71. timeout,
  72. still_valid,
  73. ))
  74. }
  75. }
  76. /// A postgres connection
  77. #[derive(Debug)]
  78. pub struct PostgresConnection {
  79. timeout: Duration,
  80. error: Arc<Mutex<Option<cdk_common::database::Error>>>,
  81. result: Arc<OnceLock<Client>>,
  82. notify: Arc<Notify>,
  83. }
  84. impl PostgresConnection {
  85. /// Creates a new instance
  86. pub fn new(config: PgConfig, timeout: Duration, still_valid: Arc<AtomicBool>) -> Self {
  87. let failed = Arc::new(Mutex::new(None));
  88. let result = Arc::new(OnceLock::new());
  89. let notify = Arc::new(Notify::new());
  90. let error_clone = failed.clone();
  91. let result_clone = result.clone();
  92. let notify_clone = notify.clone();
  93. tokio::spawn(async move {
  94. match config.tls {
  95. SslMode::NoTls(tls) => {
  96. let (client, connection) = match connect(&config.url, tls).await {
  97. Ok((client, connection)) => (client, connection),
  98. Err(err) => {
  99. *error_clone.lock().await =
  100. Some(cdk_common::database::Error::Database(Box::new(err)));
  101. still_valid.store(false, std::sync::atomic::Ordering::Release);
  102. notify_clone.notify_waiters();
  103. return;
  104. }
  105. };
  106. tokio::spawn(async move {
  107. let _ = connection.await;
  108. still_valid.store(false, std::sync::atomic::Ordering::Release);
  109. });
  110. let _ = result_clone.set(client);
  111. notify_clone.notify_waiters();
  112. }
  113. SslMode::NativeTls(tls) => {
  114. let (client, connection) = match connect(&config.url, tls).await {
  115. Ok((client, connection)) => (client, connection),
  116. Err(err) => {
  117. *error_clone.lock().await =
  118. Some(cdk_common::database::Error::Database(Box::new(err)));
  119. still_valid.store(false, std::sync::atomic::Ordering::Release);
  120. notify_clone.notify_waiters();
  121. return;
  122. }
  123. };
  124. tokio::spawn(async move {
  125. let _ = connection.await;
  126. still_valid.store(false, std::sync::atomic::Ordering::Release);
  127. });
  128. let _ = result_clone.set(client);
  129. notify_clone.notify_waiters();
  130. }
  131. }
  132. });
  133. Self {
  134. error: failed,
  135. timeout,
  136. result,
  137. notify,
  138. }
  139. }
  140. /// Gets the wrapped instance or the connection error. The connection is returned as reference,
  141. /// and the actual error is returned once, next times a generic error would be returned
  142. async fn inner(&self) -> Result<&Client, cdk_common::database::Error> {
  143. if let Some(client) = self.result.get() {
  144. return Ok(client);
  145. }
  146. if let Some(error) = self.error.lock().await.take() {
  147. return Err(error);
  148. }
  149. if timeout(self.timeout, self.notify.notified()).await.is_err() {
  150. return Err(cdk_common::database::Error::Internal("Timeout".to_owned()));
  151. }
  152. // Check result again
  153. if let Some(client) = self.result.get() {
  154. Ok(client)
  155. } else if let Some(error) = self.error.lock().await.take() {
  156. Err(error)
  157. } else {
  158. Err(cdk_common::database::Error::Internal(
  159. "Failed connection".to_owned(),
  160. ))
  161. }
  162. }
  163. }
  164. #[async_trait::async_trait]
  165. impl DatabaseConnector for PostgresConnection {
  166. type Transaction = GenericTransactionHandler<Self>;
  167. }
  168. #[async_trait::async_trait]
  169. impl DatabaseExecutor for PostgresConnection {
  170. fn name() -> &'static str {
  171. "postgres"
  172. }
  173. async fn execute(&self, statement: Statement) -> Result<usize, Error> {
  174. pg_execute(self.inner().await?, statement).await
  175. }
  176. async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
  177. pg_fetch_one(self.inner().await?, statement).await
  178. }
  179. async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
  180. pg_fetch_all(self.inner().await?, statement).await
  181. }
  182. async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
  183. pg_pluck(self.inner().await?, statement).await
  184. }
  185. async fn batch(&self, statement: Statement) -> Result<(), Error> {
  186. pg_batch(self.inner().await?, statement).await
  187. }
  188. }
  189. /// Mint DB implementation with PostgreSQL
  190. pub type MintPgDatabase = SQLMintDatabase<PgConnectionPool>;
  191. /// Mint Auth database with Postgres
  192. #[cfg(feature = "auth")]
  193. pub type MintPgAuthDatabase = SQLMintAuthDatabase<PgConnectionPool>;
  194. /// Mint DB implementation with PostgresSQL
  195. pub type WalletPgDatabase = SQLWalletDatabase<PgConnectionPool>;
  196. #[cfg(test)]
  197. mod test {
  198. use cdk_common::mint_db_test;
  199. use once_cell::sync::Lazy;
  200. use tokio::sync::Mutex;
  201. use super::*;
  202. static MIGRATION_LOCK: Lazy<Mutex<()>> = Lazy::new(|| Mutex::new(()));
  203. async fn provide_db() -> MintPgDatabase {
  204. let m = MIGRATION_LOCK.lock().await;
  205. let db_url = std::env::var("CDK_MINTD_DATABASE_URL")
  206. .or_else(|_| std::env::var("PG_DB_URL")) // Fallback for compatibility
  207. .unwrap_or("host=localhost user=test password=test dbname=testdb port=5433".to_owned());
  208. let db = MintPgDatabase::new(db_url.as_str())
  209. .await
  210. .expect("database");
  211. drop(m);
  212. db
  213. }
  214. mint_db_test!(provide_db);
  215. }