database.rs 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. //! Database traits definition
  2. use std::fmt::Debug;
  3. use std::marker::PhantomData;
  4. use std::ops::{Deref, DerefMut};
  5. use cdk_common::database::Error;
  6. use crate::stmt::{query, Column, Statement};
  7. /// Database Executor
  8. ///
  9. /// This trait defines the expectations of a database execution
  10. #[async_trait::async_trait]
  11. pub trait DatabaseExecutor: Debug + Sync + Send {
  12. /// Database driver name
  13. fn name() -> &'static str;
  14. /// Executes a query and returns the affected rows
  15. async fn execute(&self, statement: Statement) -> Result<usize, Error>;
  16. /// Runs the query and returns the first row or None
  17. async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error>;
  18. /// Runs the query and returns the first row or None
  19. async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error>;
  20. /// Fetches the first row and column from a query
  21. async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error>;
  22. /// Batch execution
  23. async fn batch(&self, statement: Statement) -> Result<(), Error>;
  24. }
  25. /// Database transaction trait
  26. #[async_trait::async_trait]
  27. pub trait DatabaseTransaction<DB>
  28. where
  29. DB: DatabaseExecutor,
  30. {
  31. /// Consumes the current transaction committing the changes
  32. async fn commit(conn: &mut DB) -> Result<(), Error>;
  33. /// Begin a transaction
  34. async fn begin(conn: &mut DB) -> Result<(), Error>;
  35. /// Consumes the transaction rolling back all changes
  36. async fn rollback(conn: &mut DB) -> Result<(), Error>;
  37. }
  38. /// Database connection with a transaction
  39. #[derive(Debug)]
  40. pub struct ConnectionWithTransaction<DB, W>
  41. where
  42. DB: DatabaseConnector + 'static,
  43. W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
  44. {
  45. inner: Option<W>,
  46. }
  47. impl<DB, W> ConnectionWithTransaction<DB, W>
  48. where
  49. DB: DatabaseConnector,
  50. W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
  51. {
  52. /// Creates a new transaction
  53. pub async fn new(mut inner: W) -> Result<Self, Error> {
  54. DB::Transaction::begin(inner.deref_mut()).await?;
  55. Ok(Self { inner: Some(inner) })
  56. }
  57. /// Commits the transaction consuming it and releasing the connection back to the pool (or
  58. /// disconnecting)
  59. pub async fn commit(mut self) -> Result<(), Error> {
  60. let mut conn = self
  61. .inner
  62. .take()
  63. .ok_or(Error::Internal("Missing connection".to_owned()))?;
  64. DB::Transaction::commit(&mut conn).await?;
  65. Ok(())
  66. }
  67. /// Rollback the transaction consuming it and releasing the connection back to the pool (or
  68. /// disconnecting)
  69. pub async fn rollback(mut self) -> Result<(), Error> {
  70. let mut conn = self
  71. .inner
  72. .take()
  73. .ok_or(Error::Internal("Missing connection".to_owned()))?;
  74. DB::Transaction::rollback(&mut conn).await?;
  75. Ok(())
  76. }
  77. }
  78. impl<DB, W> Drop for ConnectionWithTransaction<DB, W>
  79. where
  80. DB: DatabaseConnector,
  81. W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
  82. {
  83. fn drop(&mut self) {
  84. if let Some(mut conn) = self.inner.take() {
  85. tokio::spawn(async move {
  86. let _ = DB::Transaction::rollback(conn.deref_mut()).await;
  87. });
  88. }
  89. }
  90. }
  91. #[async_trait::async_trait]
  92. impl<DB, W> DatabaseExecutor for ConnectionWithTransaction<DB, W>
  93. where
  94. DB: DatabaseConnector,
  95. W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
  96. {
  97. fn name() -> &'static str {
  98. "Transaction"
  99. }
  100. /// Executes a query and returns the affected rows
  101. async fn execute(&self, statement: Statement) -> Result<usize, Error> {
  102. self.inner
  103. .as_ref()
  104. .ok_or(Error::Internal("Missing internal connection".to_owned()))?
  105. .execute(statement)
  106. .await
  107. }
  108. /// Runs the query and returns the first row or None
  109. async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
  110. self.inner
  111. .as_ref()
  112. .ok_or(Error::Internal("Missing internal connection".to_owned()))?
  113. .fetch_one(statement)
  114. .await
  115. }
  116. /// Runs the query and returns the first row or None
  117. async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
  118. self.inner
  119. .as_ref()
  120. .ok_or(Error::Internal("Missing internal connection".to_owned()))?
  121. .fetch_all(statement)
  122. .await
  123. }
  124. /// Fetches the first row and column from a query
  125. async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
  126. self.inner
  127. .as_ref()
  128. .ok_or(Error::Internal("Missing internal connection".to_owned()))?
  129. .pluck(statement)
  130. .await
  131. }
  132. /// Batch execution
  133. async fn batch(&self, statement: Statement) -> Result<(), Error> {
  134. self.inner
  135. .as_ref()
  136. .ok_or(Error::Internal("Missing internal connection".to_owned()))?
  137. .batch(statement)
  138. .await
  139. }
  140. }
  141. /// Generic transaction handler for SQLite
  142. pub struct GenericTransactionHandler<W>(PhantomData<W>);
  143. #[async_trait::async_trait]
  144. impl<W> DatabaseTransaction<W> for GenericTransactionHandler<W>
  145. where
  146. W: DatabaseExecutor,
  147. {
  148. /// Consumes the current transaction committing the changes
  149. async fn commit(conn: &mut W) -> Result<(), Error> {
  150. query("COMMIT")?.execute(conn).await?;
  151. Ok(())
  152. }
  153. /// Begin a transaction
  154. async fn begin(conn: &mut W) -> Result<(), Error> {
  155. query("START TRANSACTION")?.execute(conn).await?;
  156. Ok(())
  157. }
  158. /// Consumes the transaction rolling back all changes
  159. async fn rollback(conn: &mut W) -> Result<(), Error> {
  160. query("ROLLBACK")?.execute(conn).await?;
  161. Ok(())
  162. }
  163. }
  164. /// Database connector
  165. #[async_trait::async_trait]
  166. pub trait DatabaseConnector: Debug + DatabaseExecutor + Send + Sync {
  167. /// Database static trait for the database
  168. type Transaction: DatabaseTransaction<Self>
  169. where
  170. Self: Sized;
  171. }