mod.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. use crate::{
  2. amount::AmountCents,
  3. asset::AssetId,
  4. storage::{Error, Storage},
  5. transaction::{from_db, Type},
  6. AccountId, Amount, Asset, AssetManager, Payment, PaymentId, Status, TransactionId,
  7. };
  8. use chrono::NaiveDateTime;
  9. use futures::TryStreamExt;
  10. use sqlx::{sqlite::SqliteRow, Executor, Row};
  11. use std::{collections::HashMap, marker::PhantomData};
  12. mod batch;
  13. pub use batch::Batch;
  14. pub use sqlx::sqlite::SqlitePoolOptions;
  15. pub struct Sqlite<'a> {
  16. db: sqlx::SqlitePool,
  17. asset_manager: AssetManager,
  18. _phantom: PhantomData<&'a ()>,
  19. }
  20. impl<'a> Sqlite<'a> {
  21. pub fn new(db: sqlx::SqlitePool, asset_manager: AssetManager) -> Self {
  22. Self {
  23. db,
  24. asset_manager,
  25. _phantom: PhantomData,
  26. }
  27. }
  28. pub async fn setup(&self) -> Result<(), sqlx::Error> {
  29. let mut x = self.db.begin().await?;
  30. x.execute(
  31. r#"CREATE TABLE IF NOT EXISTS "payments" (
  32. "transaction_id" VARCHAR(66) NOT NULL,
  33. "position_id" INTEGER NOT NULL,
  34. "asset_id" TEXT NOT NULL,
  35. "cents" TEXT NOT NULL,
  36. "status" INTEGER NOT NULL,
  37. "to" VARCHAR(71) NOT NULL,
  38. "spent_by" TEXT,
  39. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  40. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  41. PRIMARY KEY ("transaction_id", "position_id")
  42. );
  43. CREATE INDEX IF NOT EXISTS payments_to ON payments ("to", "asset_id", "spent_by", "status");
  44. CREATE TABLE IF NOT EXISTS "transactions" (
  45. "transaction_id" VARCHAR(66) NOT NULL,
  46. "status" INTEGER NOT NULL,
  47. "type" INTEGER NOT NULL,
  48. "reference" TEXT NOT NULL,
  49. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  50. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  51. PRIMARY KEY ("transaction_id")
  52. );
  53. CREATE TABLE IF NOT EXISTS "transaction_accounts" (
  54. "transaction_id" VARCHAR(66) NOT NULL,
  55. "account_id" VARCHAR(71) NOT NULL,
  56. "type" INTEGER NOT NULL,
  57. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  58. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  59. PRIMARY KEY("account_id", "transaction_id")
  60. );
  61. CREATE INDEX IF NOT EXISTS "type" ON "transaction_accounts" ("account_id", "type", "created_at");
  62. CREATE TABLE IF NOT EXISTS "transaction_input_payments" (
  63. "transaction_id" VARCHAR(66) NOT NULL,
  64. "payment_transaction_id" VARCHAR(66) NOT NULL,
  65. "payment_position_id" INTEGER NOT NULL,
  66. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  67. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  68. PRIMARY KEY ("transaction_id", "payment_transaction_id", "payment_position_id")
  69. );
  70. "#,
  71. )
  72. .await
  73. .expect("valid");
  74. x.commit().await?;
  75. Ok(())
  76. }
  77. #[inline]
  78. fn sql_row_to_payment(&self, row: SqliteRow) -> Result<Payment, Error> {
  79. let id = PaymentId {
  80. transaction: row
  81. .try_get::<String, usize>(0)
  82. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?
  83. .as_str()
  84. .try_into()
  85. .map_err(|_| Error::Storage("Invalid transaction_id length".to_string()))?,
  86. position: row
  87. .try_get::<i64, usize>(1)
  88. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?
  89. .try_into()
  90. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?,
  91. };
  92. let cents = row
  93. .try_get::<String, usize>(3)
  94. .map_err(|_| Error::Storage("Invalid cents".to_string()))?
  95. .parse::<i128>()
  96. .map_err(|_| Error::Storage("Invalid cents".to_string()))?;
  97. Ok(Payment {
  98. id,
  99. amount: self
  100. .asset_manager
  101. .asset(
  102. row.try_get::<String, usize>(2)
  103. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?
  104. .parse()
  105. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?,
  106. )
  107. .map_err(|e| Error::Storage(e.to_string()))?
  108. .new_amount(cents),
  109. to: row
  110. .try_get::<String, usize>(4)
  111. .map_err(|_| Error::Storage("Invalid `to`".to_string()))?
  112. .as_str()
  113. .try_into()
  114. .map_err(|_| Error::Storage("Invalid `to`".to_string()))?,
  115. status: row
  116. .try_get::<u32, usize>(5)
  117. .map_err(|_| Error::Storage("Invalid `status`".to_string()))?
  118. .try_into()
  119. .map_err(|_| Error::Storage("Invalid status".to_string()))?,
  120. spent_by: row
  121. .try_get::<Option<String>, usize>(6)
  122. .map_err(|_| Error::Storage("Invalid spent_by".to_string()))?
  123. .map(|s| s.as_str().try_into())
  124. .transpose()
  125. .map_err(|_| Error::Storage("Invalid spent_by".to_string()))?,
  126. })
  127. }
  128. }
  129. #[async_trait::async_trait]
  130. impl<'a> Storage<'a, Batch<'a>> for Sqlite<'a> {
  131. async fn begin(&'a self) -> Result<Batch<'a>, Error> {
  132. self.db
  133. .begin()
  134. .await
  135. .map(|x| Batch::new(x))
  136. .map_err(|x| Error::Storage(x.to_string()))
  137. }
  138. async fn get_payment(&self, id: PaymentId) -> Result<Payment, Error> {
  139. let mut conn = self
  140. .db
  141. .acquire()
  142. .await
  143. .map_err(|e| Error::Storage(e.to_string()))?;
  144. let row = sqlx::query(
  145. r#"
  146. SELECT
  147. "p"."transaction_id",
  148. "p"."position_id",
  149. "p"."asset_id",
  150. "p"."cents",
  151. "p"."to",
  152. "p"."status",
  153. "p"."spent_by"
  154. FROM
  155. "payments" "p"
  156. WHERE
  157. "p"."transaction_id" = ?
  158. AND "p"."position_id" = ?
  159. LIMIT 1
  160. "#,
  161. )
  162. .bind(id.transaction.to_string())
  163. .bind(id.position.to_string())
  164. .fetch_optional(&mut *conn)
  165. .await
  166. .map_err(|e| Error::Storage(e.to_string()))?
  167. .ok_or(Error::NotFound)?;
  168. self.sql_row_to_payment(row)
  169. }
  170. async fn get_balance(&self, account: &AccountId) -> Result<Vec<Amount>, Error> {
  171. let mut conn = self
  172. .db
  173. .acquire()
  174. .await
  175. .map_err(|e| Error::Storage(e.to_string()))?;
  176. let mut result = sqlx::query(
  177. r#"
  178. SELECT
  179. "asset_id",
  180. "cents"
  181. FROM
  182. "payments"
  183. WHERE
  184. "to" = ? AND "spent_by" IS NULL AND status = ?
  185. "#,
  186. )
  187. .bind(account.to_string())
  188. .bind::<u32>(Status::Settled.into())
  189. .fetch(&mut *conn);
  190. let mut balances = HashMap::<Asset, Amount>::new();
  191. while let Some(row) = result
  192. .try_next()
  193. .await
  194. .map_err(|e| Error::Storage(e.to_string()))?
  195. {
  196. let asset = self
  197. .asset_manager
  198. .asset(
  199. row.try_get::<String, usize>(0)
  200. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?
  201. .parse()
  202. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?,
  203. )
  204. .map_err(|e| Error::Storage(e.to_string()))?;
  205. let cents = row
  206. .try_get::<String, usize>(1)
  207. .map_err(|_| Error::Storage("Invalid cents".to_string()))?
  208. .parse::<i128>()
  209. .map_err(|_| Error::Storage("Invalid cents".to_string()))?;
  210. let new_amount = asset.new_amount(cents);
  211. if let Some(amount) = balances.get_mut(&asset) {
  212. *amount = amount
  213. .checked_add(&new_amount)
  214. .ok_or(Error::Storage("amount overflow".to_owned()))?;
  215. } else {
  216. balances.insert(asset, new_amount);
  217. }
  218. }
  219. Ok(balances.into_iter().map(|(_, v)| v).collect())
  220. }
  221. async fn get_unspent_payments(
  222. &self,
  223. account: &AccountId,
  224. asset: AssetId,
  225. mut target_amount: AmountCents,
  226. ) -> Result<Vec<Payment>, Error> {
  227. let mut conn = self
  228. .db
  229. .acquire()
  230. .await
  231. .map_err(|e| Error::Storage(e.to_string()))?;
  232. let mut result = sqlx::query(
  233. r#"
  234. SELECT
  235. "p"."transaction_id",
  236. "p"."position_id",
  237. "p"."asset_id",
  238. "p"."cents",
  239. "p"."to",
  240. "p"."status",
  241. "p"."spent_by"
  242. FROM
  243. "payments" as "p"
  244. WHERE
  245. "p"."to" = ? AND "p"."asset_id" = ? AND status = ? AND "p"."spent_by" IS NULL
  246. ORDER BY cents ASC
  247. "#,
  248. )
  249. .bind(account.to_string())
  250. .bind(asset.to_string())
  251. .bind::<u32>(Status::Settled.into())
  252. .fetch(&mut *conn);
  253. let mut to_return = vec![];
  254. while let Some(row) = result
  255. .try_next()
  256. .await
  257. .map_err(|e| Error::Storage(e.to_string()))?
  258. {
  259. let row = self.sql_row_to_payment(row)?;
  260. target_amount -= row.amount.cents();
  261. to_return.push(row);
  262. if target_amount <= 0 {
  263. break;
  264. }
  265. }
  266. if target_amount <= 0 {
  267. Ok(to_return)
  268. } else {
  269. Err(Error::NotEnoughUnspentPayments(target_amount))
  270. }
  271. }
  272. async fn get_transaction(
  273. &self,
  274. transaction_id: &TransactionId,
  275. ) -> Result<from_db::Transaction, Error> {
  276. let mut conn = self
  277. .db
  278. .acquire()
  279. .await
  280. .map_err(|e| Error::Storage(e.to_string()))?;
  281. let transaction_row = sqlx::query(
  282. r#"
  283. SELECT
  284. "t"."status",
  285. "t"."type",
  286. "t"."reference",
  287. "t"."created_at",
  288. "t"."updated_at"
  289. FROM
  290. "transactions" "t"
  291. WHERE
  292. "t"."transaction_id" = ?
  293. "#,
  294. )
  295. .bind(transaction_id.to_string())
  296. .fetch_optional(&mut *conn)
  297. .await
  298. .map_err(|e| Error::Storage(e.to_string()))?
  299. .ok_or(Error::NotFound)?;
  300. let mut spend_result = sqlx::query(
  301. r#"
  302. SELECT
  303. "p"."transaction_id",
  304. "p"."position_id",
  305. "p"."asset_id",
  306. "p"."cents",
  307. "p"."to",
  308. "p"."status",
  309. "p"."spent_by"
  310. FROM
  311. "transaction_input_payments" "tp"
  312. INNER JOIN
  313. "payments" "p"
  314. ON (
  315. "tp"."payment_transaction_id" = "p"."transaction_id"
  316. AND "tp"."payment_position_id" = "p"."position_id"
  317. )
  318. WHERE
  319. "tp"."transaction_id" = ?
  320. "#,
  321. )
  322. .bind(transaction_id.to_string())
  323. .fetch(&mut *conn);
  324. let mut spend = vec![];
  325. while let Some(row) = spend_result
  326. .try_next()
  327. .await
  328. .map_err(|e| Error::Storage(e.to_string()))?
  329. {
  330. spend.push(self.sql_row_to_payment(row)?);
  331. }
  332. drop(spend_result);
  333. let mut create_result = sqlx::query(
  334. r#"
  335. SELECT
  336. "p"."transaction_id",
  337. "p"."position_id",
  338. "p"."asset_id",
  339. "p"."cents",
  340. "p"."to",
  341. "p"."status",
  342. "p"."spent_by"
  343. FROM
  344. "payments" "p"
  345. WHERE
  346. "p"."transaction_id" = ?
  347. "#,
  348. )
  349. .bind(transaction_id.to_string())
  350. .fetch(&mut *conn);
  351. let mut create = vec![];
  352. while let Some(row) = create_result
  353. .try_next()
  354. .await
  355. .map_err(|e| Error::Storage(e.to_string()))?
  356. {
  357. create.push(self.sql_row_to_payment(row)?);
  358. }
  359. let status = transaction_row
  360. .try_get::<u32, usize>(0)
  361. .map_err(|_| Error::Storage("Invalid status".to_string()))?
  362. .try_into()
  363. .map_err(|_| Error::Storage("Invalid status".to_string()))?;
  364. let typ = transaction_row
  365. .try_get::<u32, usize>(1)
  366. .map_err(|_| Error::Storage("Invalid type".to_string()))?
  367. .try_into()
  368. .map_err(|_| Error::Storage("Invalid type".to_string()))?;
  369. let reference = transaction_row
  370. .try_get::<String, usize>(2)
  371. .map_err(|_| Error::Storage("Invalid reference".to_string()))?;
  372. let created_at = transaction_row
  373. .try_get::<NaiveDateTime, usize>(3)
  374. .map_err(|e| Error::InvalidDate(e.to_string()))?
  375. .and_utc();
  376. let updated_at = transaction_row
  377. .try_get::<NaiveDateTime, usize>(4)
  378. .map_err(|e| Error::InvalidDate(e.to_string()))?
  379. .and_utc();
  380. Ok(from_db::Transaction {
  381. id: transaction_id.clone(),
  382. spend,
  383. create,
  384. typ,
  385. status,
  386. reference,
  387. created_at,
  388. updated_at,
  389. })
  390. }
  391. async fn get_transactions(
  392. &self,
  393. account: &AccountId,
  394. typ: Option<Type>,
  395. ) -> Result<Vec<from_db::Transaction>, Error> {
  396. let mut conn = self
  397. .db
  398. .acquire()
  399. .await
  400. .map_err(|e| Error::Storage(e.to_string()))?;
  401. let sql = sqlx::query(if typ.is_some() {
  402. r#"SELECT "transaction_id" FROM "transaction_accounts" WHERE "account_id" = ? AND "type" = ?" ORDER BY "created_at" DESC"#
  403. } else {
  404. r#"SELECT "transaction_id" FROM "transaction_accounts" WHERE "account_id" = ? ORDER BY "created_at" DESC"#
  405. }).bind(account.to_string());
  406. let ids = if let Some(typ) = typ {
  407. sql.bind::<u32>(typ.into())
  408. } else {
  409. sql
  410. }
  411. .fetch_all(&mut *conn)
  412. .await
  413. .map_err(|e| Error::Storage(e.to_string()))?
  414. .iter()
  415. .map(|row| {
  416. let id: Result<TransactionId, _> = row
  417. .try_get::<String, usize>(0)
  418. .map_err(|_| Error::Storage("Invalid transaction_id".to_string()))?
  419. .as_str()
  420. .try_into();
  421. id.map_err(|_| Error::Storage("Invalid transaction_id length".to_string()))
  422. })
  423. .collect::<Result<Vec<_>, Error>>()?;
  424. drop(conn);
  425. let mut transactions = vec![];
  426. for id in ids.into_iter() {
  427. transactions.push(self.get_transaction(&id).await?);
  428. }
  429. Ok(transactions)
  430. }
  431. }