mod.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. use crate::{
  2. amount::AmountCents, asset::AssetId, storage::Error, AccountId, Amount, Asset, AssetManager,
  3. Payment, PaymentId, Status, Storage, Transaction, TransactionId,
  4. };
  5. use futures::TryStreamExt;
  6. use sqlx::{sqlite::SqliteRow, Executor, Row};
  7. use std::{collections::HashMap, marker::PhantomData};
  8. mod batch;
  9. pub use batch::Batch;
  10. pub struct Sqlite<'a> {
  11. db: sqlx::SqlitePool,
  12. asset_manager: AssetManager,
  13. _phantom: PhantomData<&'a ()>,
  14. }
  15. impl<'a> Sqlite<'a> {
  16. pub fn new(db: sqlx::SqlitePool, asset_manager: AssetManager) -> Self {
  17. Self {
  18. db,
  19. asset_manager,
  20. _phantom: PhantomData,
  21. }
  22. }
  23. pub async fn setup(&self) -> Result<(), sqlx::Error> {
  24. let mut x = self.db.begin().await?;
  25. x.execute(
  26. r#"CREATE TABLE IF NOT EXISTS "payments" (
  27. "transaction_id" VARCHAR(66) NOT NULL,
  28. "position_id" INTEGER NOT NULL,
  29. "asset_id" TEXT NOT NULL,
  30. "cents" TEXT NOT NULL,
  31. "status" INTEGER NOT NULL,
  32. "to" VARCHAR(71) NOT NULL,
  33. "spent_by" TEXT,
  34. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  35. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  36. PRIMARY KEY ("transaction_id", "position_id")
  37. );
  38. CREATE INDEX IF NOT EXISTS payments_to ON payments ("to", "asset_id", "status", "spent_by");
  39. CREATE TABLE IF NOT EXISTS "transactions" (
  40. "transaction_id" VARCHAR(66) NOT NULL,
  41. "status" INTEGER NOT NULL,
  42. "reference" TEXT NOT NULL,
  43. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  44. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  45. PRIMARY KEY ("transaction_id")
  46. );
  47. CREATE TABLE IF NOT EXISTS "transaction_payments" (
  48. "transaction_id" VARCHAR(66) NOT NULL,
  49. "payment_transaction_id" VARCHAR(66) NOT NULL,
  50. "payment_position_id" INTEGER NOT NULL,
  51. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  52. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  53. PRIMARY KEY ("transaction_id", "payment_transaction_id", "payment_position_id")
  54. );
  55. "#,
  56. )
  57. .await
  58. .expect("valid");
  59. x.commit().await?;
  60. Ok(())
  61. }
  62. #[inline]
  63. fn sql_row_to_payment(&self, row: SqliteRow) -> Result<Payment, Error> {
  64. let id = PaymentId {
  65. transaction: row
  66. .try_get::<String, usize>(0)
  67. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?
  68. .as_str()
  69. .try_into()
  70. .map_err(|_| Error::Storage("Invalid transaction_id length".to_string()))?,
  71. position: row
  72. .try_get::<i64, usize>(1)
  73. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?
  74. .try_into()
  75. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?,
  76. };
  77. let cents = row
  78. .try_get::<String, usize>(3)
  79. .map_err(|_| Error::Storage("Invalid cents".to_string()))?
  80. .parse::<i128>()
  81. .map_err(|_| Error::Storage("Invalid cents".to_string()))?;
  82. Ok(Payment {
  83. id,
  84. amount: self
  85. .asset_manager
  86. .asset(
  87. row.try_get::<String, usize>(2)
  88. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?
  89. .parse()
  90. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?,
  91. )
  92. .map_err(|e| Error::Storage(e.to_string()))?
  93. .new_amount(cents),
  94. to: row
  95. .try_get::<String, usize>(4)
  96. .map_err(|_| Error::Storage("Invalid `to`".to_string()))?
  97. .as_str()
  98. .try_into()
  99. .map_err(|_| Error::Storage("Invalid `to`".to_string()))?,
  100. status: row
  101. .try_get::<u32, usize>(5)
  102. .map_err(|_| Error::Storage("Invalid `status`".to_string()))?
  103. .try_into()
  104. .map_err(|_| Error::Storage("Invalid status".to_string()))?,
  105. spent_by: row
  106. .try_get::<Option<String>, usize>(6)
  107. .map_err(|_| Error::Storage("Invalid spent_by".to_string()))?
  108. .map(|s| s.as_str().try_into())
  109. .transpose()
  110. .map_err(|_| Error::Storage("Invalid spent_by".to_string()))?,
  111. })
  112. }
  113. }
  114. #[async_trait::async_trait]
  115. impl<'a> Storage<'a, Batch<'a>> for Sqlite<'a> {
  116. async fn begin(&'a self) -> Result<Batch<'a>, Error> {
  117. self.db
  118. .begin()
  119. .await
  120. .map(|x| Batch::new(x))
  121. .map_err(|x| Error::Storage(x.to_string()))
  122. }
  123. async fn get_payment(&self, id: PaymentId) -> Result<Payment, Error> {
  124. let mut conn = self
  125. .db
  126. .acquire()
  127. .await
  128. .map_err(|e| Error::Storage(e.to_string()))?;
  129. let row = sqlx::query(
  130. r#"
  131. SELECT
  132. "p"."transaction_id",
  133. "p"."position_id",
  134. "p"."asset_id",
  135. "p"."cents",
  136. "p"."to",
  137. "p"."status",
  138. "p"."spent_by"
  139. FROM
  140. "payments" "p"
  141. WHERE
  142. "p"."transaction_id" = ?
  143. AND "p"."position_id" = ?
  144. LIMIT 1
  145. "#,
  146. )
  147. .bind(id.transaction.to_string())
  148. .bind(id.position.to_string())
  149. .fetch_optional(&mut *conn)
  150. .await
  151. .map_err(|e| Error::Storage(e.to_string()))?
  152. .ok_or(Error::NotFound)?;
  153. self.sql_row_to_payment(row)
  154. }
  155. async fn get_balance(&self, account: &AccountId) -> Result<Vec<Amount>, Error> {
  156. let mut conn = self
  157. .db
  158. .acquire()
  159. .await
  160. .map_err(|e| Error::Storage(e.to_string()))?;
  161. let mut result = sqlx::query(
  162. r#"
  163. SELECT
  164. "asset_id",
  165. "cents"
  166. FROM
  167. "payments"
  168. WHERE
  169. "to" = ? AND status = ? AND "spent_by" IS NULL
  170. "#,
  171. )
  172. .bind(account.to_string())
  173. .bind::<u32>(Status::Settled.into())
  174. .fetch(&mut *conn);
  175. let mut balances = HashMap::<Asset, Amount>::new();
  176. while let Some(row) = result
  177. .try_next()
  178. .await
  179. .map_err(|e| Error::Storage(e.to_string()))?
  180. {
  181. let asset = self
  182. .asset_manager
  183. .asset(
  184. row.try_get::<String, usize>(0)
  185. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?
  186. .parse()
  187. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?,
  188. )
  189. .map_err(|e| Error::Storage(e.to_string()))?;
  190. let cents = row
  191. .try_get::<String, usize>(1)
  192. .map_err(|_| Error::Storage("Invalid cents".to_string()))?
  193. .parse::<i128>()
  194. .map_err(|_| Error::Storage("Invalid cents".to_string()))?;
  195. let new_amount = asset.new_amount(cents);
  196. if let Some(amount) = balances.get_mut(&asset) {
  197. *amount = amount
  198. .checked_add(&new_amount)
  199. .ok_or(Error::Storage("amount overflow".to_owned()))?;
  200. } else {
  201. balances.insert(asset, new_amount);
  202. }
  203. }
  204. Ok(balances.into_iter().map(|(_, v)| v).collect())
  205. }
  206. async fn get_unspent_payments(
  207. &self,
  208. account: &AccountId,
  209. asset: AssetId,
  210. mut target_amount: AmountCents,
  211. ) -> Result<Vec<Payment>, Error> {
  212. let mut conn = self
  213. .db
  214. .acquire()
  215. .await
  216. .map_err(|e| Error::Storage(e.to_string()))?;
  217. let mut result = sqlx::query(
  218. r#"
  219. SELECT
  220. "p"."transaction_id",
  221. "p"."position_id",
  222. "p"."asset_id",
  223. "p"."cents",
  224. "p"."to",
  225. "p"."status",
  226. "p"."spent_by"
  227. FROM
  228. "payments" as "p"
  229. WHERE
  230. "p"."to" = ? AND "p"."asset_id" = ? AND status = ? AND "p"."spent_by" IS NULL
  231. ORDER BY cents ASC
  232. "#,
  233. )
  234. .bind(account.to_string())
  235. .bind(asset.to_string())
  236. .bind::<u32>(Status::Settled.into())
  237. .fetch(&mut *conn);
  238. let mut to_return = vec![];
  239. while let Some(row) = result
  240. .try_next()
  241. .await
  242. .map_err(|e| Error::Storage(e.to_string()))?
  243. {
  244. let row = self.sql_row_to_payment(row)?;
  245. target_amount -= row.amount.cents();
  246. to_return.push(row);
  247. if target_amount <= 0 {
  248. break;
  249. }
  250. }
  251. if target_amount <= 0 {
  252. Ok(to_return)
  253. } else {
  254. Err(Error::NotEnoughUnspentPayments(target_amount))
  255. }
  256. }
  257. async fn get_transaction(&self, transaction_id: &TransactionId) -> Result<Transaction, Error> {
  258. let mut conn = self
  259. .db
  260. .acquire()
  261. .await
  262. .map_err(|e| Error::Storage(e.to_string()))?;
  263. let transaction_row = sqlx::query(
  264. r#"
  265. SELECT
  266. "t"."status",
  267. "t"."reference"
  268. FROM
  269. "transactions" "t"
  270. WHERE
  271. "t"."transaction_id" = ?
  272. "#,
  273. )
  274. .bind(transaction_id.to_string())
  275. .fetch_optional(&mut *conn)
  276. .await
  277. .map_err(|e| Error::Storage(e.to_string()))?
  278. .ok_or(Error::NotFound)?;
  279. let mut spend_result = sqlx::query(
  280. r#"
  281. SELECT
  282. "p"."transaction_id",
  283. "p"."position_id",
  284. "p"."asset_id",
  285. "p"."cents",
  286. "p"."to",
  287. "p"."status",
  288. "p"."spent_by"
  289. FROM
  290. "payments" "p"
  291. INNER JOIN
  292. "transaction_payments" "tp"
  293. ON (
  294. "tp"."payment_transaction_id" = "p"."transaction_id"
  295. AND "tp"."payment_position_id" = "p"."position_id"
  296. )
  297. WHERE
  298. "tp"."transaction_id" = ?
  299. "#,
  300. )
  301. .bind(transaction_id.to_string())
  302. .fetch(&mut *conn);
  303. let mut spend = vec![];
  304. while let Some(row) = spend_result
  305. .try_next()
  306. .await
  307. .map_err(|e| Error::Storage(e.to_string()))?
  308. {
  309. spend.push(self.sql_row_to_payment(row)?);
  310. }
  311. drop(spend_result);
  312. let mut create_result = sqlx::query(
  313. r#"
  314. SELECT
  315. "p"."transaction_id",
  316. "p"."position_id",
  317. "p"."asset_id",
  318. "p"."cents",
  319. "p"."to",
  320. "p"."status",
  321. "p"."spent_by"
  322. FROM
  323. "payments" "p"
  324. WHERE
  325. "p"."transaction_id" = ?
  326. "#,
  327. )
  328. .bind(transaction_id.to_string())
  329. .fetch(&mut *conn);
  330. let mut create = vec![];
  331. while let Some(row) = create_result
  332. .try_next()
  333. .await
  334. .map_err(|e| Error::Storage(e.to_string()))?
  335. {
  336. create.push(self.sql_row_to_payment(row)?);
  337. }
  338. let status = transaction_row
  339. .try_get::<u32, usize>(0)
  340. .map_err(|_| Error::Storage("Invalid status".to_string()))?
  341. .try_into()
  342. .map_err(|_| Error::Storage("Invalid status".to_string()))?;
  343. let reference = transaction_row
  344. .try_get::<String, usize>(1)
  345. .map_err(|_| Error::Storage("Invalid reference".to_string()))?;
  346. Ok(Transaction {
  347. id: transaction_id.clone(),
  348. is_external_deposit: spend.is_empty(),
  349. spend,
  350. create,
  351. status,
  352. reference,
  353. })
  354. }
  355. }