mod.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. //! SQL Mint Auth
  2. use std::collections::HashMap;
  3. use std::fmt::Debug;
  4. use std::str::FromStr;
  5. use std::sync::Arc;
  6. use async_trait::async_trait;
  7. use cdk_common::database::{self, MintAuthDatabase, MintAuthTransaction};
  8. use cdk_common::mint::MintKeySetInfo;
  9. use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
  10. use cdk_common::{AuthRequired, ProtectedEndpoint};
  11. use migrations::MIGRATIONS;
  12. use tracing::instrument;
  13. use super::{sql_row_to_blind_signature, sql_row_to_keyset_info, SQLTransaction};
  14. use crate::column_as_string;
  15. use crate::common::migrate;
  16. use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
  17. use crate::mint::Error;
  18. use crate::pool::{DatabasePool, Pool, PooledResource};
  19. use crate::stmt::query;
  20. /// Mint SQL Database
  21. #[derive(Debug, Clone)]
  22. pub struct SQLMintAuthDatabase<RM>
  23. where
  24. RM: DatabasePool + 'static,
  25. {
  26. pool: Arc<Pool<RM>>,
  27. }
  28. impl<RM> SQLMintAuthDatabase<RM>
  29. where
  30. RM: DatabasePool + 'static,
  31. {
  32. /// Creates a new instance
  33. pub async fn new<X>(db: X) -> Result<Self, Error>
  34. where
  35. X: Into<RM::Config>,
  36. {
  37. let pool = Pool::new(db.into());
  38. Self::migrate(pool.get().map_err(|e| Error::Database(Box::new(e)))?).await?;
  39. Ok(Self { pool })
  40. }
  41. /// Migrate
  42. async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
  43. let tx = ConnectionWithTransaction::new(conn).await?;
  44. migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
  45. tx.commit().await?;
  46. Ok(())
  47. }
  48. }
  49. #[rustfmt::skip]
  50. mod migrations;
  51. #[async_trait]
  52. impl<RM> MintAuthTransaction<database::Error> for SQLTransaction<RM>
  53. where
  54. RM: DatabasePool + 'static,
  55. {
  56. #[instrument(skip(self))]
  57. async fn set_active_keyset(&mut self, id: Id) -> Result<(), database::Error> {
  58. tracing::info!("Setting auth keyset {id} active");
  59. query(
  60. r#"
  61. UPDATE keyset
  62. SET active = CASE
  63. WHEN id = :id THEN TRUE
  64. ELSE FALSE
  65. END;
  66. "#,
  67. )?
  68. .bind("id", id.to_string())
  69. .execute(&self.inner)
  70. .await?;
  71. Ok(())
  72. }
  73. async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), database::Error> {
  74. query(
  75. r#"
  76. INSERT INTO
  77. keyset (
  78. id, unit, active, valid_from, valid_to, derivation_path,
  79. max_order, derivation_path_index
  80. )
  81. VALUES (
  82. :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
  83. :max_order, :derivation_path_index
  84. )
  85. ON CONFLICT(id) DO UPDATE SET
  86. unit = excluded.unit,
  87. active = excluded.active,
  88. valid_from = excluded.valid_from,
  89. valid_to = excluded.valid_to,
  90. derivation_path = excluded.derivation_path,
  91. max_order = excluded.max_order,
  92. derivation_path_index = excluded.derivation_path_index
  93. "#,
  94. )?
  95. .bind("id", keyset.id.to_string())
  96. .bind("unit", keyset.unit.to_string())
  97. .bind("active", keyset.active)
  98. .bind("valid_from", keyset.valid_from as i64)
  99. .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
  100. .bind("derivation_path", keyset.derivation_path.to_string())
  101. .bind("max_order", keyset.max_order)
  102. .bind("derivation_path_index", keyset.derivation_path_index)
  103. .execute(&self.inner)
  104. .await?;
  105. Ok(())
  106. }
  107. async fn add_proof(&mut self, proof: AuthProof) -> Result<(), database::Error> {
  108. if let Err(err) = query(
  109. r#"
  110. INSERT INTO proof
  111. (y, keyset_id, secret, c, state)
  112. VALUES
  113. (:y, :keyset_id, :secret, :c, :state)
  114. "#,
  115. )?
  116. .bind("y", proof.y()?.to_bytes().to_vec())
  117. .bind("keyset_id", proof.keyset_id.to_string())
  118. .bind("secret", proof.secret.to_string())
  119. .bind("c", proof.c.to_bytes().to_vec())
  120. .bind("state", "UNSPENT".to_string())
  121. .execute(&self.inner)
  122. .await
  123. {
  124. tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
  125. }
  126. Ok(())
  127. }
  128. async fn update_proof_state(
  129. &mut self,
  130. y: &PublicKey,
  131. proofs_state: State,
  132. ) -> Result<Option<State>, Self::Err> {
  133. let current_state = query(r#"SELECT state FROM proof WHERE y = :y FOR UPDATE"#)?
  134. .bind("y", y.to_bytes().to_vec())
  135. .pluck(&self.inner)
  136. .await?
  137. .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
  138. .transpose()?;
  139. query(r#"UPDATE proof SET state = :new_state WHERE state = :state AND y = :y"#)?
  140. .bind("y", y.to_bytes().to_vec())
  141. .bind(
  142. "state",
  143. current_state.as_ref().map(|state| state.to_string()),
  144. )
  145. .bind("new_state", proofs_state.to_string())
  146. .execute(&self.inner)
  147. .await?;
  148. Ok(current_state)
  149. }
  150. async fn add_blind_signatures(
  151. &mut self,
  152. blinded_messages: &[PublicKey],
  153. blind_signatures: &[BlindSignature],
  154. ) -> Result<(), database::Error> {
  155. for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
  156. query(
  157. r#"
  158. INSERT
  159. INTO blind_signature
  160. (blinded_message, amount, keyset_id, c)
  161. VALUES
  162. (:blinded_message, :amount, :keyset_id, :c)
  163. "#,
  164. )?
  165. .bind("blinded_message", message.to_bytes().to_vec())
  166. .bind("amount", u64::from(signature.amount) as i64)
  167. .bind("keyset_id", signature.keyset_id.to_string())
  168. .bind("c", signature.c.to_bytes().to_vec())
  169. .execute(&self.inner)
  170. .await?;
  171. }
  172. Ok(())
  173. }
  174. async fn add_protected_endpoints(
  175. &mut self,
  176. protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
  177. ) -> Result<(), database::Error> {
  178. for (endpoint, auth) in protected_endpoints.iter() {
  179. if let Err(err) = query(
  180. r#"
  181. INSERT OR REPLACE INTO protected_endpoints
  182. (endpoint, auth)
  183. VALUES (:endpoint, :auth);
  184. "#,
  185. )?
  186. .bind("endpoint", serde_json::to_string(endpoint)?)
  187. .bind("auth", serde_json::to_string(auth)?)
  188. .execute(&self.inner)
  189. .await
  190. {
  191. tracing::debug!(
  192. "Attempting to add protected endpoint. Skipping.... {:?}",
  193. err
  194. );
  195. }
  196. }
  197. Ok(())
  198. }
  199. async fn remove_protected_endpoints(
  200. &mut self,
  201. protected_endpoints: Vec<ProtectedEndpoint>,
  202. ) -> Result<(), database::Error> {
  203. query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)?
  204. .bind_vec(
  205. "endpoints",
  206. protected_endpoints
  207. .iter()
  208. .map(serde_json::to_string)
  209. .collect::<Result<_, _>>()?,
  210. )
  211. .execute(&self.inner)
  212. .await?;
  213. Ok(())
  214. }
  215. }
  216. #[async_trait]
  217. impl<RM> MintAuthDatabase for SQLMintAuthDatabase<RM>
  218. where
  219. RM: DatabasePool + 'static,
  220. {
  221. type Err = database::Error;
  222. async fn begin_transaction<'a>(
  223. &'a self,
  224. ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
  225. {
  226. Ok(Box::new(SQLTransaction {
  227. inner: ConnectionWithTransaction::new(
  228. self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
  229. )
  230. .await?,
  231. }))
  232. }
  233. async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
  234. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  235. Ok(query(
  236. r#"
  237. SELECT
  238. id
  239. FROM
  240. keyset
  241. WHERE
  242. active = :active;
  243. "#,
  244. )?
  245. .bind("active", true)
  246. .pluck(&*conn)
  247. .await?
  248. .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
  249. .transpose()?)
  250. }
  251. async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
  252. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  253. Ok(query(
  254. r#"SELECT
  255. id,
  256. unit,
  257. active,
  258. valid_from,
  259. valid_to,
  260. derivation_path,
  261. derivation_path_index,
  262. max_order,
  263. amounts,
  264. input_fee_ppk
  265. FROM
  266. keyset
  267. WHERE id=:id"#,
  268. )?
  269. .bind("id", id.to_string())
  270. .fetch_one(&*conn)
  271. .await?
  272. .map(sql_row_to_keyset_info)
  273. .transpose()?)
  274. }
  275. async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
  276. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  277. Ok(query(
  278. r#"SELECT
  279. id,
  280. unit,
  281. active,
  282. valid_from,
  283. valid_to,
  284. derivation_path,
  285. derivation_path_index,
  286. max_order,
  287. amounts,
  288. input_fee_ppk
  289. FROM
  290. keyset
  291. WHERE id=:id"#,
  292. )?
  293. .fetch_all(&*conn)
  294. .await?
  295. .into_iter()
  296. .map(sql_row_to_keyset_info)
  297. .collect::<Result<Vec<_>, _>>()?)
  298. }
  299. async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
  300. if ys.is_empty() {
  301. return Ok(vec![]);
  302. }
  303. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  304. let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
  305. .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
  306. .fetch_all(&*conn)
  307. .await?
  308. .into_iter()
  309. .map(|row| {
  310. Ok((
  311. column_as_string!(&row[0], PublicKey::from_hex, PublicKey::from_slice),
  312. column_as_string!(&row[1], State::from_str),
  313. ))
  314. })
  315. .collect::<Result<HashMap<_, _>, Error>>()?;
  316. Ok(ys.iter().map(|y| current_states.remove(y)).collect())
  317. }
  318. async fn get_blind_signatures(
  319. &self,
  320. blinded_messages: &[PublicKey],
  321. ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
  322. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  323. let mut blinded_signatures = query(
  324. r#"SELECT
  325. keyset_id,
  326. amount,
  327. c,
  328. dleq_e,
  329. dleq_s,
  330. blinded_message,
  331. FROM
  332. blind_signature
  333. WHERE blinded_message IN (:blinded_message)
  334. "#,
  335. )?
  336. .bind_vec(
  337. "blinded_message",
  338. blinded_messages
  339. .iter()
  340. .map(|bm| bm.to_bytes().to_vec())
  341. .collect(),
  342. )
  343. .fetch_all(&*conn)
  344. .await?
  345. .into_iter()
  346. .map(|mut row| {
  347. Ok((
  348. column_as_string!(
  349. &row.pop().ok_or(Error::InvalidDbResponse)?,
  350. PublicKey::from_hex,
  351. PublicKey::from_slice
  352. ),
  353. sql_row_to_blind_signature(row)?,
  354. ))
  355. })
  356. .collect::<Result<HashMap<_, _>, Error>>()?;
  357. Ok(blinded_messages
  358. .iter()
  359. .map(|bm| blinded_signatures.remove(bm))
  360. .collect())
  361. }
  362. async fn get_auth_for_endpoint(
  363. &self,
  364. protected_endpoint: ProtectedEndpoint,
  365. ) -> Result<Option<AuthRequired>, Self::Err> {
  366. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  367. Ok(
  368. query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)?
  369. .bind("endpoint", serde_json::to_string(&protected_endpoint)?)
  370. .pluck(&*conn)
  371. .await?
  372. .map(|auth| {
  373. Ok::<_, Error>(column_as_string!(
  374. auth,
  375. serde_json::from_str,
  376. serde_json::from_slice
  377. ))
  378. })
  379. .transpose()?,
  380. )
  381. }
  382. async fn get_auth_for_endpoints(
  383. &self,
  384. ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
  385. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  386. Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)?
  387. .fetch_all(&*conn)
  388. .await?
  389. .into_iter()
  390. .map(|row| {
  391. let endpoint =
  392. column_as_string!(&row[0], serde_json::from_str, serde_json::from_slice);
  393. let auth = column_as_string!(&row[1], serde_json::from_str, serde_json::from_slice);
  394. Ok((endpoint, Some(auth)))
  395. })
  396. .collect::<Result<HashMap<_, _>, Error>>()?)
  397. }
  398. }