use crate::{ amount::AmountCents, asset::AssetId, storage::Error, AccountId, Amount, Asset, AssetManager, Payment, PaymentId, Status, Storage, Transaction, TransactionId, }; use futures::TryStreamExt; use sqlx::{sqlite::SqliteRow, Executor, Row}; use std::{collections::HashMap, marker::PhantomData}; mod batch; pub use batch::Batch; pub struct Sqlite<'a> { db: sqlx::SqlitePool, asset_manager: AssetManager, _phantom: PhantomData<&'a ()>, } impl<'a> Sqlite<'a> { pub fn new(db: sqlx::SqlitePool, asset_manager: AssetManager) -> Self { Self { db, asset_manager, _phantom: PhantomData, } } pub async fn setup(&self) -> Result<(), sqlx::Error> { let mut x = self.db.begin().await?; x.execute( r#"CREATE TABLE IF NOT EXISTS "payments" ( "transaction_id" VARCHAR(66) NOT NULL, "position_id" INTEGER NOT NULL, "asset_id" TEXT NOT NULL, "cents" TEXT NOT NULL, "status" INTEGER NOT NULL, "to" VARCHAR(71) NOT NULL, "spent_by" TEXT, "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP, "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY ("transaction_id", "position_id") ); CREATE INDEX IF NOT EXISTS payments_to ON payments ("to", "asset_id", "status", "spent_by"); CREATE TABLE IF NOT EXISTS "transactions" ( "transaction_id" VARCHAR(66) NOT NULL, "status" INTEGER NOT NULL, "reference" TEXT NOT NULL, "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP, "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY ("transaction_id") ); CREATE TABLE IF NOT EXISTS "transaction_payments" ( "transaction_id" VARCHAR(66) NOT NULL, "payment_transaction_id" VARCHAR(66) NOT NULL, "payment_position_id" INTEGER NOT NULL, "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP, "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP, PRIMARY KEY ("transaction_id", "payment_transaction_id", "payment_position_id") ); "#, ) .await .expect("valid"); x.commit().await?; Ok(()) } #[inline] fn sql_row_to_payment(&self, row: SqliteRow) -> Result { let id = PaymentId { transaction: row .try_get::(0) .map_err(|_| Error::Storage("Invalid payment_id".to_string()))? .as_str() .try_into() .map_err(|_| Error::Storage("Invalid transaction_id length".to_string()))?, position: row .try_get::(1) .map_err(|_| Error::Storage("Invalid payment_id".to_string()))? .try_into() .map_err(|_| Error::Storage("Invalid payment_id".to_string()))?, }; let cents = row .try_get::(3) .map_err(|_| Error::Storage("Invalid cents".to_string()))? .parse::() .map_err(|_| Error::Storage("Invalid cents".to_string()))?; Ok(Payment { id, amount: self .asset_manager .asset( row.try_get::(2) .map_err(|_| Error::Storage("Invalid asset_id".to_string()))? .parse() .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?, ) .map_err(|e| Error::Storage(e.to_string()))? .new_amount(cents), to: row .try_get::(4) .map_err(|_| Error::Storage("Invalid `to`".to_string()))? .as_str() .try_into() .map_err(|_| Error::Storage("Invalid `to`".to_string()))?, status: row .try_get::(5) .map_err(|_| Error::Storage("Invalid `status`".to_string()))? .try_into() .map_err(|_| Error::Storage("Invalid status".to_string()))?, spent_by: row .try_get::, usize>(6) .map_err(|_| Error::Storage("Invalid spent_by".to_string()))? .map(|s| s.as_str().try_into()) .transpose() .map_err(|_| Error::Storage("Invalid spent_by".to_string()))?, }) } } #[async_trait::async_trait] impl<'a> Storage<'a, Batch<'a>> for Sqlite<'a> { async fn begin(&'a self) -> Result, Error> { self.db .begin() .await .map(|x| Batch::new(x)) .map_err(|x| Error::Storage(x.to_string())) } async fn get_payment(&self, id: PaymentId) -> Result { let mut conn = self .db .acquire() .await .map_err(|e| Error::Storage(e.to_string()))?; let row = sqlx::query( r#" SELECT "p"."transaction_id", "p"."position_id", "p"."asset_id", "p"."cents", "p"."to", "p"."status", "p"."spent_by" FROM "payments" "p" WHERE "p"."transaction_id" = ? AND "p"."position_id" = ? LIMIT 1 "#, ) .bind(id.transaction.to_string()) .bind(id.position.to_string()) .fetch_optional(&mut *conn) .await .map_err(|e| Error::Storage(e.to_string()))? .ok_or(Error::NotFound)?; self.sql_row_to_payment(row) } async fn get_balance(&self, account: &AccountId) -> Result, Error> { let mut conn = self .db .acquire() .await .map_err(|e| Error::Storage(e.to_string()))?; let mut result = sqlx::query( r#" SELECT "asset_id", "cents" FROM "payments" WHERE "to" = ? AND status = ? AND "spent_by" IS NULL "#, ) .bind(account.to_string()) .bind::(Status::Settled.into()) .fetch(&mut *conn); let mut balances = HashMap::::new(); while let Some(row) = result .try_next() .await .map_err(|e| Error::Storage(e.to_string()))? { let asset = self .asset_manager .asset( row.try_get::(0) .map_err(|_| Error::Storage("Invalid asset_id".to_string()))? .parse() .map_err(|_| Error::Storage("Invalid asset_id".to_string()))?, ) .map_err(|e| Error::Storage(e.to_string()))?; let cents = row .try_get::(1) .map_err(|_| Error::Storage("Invalid cents".to_string()))? .parse::() .map_err(|_| Error::Storage("Invalid cents".to_string()))?; let new_amount = asset.new_amount(cents); if let Some(amount) = balances.get_mut(&asset) { *amount = amount .checked_add(&new_amount) .ok_or(Error::Storage("amount overflow".to_owned()))?; } else { balances.insert(asset, new_amount); } } Ok(balances.into_iter().map(|(_, v)| v).collect()) } async fn get_unspent_payments( &self, account: &AccountId, asset: AssetId, mut target_amount: AmountCents, ) -> Result, Error> { let mut conn = self .db .acquire() .await .map_err(|e| Error::Storage(e.to_string()))?; let mut result = sqlx::query( r#" SELECT "p"."transaction_id", "p"."position_id", "p"."asset_id", "p"."cents", "p"."to", "p"."status", "p"."spent_by" FROM "payments" as "p" WHERE "p"."to" = ? AND "p"."asset_id" = ? AND status = ? AND "p"."spent_by" IS NULL ORDER BY cents ASC "#, ) .bind(account.to_string()) .bind(asset.to_string()) .bind::(Status::Settled.into()) .fetch(&mut *conn); let mut to_return = vec![]; while let Some(row) = result .try_next() .await .map_err(|e| Error::Storage(e.to_string()))? { let row = self.sql_row_to_payment(row)?; target_amount -= row.amount.cents(); to_return.push(row); if target_amount <= 0 { break; } } if target_amount <= 0 { Ok(to_return) } else { Err(Error::NotEnoughUnspentPayments(target_amount)) } } async fn get_transaction(&self, transaction_id: &TransactionId) -> Result { let mut conn = self .db .acquire() .await .map_err(|e| Error::Storage(e.to_string()))?; let transaction_row = sqlx::query( r#" SELECT "t"."status", "t"."reference" FROM "transactions" "t" WHERE "t"."transaction_id" = ? "#, ) .bind(transaction_id.to_string()) .fetch_optional(&mut *conn) .await .map_err(|e| Error::Storage(e.to_string()))? .ok_or(Error::NotFound)?; let mut spend_result = sqlx::query( r#" SELECT "p"."transaction_id", "p"."position_id", "p"."asset_id", "p"."cents", "p"."to", "p"."status", "p"."spent_by" FROM "payments" "p" INNER JOIN "transaction_payments" "tp" ON ( "tp"."payment_transaction_id" = "p"."transaction_id" AND "tp"."payment_position_id" = "p"."position_id" ) WHERE "tp"."transaction_id" = ? "#, ) .bind(transaction_id.to_string()) .fetch(&mut *conn); let mut spend = vec![]; while let Some(row) = spend_result .try_next() .await .map_err(|e| Error::Storage(e.to_string()))? { spend.push(self.sql_row_to_payment(row)?); } drop(spend_result); let mut create_result = sqlx::query( r#" SELECT "p"."transaction_id", "p"."position_id", "p"."asset_id", "p"."cents", "p"."to", "p"."status", "p"."spent_by" FROM "payments" "p" WHERE "p"."transaction_id" = ? "#, ) .bind(transaction_id.to_string()) .fetch(&mut *conn); let mut create = vec![]; while let Some(row) = create_result .try_next() .await .map_err(|e| Error::Storage(e.to_string()))? { create.push(self.sql_row_to_payment(row)?); } let status = transaction_row .try_get::(0) .map_err(|_| Error::Storage("Invalid status".to_string()))? .try_into() .map_err(|_| Error::Storage("Invalid status".to_string()))?; let reference = transaction_row .try_get::(1) .map_err(|_| Error::Storage("Invalid reference".to_string()))?; Ok(Transaction { id: transaction_id.clone(), is_external_deposit: spend.is_empty(), spend, create, status, reference, }) } }