mod.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. //! SQL Mint Auth
  2. use std::collections::HashMap;
  3. use std::fmt::Debug;
  4. use std::str::FromStr;
  5. use std::sync::Arc;
  6. use async_trait::async_trait;
  7. use cdk_common::database::{self, MintAuthDatabase, MintAuthTransaction};
  8. use cdk_common::mint::MintKeySetInfo;
  9. use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
  10. use cdk_common::{AuthRequired, ProtectedEndpoint};
  11. use migrations::MIGRATIONS;
  12. use tracing::instrument;
  13. use super::{sql_row_to_blind_signature, sql_row_to_keyset_info, SQLTransaction};
  14. use crate::column_as_string;
  15. use crate::common::migrate;
  16. use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
  17. use crate::mint::Error;
  18. use crate::pool::{DatabasePool, Pool, PooledResource};
  19. use crate::stmt::query;
  20. /// Mint SQL Database
  21. #[derive(Debug, Clone)]
  22. pub struct SQLMintAuthDatabase<RM>
  23. where
  24. RM: DatabasePool + 'static,
  25. {
  26. pool: Arc<Pool<RM>>,
  27. }
  28. impl<RM> SQLMintAuthDatabase<RM>
  29. where
  30. RM: DatabasePool + 'static,
  31. {
  32. /// Creates a new instance
  33. pub async fn new<X>(db: X) -> Result<Self, Error>
  34. where
  35. X: Into<RM::Config>,
  36. {
  37. let pool = Pool::new(db.into());
  38. Self::migrate(pool.get().map_err(|e| Error::Database(Box::new(e)))?).await?;
  39. Ok(Self { pool })
  40. }
  41. /// Migrate
  42. async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
  43. let tx = ConnectionWithTransaction::new(conn).await?;
  44. migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
  45. tx.commit().await?;
  46. Ok(())
  47. }
  48. }
  49. #[rustfmt::skip]
  50. mod migrations {
  51. include!(concat!(env!("OUT_DIR"), "/migrations_mint_auth.rs"));
  52. }
  53. #[async_trait]
  54. impl<RM> MintAuthTransaction<database::Error> for SQLTransaction<RM>
  55. where
  56. RM: DatabasePool + 'static,
  57. {
  58. #[instrument(skip(self))]
  59. async fn set_active_keyset(&mut self, id: Id) -> Result<(), database::Error> {
  60. tracing::info!("Setting auth keyset {id} active");
  61. query(
  62. r#"
  63. UPDATE keyset
  64. SET active = CASE
  65. WHEN id = :id THEN TRUE
  66. ELSE FALSE
  67. END;
  68. "#,
  69. )?
  70. .bind("id", id.to_string())
  71. .execute(&self.inner)
  72. .await?;
  73. Ok(())
  74. }
  75. async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), database::Error> {
  76. query(
  77. r#"
  78. INSERT INTO
  79. keyset (
  80. id, unit, active, valid_from, valid_to, derivation_path,
  81. amounts, input_fee_ppk, derivation_path_index
  82. )
  83. VALUES (
  84. :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
  85. :amounts, :input_fee_ppk, :derivation_path_index
  86. )
  87. ON CONFLICT(id) DO UPDATE SET
  88. unit = excluded.unit,
  89. active = excluded.active,
  90. valid_from = excluded.valid_from,
  91. valid_to = excluded.valid_to,
  92. derivation_path = excluded.derivation_path,
  93. amounts = excluded.amounts,
  94. input_fee_ppk = excluded.input_fee_ppk,
  95. derivation_path_index = excluded.derivation_path_index
  96. "#,
  97. )?
  98. .bind("id", keyset.id.to_string())
  99. .bind("unit", keyset.unit.to_string())
  100. .bind("active", keyset.active)
  101. .bind("valid_from", keyset.valid_from as i64)
  102. .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
  103. .bind("derivation_path", keyset.derivation_path.to_string())
  104. .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
  105. .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
  106. .bind("derivation_path_index", keyset.derivation_path_index)
  107. .execute(&self.inner)
  108. .await?;
  109. Ok(())
  110. }
  111. async fn add_proof(&mut self, proof: AuthProof) -> Result<(), database::Error> {
  112. let y = proof.y()?;
  113. if let Err(err) = query(
  114. r#"
  115. INSERT INTO proof
  116. (y, keyset_id, secret, c, state)
  117. VALUES
  118. (:y, :keyset_id, :secret, :c, :state)
  119. "#,
  120. )?
  121. .bind("y", y.to_bytes().to_vec())
  122. .bind("keyset_id", proof.keyset_id.to_string())
  123. .bind("secret", proof.secret.to_string())
  124. .bind("c", proof.c.to_bytes().to_vec())
  125. .bind("state", "UNSPENT".to_string())
  126. .execute(&self.inner)
  127. .await
  128. {
  129. tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
  130. }
  131. self.locked_records.lock(y);
  132. Ok(())
  133. }
  134. async fn update_proof_state(
  135. &mut self,
  136. y: &PublicKey,
  137. proofs_state: State,
  138. ) -> Result<Option<State>, Self::Err> {
  139. self.locked_records.is_locked(y)?;
  140. let current_state = query(r#"SELECT state FROM proof WHERE y = :y FOR UPDATE"#)?
  141. .bind("y", y.to_bytes().to_vec())
  142. .pluck(&self.inner)
  143. .await?
  144. .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
  145. .transpose()?;
  146. query(r#"UPDATE proof SET state = :new_state WHERE y = :y"#)?
  147. .bind("y", y.to_bytes().to_vec())
  148. .bind("new_state", proofs_state.to_string())
  149. .execute(&self.inner)
  150. .await?;
  151. Ok(current_state)
  152. }
  153. async fn add_blind_signatures(
  154. &mut self,
  155. blinded_messages: &[PublicKey],
  156. blind_signatures: &[BlindSignature],
  157. ) -> Result<(), database::Error> {
  158. for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
  159. query(
  160. r#"
  161. INSERT
  162. INTO blind_signature
  163. (blinded_message, amount, keyset_id, c)
  164. VALUES
  165. (:blinded_message, :amount, :keyset_id, :c)
  166. "#,
  167. )?
  168. .bind("blinded_message", message.to_bytes().to_vec())
  169. .bind("amount", u64::from(signature.amount) as i64)
  170. .bind("keyset_id", signature.keyset_id.to_string())
  171. .bind("c", signature.c.to_bytes().to_vec())
  172. .execute(&self.inner)
  173. .await?;
  174. }
  175. Ok(())
  176. }
  177. async fn add_protected_endpoints(
  178. &mut self,
  179. protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
  180. ) -> Result<(), database::Error> {
  181. for (endpoint, auth) in protected_endpoints.iter() {
  182. if let Err(err) = query(
  183. r#"
  184. INSERT INTO protected_endpoints
  185. (endpoint, auth)
  186. VALUES (:endpoint, :auth)
  187. ON CONFLICT (endpoint) DO UPDATE SET
  188. auth = EXCLUDED.auth;
  189. "#,
  190. )?
  191. .bind("endpoint", serde_json::to_string(endpoint)?)
  192. .bind("auth", serde_json::to_string(auth)?)
  193. .execute(&self.inner)
  194. .await
  195. {
  196. tracing::debug!(
  197. "Attempting to add protected endpoint. Skipping.... {:?}",
  198. err
  199. );
  200. }
  201. }
  202. Ok(())
  203. }
  204. async fn remove_protected_endpoints(
  205. &mut self,
  206. protected_endpoints: Vec<ProtectedEndpoint>,
  207. ) -> Result<(), database::Error> {
  208. query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)?
  209. .bind_vec(
  210. "endpoints",
  211. protected_endpoints
  212. .iter()
  213. .map(serde_json::to_string)
  214. .collect::<Result<_, _>>()?,
  215. )
  216. .execute(&self.inner)
  217. .await?;
  218. Ok(())
  219. }
  220. }
  221. #[async_trait]
  222. impl<RM> MintAuthDatabase for SQLMintAuthDatabase<RM>
  223. where
  224. RM: DatabasePool + 'static,
  225. {
  226. type Err = database::Error;
  227. async fn begin_transaction<'a>(
  228. &'a self,
  229. ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
  230. {
  231. Ok(Box::new(SQLTransaction {
  232. inner: ConnectionWithTransaction::new(
  233. self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
  234. )
  235. .await?,
  236. locked_records: Default::default(),
  237. }))
  238. }
  239. async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
  240. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  241. Ok(query(
  242. r#"
  243. SELECT
  244. id
  245. FROM
  246. keyset
  247. WHERE
  248. active = :active;
  249. "#,
  250. )?
  251. .bind("active", true)
  252. .pluck(&*conn)
  253. .await?
  254. .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
  255. .transpose()?)
  256. }
  257. async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
  258. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  259. Ok(query(
  260. r#"SELECT
  261. id,
  262. unit,
  263. active,
  264. valid_from,
  265. valid_to,
  266. derivation_path,
  267. derivation_path_index,
  268. amounts,
  269. input_fee_ppk
  270. FROM
  271. keyset
  272. WHERE id=:id"#,
  273. )?
  274. .bind("id", id.to_string())
  275. .fetch_one(&*conn)
  276. .await?
  277. .map(sql_row_to_keyset_info)
  278. .transpose()?)
  279. }
  280. async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
  281. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  282. Ok(query(
  283. r#"SELECT
  284. id,
  285. unit,
  286. active,
  287. valid_from,
  288. valid_to,
  289. derivation_path,
  290. derivation_path_index,
  291. amounts,
  292. input_fee_ppk
  293. FROM
  294. keyset
  295. WHERE id=:id"#,
  296. )?
  297. .fetch_all(&*conn)
  298. .await?
  299. .into_iter()
  300. .map(sql_row_to_keyset_info)
  301. .collect::<Result<Vec<_>, _>>()?)
  302. }
  303. async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
  304. if ys.is_empty() {
  305. return Ok(vec![]);
  306. }
  307. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  308. let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
  309. .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
  310. .fetch_all(&*conn)
  311. .await?
  312. .into_iter()
  313. .map(|row| {
  314. Ok((
  315. column_as_string!(&row[0], PublicKey::from_hex, PublicKey::from_slice),
  316. column_as_string!(&row[1], State::from_str),
  317. ))
  318. })
  319. .collect::<Result<HashMap<_, _>, Error>>()?;
  320. Ok(ys.iter().map(|y| current_states.remove(y)).collect())
  321. }
  322. async fn get_blind_signatures(
  323. &self,
  324. blinded_messages: &[PublicKey],
  325. ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
  326. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  327. let mut blinded_signatures = query(
  328. r#"SELECT
  329. keyset_id,
  330. amount,
  331. c,
  332. dleq_e,
  333. dleq_s,
  334. blinded_message,
  335. FROM
  336. blind_signature
  337. WHERE blinded_message IN (:blinded_message)
  338. "#,
  339. )?
  340. .bind_vec(
  341. "blinded_message",
  342. blinded_messages
  343. .iter()
  344. .map(|bm| bm.to_bytes().to_vec())
  345. .collect(),
  346. )
  347. .fetch_all(&*conn)
  348. .await?
  349. .into_iter()
  350. .map(|mut row| {
  351. Ok((
  352. column_as_string!(
  353. &row.pop().ok_or(Error::InvalidDbResponse)?,
  354. PublicKey::from_hex,
  355. PublicKey::from_slice
  356. ),
  357. sql_row_to_blind_signature(row)?,
  358. ))
  359. })
  360. .collect::<Result<HashMap<_, _>, Error>>()?;
  361. Ok(blinded_messages
  362. .iter()
  363. .map(|bm| blinded_signatures.remove(bm))
  364. .collect())
  365. }
  366. async fn get_auth_for_endpoint(
  367. &self,
  368. protected_endpoint: ProtectedEndpoint,
  369. ) -> Result<Option<AuthRequired>, Self::Err> {
  370. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  371. Ok(
  372. query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)?
  373. .bind("endpoint", serde_json::to_string(&protected_endpoint)?)
  374. .pluck(&*conn)
  375. .await?
  376. .map(|auth| {
  377. Ok::<_, Error>(column_as_string!(
  378. auth,
  379. serde_json::from_str,
  380. serde_json::from_slice
  381. ))
  382. })
  383. .transpose()?,
  384. )
  385. }
  386. async fn get_auth_for_endpoints(
  387. &self,
  388. ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
  389. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  390. Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)?
  391. .fetch_all(&*conn)
  392. .await?
  393. .into_iter()
  394. .map(|row| {
  395. let endpoint =
  396. column_as_string!(&row[0], serde_json::from_str, serde_json::from_slice);
  397. let auth = column_as_string!(&row[1], serde_json::from_str, serde_json::from_slice);
  398. Ok((endpoint, Some(auth)))
  399. })
  400. .collect::<Result<HashMap<_, _>, Error>>()?)
  401. }
  402. }