mod.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. use crate::{
  2. amount::AmountCents,
  3. asset::AssetId,
  4. storage::{Error, Storage},
  5. transaction::{from_db, Type},
  6. AccountId, Amount, Asset, AssetManager, Payment, PaymentId, Status, TransactionId,
  7. };
  8. use chrono::NaiveDateTime;
  9. use futures::TryStreamExt;
  10. use sqlx::{sqlite::SqliteRow, Executor, Row};
  11. use std::{collections::HashMap, marker::PhantomData};
  12. mod batch;
  13. pub use batch::Batch;
  14. pub struct Sqlite<'a> {
  15. db: sqlx::SqlitePool,
  16. asset_manager: AssetManager,
  17. _phantom: PhantomData<&'a ()>,
  18. }
  19. impl<'a> Sqlite<'a> {
  20. pub fn new(db: sqlx::SqlitePool, asset_manager: AssetManager) -> Self {
  21. Self {
  22. db,
  23. asset_manager,
  24. _phantom: PhantomData,
  25. }
  26. }
  27. pub async fn setup(&self) -> Result<(), sqlx::Error> {
  28. let mut x = self.db.begin().await?;
  29. x.execute(
  30. r#"CREATE TABLE IF NOT EXISTS "payments" (
  31. "transaction_id" VARCHAR(66) NOT NULL,
  32. "position_id" INTEGER NOT NULL,
  33. "asset_id" TEXT NOT NULL,
  34. "cents" TEXT NOT NULL,
  35. "status" INTEGER NOT NULL,
  36. "to" VARCHAR(71) NOT NULL,
  37. "spent_by" TEXT,
  38. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  39. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  40. PRIMARY KEY ("transaction_id", "position_id")
  41. );
  42. CREATE INDEX IF NOT EXISTS payments_to ON payments ("to", "asset_id", "status", "spent_by");
  43. CREATE TABLE IF NOT EXISTS "transactions" (
  44. "transaction_id" VARCHAR(66) NOT NULL,
  45. "status" INTEGER NOT NULL,
  46. "type" INTEGER NOT NULL,
  47. "reference" TEXT NOT NULL,
  48. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  49. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  50. PRIMARY KEY ("transaction_id")
  51. );
  52. CREATE TABLE IF NOT EXISTS "transaction_accounts" (
  53. "transaction_id" VARCHAR(66) NOT NULL,
  54. "account_id" VARCHAR(71) NOT NULL,
  55. "type" INTEGER NOT NULL,
  56. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  57. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  58. PRIMARY KEY("account_id", "transaction_id")
  59. );
  60. CREATE INDEX "type" ON "transaction_accounts" ("account_id", "type", "created_at");
  61. CREATE TABLE IF NOT EXISTS "transaction_input_payments" (
  62. "transaction_id" VARCHAR(66) NOT NULL,
  63. "payment_transaction_id" VARCHAR(66) NOT NULL,
  64. "payment_position_id" INTEGER NOT NULL,
  65. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  66. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  67. PRIMARY KEY ("transaction_id", "payment_transaction_id", "payment_position_id")
  68. );
  69. "#,
  70. )
  71. .await
  72. .expect("valid");
  73. x.commit().await?;
  74. Ok(())
  75. }
  76. #[inline]
  77. fn sql_row_to_payment(&self, row: SqliteRow) -> Result<Payment, Error> {
  78. let id = PaymentId {
  79. transaction: row
  80. .try_get::<String, usize>(0)
  81. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?
  82. .as_str()
  83. .try_into()
  84. .map_err(|_| Error::Storage("Invalid transaction_id length".to_string()))?,
  85. position: row
  86. .try_get::<i64, usize>(1)
  87. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?
  88. .try_into()
  89. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?,
  90. };
  91. let cents = row
  92. .try_get::<String, usize>(3)
  93. .map_err(|_| Error::Storage("Invalid cents".to_string()))?
  94. .parse::<i128>()
  95. .map_err(|_| Error::Storage("Invalid cents".to_string()))?;
  96. Ok(Payment {
  97. id,
  98. amount: self
  99. .asset_manager
  100. .asset(
  101. row.try_get::<String, usize>(2)
  102. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?
  103. .parse()
  104. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?,
  105. )
  106. .map_err(|e| Error::Storage(e.to_string()))?
  107. .new_amount(cents),
  108. to: row
  109. .try_get::<String, usize>(4)
  110. .map_err(|_| Error::Storage("Invalid `to`".to_string()))?
  111. .as_str()
  112. .try_into()
  113. .map_err(|_| Error::Storage("Invalid `to`".to_string()))?,
  114. status: row
  115. .try_get::<u32, usize>(5)
  116. .map_err(|_| Error::Storage("Invalid `status`".to_string()))?
  117. .try_into()
  118. .map_err(|_| Error::Storage("Invalid status".to_string()))?,
  119. spent_by: row
  120. .try_get::<Option<String>, usize>(6)
  121. .map_err(|_| Error::Storage("Invalid spent_by".to_string()))?
  122. .map(|s| s.as_str().try_into())
  123. .transpose()
  124. .map_err(|_| Error::Storage("Invalid spent_by".to_string()))?,
  125. })
  126. }
  127. }
  128. #[async_trait::async_trait]
  129. impl<'a> Storage<'a, Batch<'a>> for Sqlite<'a> {
  130. async fn begin(&'a self) -> Result<Batch<'a>, Error> {
  131. self.db
  132. .begin()
  133. .await
  134. .map(|x| Batch::new(x))
  135. .map_err(|x| Error::Storage(x.to_string()))
  136. }
  137. async fn get_payment(&self, id: PaymentId) -> Result<Payment, Error> {
  138. let mut conn = self
  139. .db
  140. .acquire()
  141. .await
  142. .map_err(|e| Error::Storage(e.to_string()))?;
  143. let row = sqlx::query(
  144. r#"
  145. SELECT
  146. "p"."transaction_id",
  147. "p"."position_id",
  148. "p"."asset_id",
  149. "p"."cents",
  150. "p"."to",
  151. "p"."status",
  152. "p"."spent_by"
  153. FROM
  154. "payments" "p"
  155. WHERE
  156. "p"."transaction_id" = ?
  157. AND "p"."position_id" = ?
  158. LIMIT 1
  159. "#,
  160. )
  161. .bind(id.transaction.to_string())
  162. .bind(id.position.to_string())
  163. .fetch_optional(&mut *conn)
  164. .await
  165. .map_err(|e| Error::Storage(e.to_string()))?
  166. .ok_or(Error::NotFound)?;
  167. self.sql_row_to_payment(row)
  168. }
  169. async fn get_balance(&self, account: &AccountId) -> Result<Vec<Amount>, Error> {
  170. let mut conn = self
  171. .db
  172. .acquire()
  173. .await
  174. .map_err(|e| Error::Storage(e.to_string()))?;
  175. let mut result = sqlx::query(
  176. r#"
  177. SELECT
  178. "asset_id",
  179. "cents"
  180. FROM
  181. "payments"
  182. WHERE
  183. "to" = ? AND status = ? AND "spent_by" IS NULL
  184. "#,
  185. )
  186. .bind(account.to_string())
  187. .bind::<u32>(Status::Settled.into())
  188. .fetch(&mut *conn);
  189. let mut balances = HashMap::<Asset, Amount>::new();
  190. while let Some(row) = result
  191. .try_next()
  192. .await
  193. .map_err(|e| Error::Storage(e.to_string()))?
  194. {
  195. let asset = self
  196. .asset_manager
  197. .asset(
  198. row.try_get::<String, usize>(0)
  199. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?
  200. .parse()
  201. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?,
  202. )
  203. .map_err(|e| Error::Storage(e.to_string()))?;
  204. let cents = row
  205. .try_get::<String, usize>(1)
  206. .map_err(|_| Error::Storage("Invalid cents".to_string()))?
  207. .parse::<i128>()
  208. .map_err(|_| Error::Storage("Invalid cents".to_string()))?;
  209. let new_amount = asset.new_amount(cents);
  210. if let Some(amount) = balances.get_mut(&asset) {
  211. *amount = amount
  212. .checked_add(&new_amount)
  213. .ok_or(Error::Storage("amount overflow".to_owned()))?;
  214. } else {
  215. balances.insert(asset, new_amount);
  216. }
  217. }
  218. Ok(balances.into_iter().map(|(_, v)| v).collect())
  219. }
  220. async fn get_unspent_payments(
  221. &self,
  222. account: &AccountId,
  223. asset: AssetId,
  224. mut target_amount: AmountCents,
  225. ) -> Result<Vec<Payment>, Error> {
  226. let mut conn = self
  227. .db
  228. .acquire()
  229. .await
  230. .map_err(|e| Error::Storage(e.to_string()))?;
  231. let mut result = sqlx::query(
  232. r#"
  233. SELECT
  234. "p"."transaction_id",
  235. "p"."position_id",
  236. "p"."asset_id",
  237. "p"."cents",
  238. "p"."to",
  239. "p"."status",
  240. "p"."spent_by"
  241. FROM
  242. "payments" as "p"
  243. WHERE
  244. "p"."to" = ? AND "p"."asset_id" = ? AND status = ? AND "p"."spent_by" IS NULL
  245. ORDER BY cents ASC
  246. "#,
  247. )
  248. .bind(account.to_string())
  249. .bind(asset.to_string())
  250. .bind::<u32>(Status::Settled.into())
  251. .fetch(&mut *conn);
  252. let mut to_return = vec![];
  253. while let Some(row) = result
  254. .try_next()
  255. .await
  256. .map_err(|e| Error::Storage(e.to_string()))?
  257. {
  258. let row = self.sql_row_to_payment(row)?;
  259. target_amount -= row.amount.cents();
  260. to_return.push(row);
  261. if target_amount <= 0 {
  262. break;
  263. }
  264. }
  265. if target_amount <= 0 {
  266. Ok(to_return)
  267. } else {
  268. Err(Error::NotEnoughUnspentPayments(target_amount))
  269. }
  270. }
  271. async fn get_transaction(
  272. &self,
  273. transaction_id: &TransactionId,
  274. ) -> Result<from_db::Transaction, Error> {
  275. let mut conn = self
  276. .db
  277. .acquire()
  278. .await
  279. .map_err(|e| Error::Storage(e.to_string()))?;
  280. let transaction_row = sqlx::query(
  281. r#"
  282. SELECT
  283. "t"."status",
  284. "t"."type",
  285. "t"."reference",
  286. "t"."created_at",
  287. "t"."updated_at"
  288. FROM
  289. "transactions" "t"
  290. WHERE
  291. "t"."transaction_id" = ?
  292. "#,
  293. )
  294. .bind(transaction_id.to_string())
  295. .fetch_optional(&mut *conn)
  296. .await
  297. .map_err(|e| Error::Storage(e.to_string()))?
  298. .ok_or(Error::NotFound)?;
  299. let mut spend_result = sqlx::query(
  300. r#"
  301. SELECT
  302. "p"."transaction_id",
  303. "p"."position_id",
  304. "p"."asset_id",
  305. "p"."cents",
  306. "p"."to",
  307. "p"."status",
  308. "p"."spent_by"
  309. FROM
  310. "payments" "p"
  311. INNER JOIN
  312. "transaction_input_payments" "tp"
  313. ON (
  314. "tp"."payment_transaction_id" = "p"."transaction_id"
  315. AND "tp"."payment_position_id" = "p"."position_id"
  316. )
  317. WHERE
  318. "tp"."transaction_id" = ?
  319. "#,
  320. )
  321. .bind(transaction_id.to_string())
  322. .fetch(&mut *conn);
  323. let mut spend = vec![];
  324. while let Some(row) = spend_result
  325. .try_next()
  326. .await
  327. .map_err(|e| Error::Storage(e.to_string()))?
  328. {
  329. spend.push(self.sql_row_to_payment(row)?);
  330. }
  331. drop(spend_result);
  332. let mut create_result = sqlx::query(
  333. r#"
  334. SELECT
  335. "p"."transaction_id",
  336. "p"."position_id",
  337. "p"."asset_id",
  338. "p"."cents",
  339. "p"."to",
  340. "p"."status",
  341. "p"."spent_by"
  342. FROM
  343. "payments" "p"
  344. WHERE
  345. "p"."transaction_id" = ?
  346. "#,
  347. )
  348. .bind(transaction_id.to_string())
  349. .fetch(&mut *conn);
  350. let mut create = vec![];
  351. while let Some(row) = create_result
  352. .try_next()
  353. .await
  354. .map_err(|e| Error::Storage(e.to_string()))?
  355. {
  356. create.push(self.sql_row_to_payment(row)?);
  357. }
  358. let status = transaction_row
  359. .try_get::<u32, usize>(0)
  360. .map_err(|_| Error::Storage("Invalid status".to_string()))?
  361. .try_into()
  362. .map_err(|_| Error::Storage("Invalid status".to_string()))?;
  363. let typ = transaction_row
  364. .try_get::<u32, usize>(1)
  365. .map_err(|_| Error::Storage("Invalid type".to_string()))?
  366. .try_into()
  367. .map_err(|_| Error::Storage("Invalid type".to_string()))?;
  368. let reference = transaction_row
  369. .try_get::<String, usize>(2)
  370. .map_err(|_| Error::Storage("Invalid reference".to_string()))?;
  371. let created_at = transaction_row
  372. .try_get::<NaiveDateTime, usize>(3)
  373. .map_err(|e| Error::InvalidDate(e.to_string()))?
  374. .and_utc();
  375. let updated_at = transaction_row
  376. .try_get::<NaiveDateTime, usize>(4)
  377. .map_err(|e| Error::InvalidDate(e.to_string()))?
  378. .and_utc();
  379. Ok(from_db::Transaction {
  380. id: transaction_id.clone(),
  381. spend,
  382. create,
  383. typ,
  384. status,
  385. reference,
  386. created_at,
  387. updated_at,
  388. })
  389. }
  390. async fn get_transactions(
  391. &self,
  392. account: &AccountId,
  393. typ: Option<Type>,
  394. ) -> Result<Vec<from_db::Transaction>, Error> {
  395. let mut conn = self
  396. .db
  397. .acquire()
  398. .await
  399. .map_err(|e| Error::Storage(e.to_string()))?;
  400. let sql = sqlx::query(if typ.is_some() {
  401. r#"SELECT "transaction_id" FROM "transaction_accounts" WHERE "account_id" = ? AND "type" = ?" ORDER BY "created_at" DESC"#
  402. } else {
  403. r#"SELECT "transaction_id" FROM "transaction_accounts" WHERE "account_id" = ? ORDER BY "created_at" DESC"#
  404. }).bind(account.to_string());
  405. let ids = if let Some(typ) = typ {
  406. sql.bind::<u32>(typ.into())
  407. } else {
  408. sql
  409. }
  410. .fetch_all(&mut *conn)
  411. .await
  412. .map_err(|e| Error::Storage(e.to_string()))?
  413. .iter()
  414. .map(|row| {
  415. let id: Result<TransactionId, _> = row
  416. .try_get::<String, usize>(0)
  417. .map_err(|_| Error::Storage("Invalid transaction_id".to_string()))?
  418. .as_str()
  419. .try_into();
  420. id.map_err(|_| Error::Storage("Invalid transaction_id length".to_string()))
  421. })
  422. .collect::<Result<Vec<_>, Error>>()?;
  423. drop(conn);
  424. let mut transactions = vec![];
  425. for id in ids.into_iter() {
  426. transactions.push(self.get_transaction(&id).await?);
  427. }
  428. Ok(transactions)
  429. }
  430. }