batch.rs 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. use crate::{
  2. changelog::Changelog,
  3. storage::{self, Error},
  4. transaction, AccountId, Payment, PaymentId, Status, Transaction, TransactionId,
  5. };
  6. use serde::{de::DeserializeOwned, Serialize};
  7. use sqlx::{Row, Sqlite, Transaction as SqlxTransaction};
  8. use std::marker::PhantomData;
  9. /// Creates a new Batch for SQLite
  10. ///
  11. /// Batches are a group of updates to the databases, in which all of the are
  12. /// executed or none. Same concept as a rdbms transaction.
  13. pub struct Batch<'a> {
  14. inner: SqlxTransaction<'a, Sqlite>,
  15. x: PhantomData<&'a ()>,
  16. }
  17. impl<'a> Batch<'a> {
  18. /// Creates a new instance
  19. pub fn new(inner: SqlxTransaction<'a, Sqlite>) -> Batch<'a> {
  20. Self {
  21. inner,
  22. x: PhantomData,
  23. }
  24. }
  25. }
  26. #[async_trait::async_trait]
  27. impl<'a> storage::Batch<'a> for Batch<'a> {
  28. async fn rollback(self) -> Result<(), Error> {
  29. self.inner
  30. .rollback()
  31. .await
  32. .map_err(|e| Error::Storage(e.to_string()))
  33. }
  34. async fn commit(self) -> Result<(), Error> {
  35. self.inner
  36. .commit()
  37. .await
  38. .map_err(|e| Error::Storage(e.to_string()))
  39. }
  40. async fn store_changelogs<T: DeserializeOwned + Serialize + Send + Sync>(
  41. &mut self,
  42. changelog: &[Changelog<T>],
  43. ) -> Result<(), Error> {
  44. for change in changelog.iter() {
  45. let change_bytes =
  46. bincode::serialize(&change.change).map_err(|e| Error::Storage(e.to_string()))?;
  47. sqlx::query(
  48. r#"
  49. INSERT INTO "changelog"("id", "previous", "object_id", "change", "created_at")
  50. VALUES(?, ?, ?, ?, ?)
  51. ON CONFLICT("id")
  52. DO NOTHING
  53. "#,
  54. )
  55. .bind(change.id().map_err(|e| Error::Storage(e.to_string()))?)
  56. .bind(&change.previous)
  57. .bind(&change.object_id)
  58. .bind(change_bytes)
  59. .bind(change.updated_at)
  60. .execute(&mut *self.inner)
  61. .await
  62. .map_err(|e| Error::Storage(e.to_string()))?;
  63. }
  64. Ok(())
  65. }
  66. async fn update_payment(
  67. &mut self,
  68. payment_id: &PaymentId,
  69. spent_by: &TransactionId,
  70. spent_status: Status,
  71. ) -> Result<(), Error> {
  72. let settled: u32 = Status::Settled.into();
  73. let spent_by_val = if spent_status.is_rollback() {
  74. None
  75. } else {
  76. Some(spent_by.to_string())
  77. };
  78. let spent_by_status_val: Option<u32> = if spent_status.is_rollback() {
  79. None
  80. } else {
  81. Some(spent_status.into())
  82. };
  83. let result = sqlx::query(
  84. r#"
  85. UPDATE
  86. "payments"
  87. SET
  88. "spent_by" = ?,
  89. "spent_by_status" = ?
  90. WHERE
  91. "transaction_id" = ?
  92. AND "position_id" = ?
  93. AND "status" = ?
  94. AND ("spent_by_status" IS NULL OR "spent_by_status" != ?)
  95. AND ("spent_by" = ? OR "spent_by" IS NULL)
  96. "#,
  97. )
  98. .bind(spent_by_val)
  99. .bind(spent_by_status_val)
  100. .bind(payment_id.transaction.to_string())
  101. .bind(payment_id.position.to_string())
  102. .bind(settled)
  103. .bind(settled)
  104. .bind(spent_by.to_string())
  105. .execute(&mut *self.inner)
  106. .await
  107. .map_err(|e| Error::SpendPayment(e.to_string()))?;
  108. if result.rows_affected() == 1 {
  109. Ok(())
  110. } else {
  111. Err(Error::NoUpdate)
  112. }
  113. }
  114. async fn get_payment_status(
  115. &mut self,
  116. transaction_id: &TransactionId,
  117. ) -> Result<Option<Status>, Error> {
  118. let row = sqlx::query(
  119. r#"
  120. SELECT
  121. "p"."status"
  122. FROM
  123. "payments" "p"
  124. WHERE
  125. "p"."transaction_id" = ?
  126. LIMIT 1
  127. "#,
  128. )
  129. .bind(transaction_id.to_string())
  130. .fetch_optional(&mut *self.inner)
  131. .await
  132. .map_err(|e| Error::Storage(e.to_string()))?;
  133. if let Some(row) = row {
  134. let status = row
  135. .try_get::<u32, usize>(0)
  136. .map_err(|_| Error::Storage("failed to parse status".to_owned()))?;
  137. status
  138. .try_into()
  139. .map(|x| Some(x))
  140. .map_err(|_| Error::Storage("failed to parse status".to_owned()))
  141. } else {
  142. return Ok(None);
  143. }
  144. }
  145. async fn store_new_payment(&mut self, payment: &Payment) -> Result<(), Error> {
  146. sqlx::query(
  147. r#"
  148. INSERT INTO payments("transaction_id", "position_id", "to", "cents", "asset_id", "status", "spent_by_status")
  149. VALUES (?, ?, ?, ?, ?, ?, ?)
  150. ON CONFLICT("transaction_id", "position_id")
  151. DO UPDATE SET "status" = excluded."status"
  152. "#,
  153. )
  154. .bind(payment.id.transaction.to_string())
  155. .bind(payment.id.position.to_string())
  156. .bind(payment.to.to_string())
  157. .bind(payment.amount.cents().to_string())
  158. .bind(payment.amount.asset().id)
  159. .bind::<u32>((&payment.status).into())
  160. .bind::<Option<u32>>(None)
  161. .execute(&mut *self.inner)
  162. .await
  163. .map_err(|e| Error::Storage(e.to_string()))?;
  164. Ok(())
  165. }
  166. async fn store_transaction(&mut self, transaction: &Transaction) -> Result<(), Error> {
  167. sqlx::query(
  168. r#"
  169. INSERT INTO "transactions"("transaction_id", "status", "type", "reference", "last_version", "created_at", "updated_at")
  170. VALUES(?, ?, ?, ?, ?, ?, ?)
  171. ON CONFLICT("transaction_id")
  172. DO UPDATE SET "status" = excluded."status", "updated_at" = excluded."updated_at", "last_version" = excluded."last_version"
  173. "#,
  174. )
  175. .bind(transaction.id().to_string())
  176. .bind::<u32>(transaction.status().into())
  177. .bind::<u32>(transaction.typ().into())
  178. .bind(transaction.reference())
  179. .bind(transaction.last_version())
  180. .bind(transaction.created_at())
  181. .bind(transaction.updated_at())
  182. .execute(&mut *self.inner)
  183. .await
  184. .map_err(|e| Error::Storage(e.to_string()))?;
  185. for payment in transaction.spends().iter() {
  186. sqlx::query(
  187. r#"
  188. INSERT INTO "transaction_input_payments"("transaction_id", "payment_transaction_id", "payment_position_id")
  189. VALUES(?, ?, ?)
  190. ON CONFLICT("transaction_id", "payment_transaction_id", "payment_position_id")
  191. DO NOTHING
  192. "#,
  193. )
  194. .bind(transaction.id().to_string())
  195. .bind(payment.id.transaction.to_string())
  196. .bind(payment.id.position.to_string())
  197. .execute(&mut *self.inner)
  198. .await
  199. .map_err(|e| Error::Storage(e.to_string()))?;
  200. }
  201. Ok(())
  202. }
  203. async fn relate_account_to_transaction(
  204. &mut self,
  205. transaction: &Transaction,
  206. account: &AccountId,
  207. ) -> Result<(), Error> {
  208. sqlx::query(
  209. r#"
  210. INSERT INTO "transaction_accounts"("transaction_id", "account_id", "type", "created_at", "updated_at")
  211. VALUES(?, ?, ?, ?, ?)
  212. ON CONFLICT("transaction_id", "account_id")
  213. DO NOTHING
  214. "#,
  215. )
  216. .bind(transaction.id().to_string())
  217. .bind(account.to_string())
  218. .bind::<u32>(transaction.typ().into())
  219. .bind(transaction.created_at())
  220. .bind(transaction.updated_at())
  221. .execute(&mut *self.inner)
  222. .await
  223. .map_err(|e| Error::Storage(e.to_string()))?;
  224. Ok(())
  225. }
  226. }