mod.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. use crate::{
  2. amount::AmountCents, asset::AssetId, storage::Error, transaction::from_db, AccountId, Amount,
  3. Asset, AssetManager, Payment, PaymentId, Status, Storage, TransactionId,
  4. };
  5. use chrono::NaiveDateTime;
  6. use futures::TryStreamExt;
  7. use sqlx::{sqlite::SqliteRow, Executor, Row};
  8. use std::{collections::HashMap, marker::PhantomData};
  9. mod batch;
  10. pub use batch::Batch;
  11. pub struct Sqlite<'a> {
  12. db: sqlx::SqlitePool,
  13. asset_manager: AssetManager,
  14. _phantom: PhantomData<&'a ()>,
  15. }
  16. impl<'a> Sqlite<'a> {
  17. pub fn new(db: sqlx::SqlitePool, asset_manager: AssetManager) -> Self {
  18. Self {
  19. db,
  20. asset_manager,
  21. _phantom: PhantomData,
  22. }
  23. }
  24. pub async fn setup(&self) -> Result<(), sqlx::Error> {
  25. let mut x = self.db.begin().await?;
  26. x.execute(
  27. r#"CREATE TABLE IF NOT EXISTS "payments" (
  28. "transaction_id" VARCHAR(66) NOT NULL,
  29. "position_id" INTEGER NOT NULL,
  30. "asset_id" TEXT NOT NULL,
  31. "cents" TEXT NOT NULL,
  32. "status" INTEGER NOT NULL,
  33. "to" VARCHAR(71) NOT NULL,
  34. "spent_by" TEXT,
  35. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  36. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  37. PRIMARY KEY ("transaction_id", "position_id")
  38. );
  39. CREATE INDEX IF NOT EXISTS payments_to ON payments ("to", "asset_id", "status", "spent_by");
  40. CREATE TABLE IF NOT EXISTS "transactions" (
  41. "transaction_id" VARCHAR(66) NOT NULL,
  42. "status" INTEGER NOT NULL,
  43. "reference" TEXT NOT NULL,
  44. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  45. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  46. PRIMARY KEY ("transaction_id")
  47. );
  48. CREATE TABLE IF NOT EXISTS "transaction_accounts" (
  49. "transaction_id" VARCHAR(66) NOT NULL,
  50. "account_id" VARCHAR(71) NOT NULL,
  51. PRIMARY KEY("account_id", "transaction_id")
  52. );
  53. CREATE TABLE IF NOT EXISTS "transaction_input_payments" (
  54. "transaction_id" VARCHAR(66) NOT NULL,
  55. "payment_transaction_id" VARCHAR(66) NOT NULL,
  56. "payment_position_id" INTEGER NOT NULL,
  57. "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  58. "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP,
  59. PRIMARY KEY ("transaction_id", "payment_transaction_id", "payment_position_id")
  60. );
  61. "#,
  62. )
  63. .await
  64. .expect("valid");
  65. x.commit().await?;
  66. Ok(())
  67. }
  68. #[inline]
  69. fn sql_row_to_payment(&self, row: SqliteRow) -> Result<Payment, Error> {
  70. let id = PaymentId {
  71. transaction: row
  72. .try_get::<String, usize>(0)
  73. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?
  74. .as_str()
  75. .try_into()
  76. .map_err(|_| Error::Storage("Invalid transaction_id length".to_string()))?,
  77. position: row
  78. .try_get::<i64, usize>(1)
  79. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?
  80. .try_into()
  81. .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?,
  82. };
  83. let cents = row
  84. .try_get::<String, usize>(3)
  85. .map_err(|_| Error::Storage("Invalid cents".to_string()))?
  86. .parse::<i128>()
  87. .map_err(|_| Error::Storage("Invalid cents".to_string()))?;
  88. Ok(Payment {
  89. id,
  90. amount: self
  91. .asset_manager
  92. .asset(
  93. row.try_get::<String, usize>(2)
  94. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?
  95. .parse()
  96. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?,
  97. )
  98. .map_err(|e| Error::Storage(e.to_string()))?
  99. .new_amount(cents),
  100. to: row
  101. .try_get::<String, usize>(4)
  102. .map_err(|_| Error::Storage("Invalid `to`".to_string()))?
  103. .as_str()
  104. .try_into()
  105. .map_err(|_| Error::Storage("Invalid `to`".to_string()))?,
  106. status: row
  107. .try_get::<u32, usize>(5)
  108. .map_err(|_| Error::Storage("Invalid `status`".to_string()))?
  109. .try_into()
  110. .map_err(|_| Error::Storage("Invalid status".to_string()))?,
  111. spent_by: row
  112. .try_get::<Option<String>, usize>(6)
  113. .map_err(|_| Error::Storage("Invalid spent_by".to_string()))?
  114. .map(|s| s.as_str().try_into())
  115. .transpose()
  116. .map_err(|_| Error::Storage("Invalid spent_by".to_string()))?,
  117. })
  118. }
  119. }
  120. #[async_trait::async_trait]
  121. impl<'a> Storage<'a, Batch<'a>> for Sqlite<'a> {
  122. async fn begin(&'a self) -> Result<Batch<'a>, Error> {
  123. self.db
  124. .begin()
  125. .await
  126. .map(|x| Batch::new(x))
  127. .map_err(|x| Error::Storage(x.to_string()))
  128. }
  129. async fn get_payment(&self, id: PaymentId) -> Result<Payment, Error> {
  130. let mut conn = self
  131. .db
  132. .acquire()
  133. .await
  134. .map_err(|e| Error::Storage(e.to_string()))?;
  135. let row = sqlx::query(
  136. r#"
  137. SELECT
  138. "p"."transaction_id",
  139. "p"."position_id",
  140. "p"."asset_id",
  141. "p"."cents",
  142. "p"."to",
  143. "p"."status",
  144. "p"."spent_by"
  145. FROM
  146. "payments" "p"
  147. WHERE
  148. "p"."transaction_id" = ?
  149. AND "p"."position_id" = ?
  150. LIMIT 1
  151. "#,
  152. )
  153. .bind(id.transaction.to_string())
  154. .bind(id.position.to_string())
  155. .fetch_optional(&mut *conn)
  156. .await
  157. .map_err(|e| Error::Storage(e.to_string()))?
  158. .ok_or(Error::NotFound)?;
  159. self.sql_row_to_payment(row)
  160. }
  161. async fn get_balance(&self, account: &AccountId) -> Result<Vec<Amount>, Error> {
  162. let mut conn = self
  163. .db
  164. .acquire()
  165. .await
  166. .map_err(|e| Error::Storage(e.to_string()))?;
  167. let mut result = sqlx::query(
  168. r#"
  169. SELECT
  170. "asset_id",
  171. "cents"
  172. FROM
  173. "payments"
  174. WHERE
  175. "to" = ? AND status = ? AND "spent_by" IS NULL
  176. "#,
  177. )
  178. .bind(account.to_string())
  179. .bind::<u32>(Status::Settled.into())
  180. .fetch(&mut *conn);
  181. let mut balances = HashMap::<Asset, Amount>::new();
  182. while let Some(row) = result
  183. .try_next()
  184. .await
  185. .map_err(|e| Error::Storage(e.to_string()))?
  186. {
  187. let asset = self
  188. .asset_manager
  189. .asset(
  190. row.try_get::<String, usize>(0)
  191. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?
  192. .parse()
  193. .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?,
  194. )
  195. .map_err(|e| Error::Storage(e.to_string()))?;
  196. let cents = row
  197. .try_get::<String, usize>(1)
  198. .map_err(|_| Error::Storage("Invalid cents".to_string()))?
  199. .parse::<i128>()
  200. .map_err(|_| Error::Storage("Invalid cents".to_string()))?;
  201. let new_amount = asset.new_amount(cents);
  202. if let Some(amount) = balances.get_mut(&asset) {
  203. *amount = amount
  204. .checked_add(&new_amount)
  205. .ok_or(Error::Storage("amount overflow".to_owned()))?;
  206. } else {
  207. balances.insert(asset, new_amount);
  208. }
  209. }
  210. Ok(balances.into_iter().map(|(_, v)| v).collect())
  211. }
  212. async fn get_unspent_payments(
  213. &self,
  214. account: &AccountId,
  215. asset: AssetId,
  216. mut target_amount: AmountCents,
  217. ) -> Result<Vec<Payment>, Error> {
  218. let mut conn = self
  219. .db
  220. .acquire()
  221. .await
  222. .map_err(|e| Error::Storage(e.to_string()))?;
  223. let mut result = sqlx::query(
  224. r#"
  225. SELECT
  226. "p"."transaction_id",
  227. "p"."position_id",
  228. "p"."asset_id",
  229. "p"."cents",
  230. "p"."to",
  231. "p"."status",
  232. "p"."spent_by"
  233. FROM
  234. "payments" as "p"
  235. WHERE
  236. "p"."to" = ? AND "p"."asset_id" = ? AND status = ? AND "p"."spent_by" IS NULL
  237. ORDER BY cents ASC
  238. "#,
  239. )
  240. .bind(account.to_string())
  241. .bind(asset.to_string())
  242. .bind::<u32>(Status::Settled.into())
  243. .fetch(&mut *conn);
  244. let mut to_return = vec![];
  245. while let Some(row) = result
  246. .try_next()
  247. .await
  248. .map_err(|e| Error::Storage(e.to_string()))?
  249. {
  250. let row = self.sql_row_to_payment(row)?;
  251. target_amount -= row.amount.cents();
  252. to_return.push(row);
  253. if target_amount <= 0 {
  254. break;
  255. }
  256. }
  257. if target_amount <= 0 {
  258. Ok(to_return)
  259. } else {
  260. Err(Error::NotEnoughUnspentPayments(target_amount))
  261. }
  262. }
  263. async fn get_transaction(
  264. &self,
  265. transaction_id: &TransactionId,
  266. ) -> Result<from_db::Transaction, Error> {
  267. let mut conn = self
  268. .db
  269. .acquire()
  270. .await
  271. .map_err(|e| Error::Storage(e.to_string()))?;
  272. let transaction_row = sqlx::query(
  273. r#"
  274. SELECT
  275. "t"."status",
  276. "t"."reference",
  277. "t"."created_at",
  278. "t"."updated_at"
  279. FROM
  280. "transactions" "t"
  281. WHERE
  282. "t"."transaction_id" = ?
  283. "#,
  284. )
  285. .bind(transaction_id.to_string())
  286. .fetch_optional(&mut *conn)
  287. .await
  288. .map_err(|e| Error::Storage(e.to_string()))?
  289. .ok_or(Error::NotFound)?;
  290. let mut spend_result = sqlx::query(
  291. r#"
  292. SELECT
  293. "p"."transaction_id",
  294. "p"."position_id",
  295. "p"."asset_id",
  296. "p"."cents",
  297. "p"."to",
  298. "p"."status",
  299. "p"."spent_by"
  300. FROM
  301. "payments" "p"
  302. INNER JOIN
  303. "transaction_input_payments" "tp"
  304. ON (
  305. "tp"."payment_transaction_id" = "p"."transaction_id"
  306. AND "tp"."payment_position_id" = "p"."position_id"
  307. )
  308. WHERE
  309. "tp"."transaction_id" = ?
  310. "#,
  311. )
  312. .bind(transaction_id.to_string())
  313. .fetch(&mut *conn);
  314. let mut spend = vec![];
  315. while let Some(row) = spend_result
  316. .try_next()
  317. .await
  318. .map_err(|e| Error::Storage(e.to_string()))?
  319. {
  320. spend.push(self.sql_row_to_payment(row)?);
  321. }
  322. drop(spend_result);
  323. let mut create_result = sqlx::query(
  324. r#"
  325. SELECT
  326. "p"."transaction_id",
  327. "p"."position_id",
  328. "p"."asset_id",
  329. "p"."cents",
  330. "p"."to",
  331. "p"."status",
  332. "p"."spent_by"
  333. FROM
  334. "payments" "p"
  335. WHERE
  336. "p"."transaction_id" = ?
  337. "#,
  338. )
  339. .bind(transaction_id.to_string())
  340. .fetch(&mut *conn);
  341. let mut create = vec![];
  342. while let Some(row) = create_result
  343. .try_next()
  344. .await
  345. .map_err(|e| Error::Storage(e.to_string()))?
  346. {
  347. create.push(self.sql_row_to_payment(row)?);
  348. }
  349. let status = transaction_row
  350. .try_get::<u32, usize>(0)
  351. .map_err(|_| Error::Storage("Invalid status".to_string()))?
  352. .try_into()
  353. .map_err(|_| Error::Storage("Invalid status".to_string()))?;
  354. let reference = transaction_row
  355. .try_get::<String, usize>(1)
  356. .map_err(|_| Error::Storage("Invalid reference".to_string()))?;
  357. let created_at = transaction_row
  358. .try_get::<NaiveDateTime, usize>(2)
  359. .map_err(|e| Error::InvalidDate(e.to_string()))?
  360. .and_utc();
  361. let updated_at = transaction_row
  362. .try_get::<NaiveDateTime, usize>(3)
  363. .map_err(|e| Error::InvalidDate(e.to_string()))?
  364. .and_utc();
  365. Ok(from_db::Transaction {
  366. id: transaction_id.clone(),
  367. spend,
  368. create,
  369. status,
  370. reference,
  371. created_at,
  372. updated_at,
  373. })
  374. }
  375. }