mod.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. //! SQL Mint Auth
  2. use std::collections::HashMap;
  3. use std::fmt::Debug;
  4. use std::str::FromStr;
  5. use std::sync::Arc;
  6. use async_trait::async_trait;
  7. use cdk_common::database::{self, MintAuthDatabase, MintAuthTransaction};
  8. use cdk_common::mint::MintKeySetInfo;
  9. use cdk_common::nuts::{AuthProof, BlindSignature, Id, PublicKey, State};
  10. use cdk_common::{AuthRequired, ProtectedEndpoint};
  11. use migrations::MIGRATIONS;
  12. use tracing::instrument;
  13. use super::SQLTransaction;
  14. use crate::column_as_string;
  15. use crate::common::migrate;
  16. use crate::database::{ConnectionWithTransaction, DatabaseExecutor};
  17. use crate::mint::keys::sql_row_to_keyset_info;
  18. use crate::mint::signatures::sql_row_to_blind_signature;
  19. use crate::mint::Error;
  20. use crate::pool::{DatabasePool, Pool, PooledResource};
  21. use crate::stmt::query;
  22. /// Mint SQL Database
  23. #[derive(Debug, Clone)]
  24. pub struct SQLMintAuthDatabase<RM>
  25. where
  26. RM: DatabasePool + 'static,
  27. {
  28. pool: Arc<Pool<RM>>,
  29. }
  30. impl<RM> SQLMintAuthDatabase<RM>
  31. where
  32. RM: DatabasePool + 'static,
  33. {
  34. /// Creates a new instance
  35. pub async fn new<X>(db: X) -> Result<Self, Error>
  36. where
  37. X: Into<RM::Config>,
  38. {
  39. let pool = Pool::new(db.into());
  40. Self::migrate(pool.get().map_err(|e| Error::Database(Box::new(e)))?).await?;
  41. Ok(Self { pool })
  42. }
  43. /// Migrate
  44. async fn migrate(conn: PooledResource<RM>) -> Result<(), Error> {
  45. let tx = ConnectionWithTransaction::new(conn).await?;
  46. migrate(&tx, RM::Connection::name(), MIGRATIONS).await?;
  47. tx.commit().await?;
  48. Ok(())
  49. }
  50. }
  51. #[rustfmt::skip]
  52. mod migrations {
  53. include!(concat!(env!("OUT_DIR"), "/migrations_mint_auth.rs"));
  54. }
  55. #[async_trait]
  56. impl<RM> MintAuthTransaction<database::Error> for SQLTransaction<RM>
  57. where
  58. RM: DatabasePool + 'static,
  59. {
  60. #[instrument(skip(self))]
  61. async fn set_active_keyset(&mut self, id: Id) -> Result<(), database::Error> {
  62. tracing::info!("Setting auth keyset {id} active");
  63. query(
  64. r#"
  65. UPDATE keyset
  66. SET active = CASE
  67. WHEN id = :id THEN TRUE
  68. ELSE FALSE
  69. END;
  70. "#,
  71. )?
  72. .bind("id", id.to_string())
  73. .execute(&self.inner)
  74. .await?;
  75. Ok(())
  76. }
  77. async fn add_keyset_info(&mut self, keyset: MintKeySetInfo) -> Result<(), database::Error> {
  78. query(
  79. r#"
  80. INSERT INTO
  81. keyset (
  82. id, unit, active, valid_from, valid_to, derivation_path,
  83. amounts, input_fee_ppk, derivation_path_index
  84. )
  85. VALUES (
  86. :id, :unit, :active, :valid_from, :valid_to, :derivation_path,
  87. :amounts, :input_fee_ppk, :derivation_path_index
  88. )
  89. ON CONFLICT(id) DO UPDATE SET
  90. unit = excluded.unit,
  91. active = excluded.active,
  92. valid_from = excluded.valid_from,
  93. valid_to = excluded.valid_to,
  94. derivation_path = excluded.derivation_path,
  95. amounts = excluded.amounts,
  96. input_fee_ppk = excluded.input_fee_ppk,
  97. derivation_path_index = excluded.derivation_path_index
  98. "#,
  99. )?
  100. .bind("id", keyset.id.to_string())
  101. .bind("unit", keyset.unit.to_string())
  102. .bind("active", keyset.active)
  103. .bind("valid_from", keyset.valid_from as i64)
  104. .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
  105. .bind("derivation_path", keyset.derivation_path.to_string())
  106. .bind("amounts", serde_json::to_string(&keyset.amounts).ok())
  107. .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
  108. .bind("derivation_path_index", keyset.derivation_path_index)
  109. .execute(&self.inner)
  110. .await?;
  111. Ok(())
  112. }
  113. async fn add_proof(&mut self, proof: AuthProof) -> Result<(), database::Error> {
  114. let y = proof.y()?;
  115. if let Err(err) = query(
  116. r#"
  117. INSERT INTO proof
  118. (y, keyset_id, secret, c, state)
  119. VALUES
  120. (:y, :keyset_id, :secret, :c, :state)
  121. "#,
  122. )?
  123. .bind("y", y.to_bytes().to_vec())
  124. .bind("keyset_id", proof.keyset_id.to_string())
  125. .bind("secret", proof.secret.to_string())
  126. .bind("c", proof.c.to_bytes().to_vec())
  127. .bind("state", "UNSPENT".to_string())
  128. .execute(&self.inner)
  129. .await
  130. {
  131. tracing::debug!("Attempting to add known proof. Skipping.... {:?}", err);
  132. }
  133. Ok(())
  134. }
  135. async fn update_proof_state(
  136. &mut self,
  137. y: &PublicKey,
  138. proofs_state: State,
  139. ) -> Result<Option<State>, Self::Err> {
  140. let current_state = query(r#"SELECT state FROM proof WHERE y = :y FOR UPDATE"#)?
  141. .bind("y", y.to_bytes().to_vec())
  142. .pluck(&self.inner)
  143. .await?
  144. .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
  145. .transpose()?;
  146. query(r#"UPDATE proof SET state = :new_state WHERE y = :y"#)?
  147. .bind("y", y.to_bytes().to_vec())
  148. .bind("new_state", proofs_state.to_string())
  149. .execute(&self.inner)
  150. .await?;
  151. Ok(current_state)
  152. }
  153. async fn add_blind_signatures(
  154. &mut self,
  155. blinded_messages: &[PublicKey],
  156. blind_signatures: &[BlindSignature],
  157. ) -> Result<(), database::Error> {
  158. for (message, signature) in blinded_messages.iter().zip(blind_signatures) {
  159. query(
  160. r#"
  161. INSERT
  162. INTO blind_signature
  163. (blinded_message, amount, keyset_id, c)
  164. VALUES
  165. (:blinded_message, :amount, :keyset_id, :c)
  166. "#,
  167. )?
  168. .bind("blinded_message", message.to_bytes().to_vec())
  169. .bind("amount", u64::from(signature.amount) as i64)
  170. .bind("keyset_id", signature.keyset_id.to_string())
  171. .bind("c", signature.c.to_bytes().to_vec())
  172. .execute(&self.inner)
  173. .await?;
  174. }
  175. Ok(())
  176. }
  177. async fn add_protected_endpoints(
  178. &mut self,
  179. protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
  180. ) -> Result<(), database::Error> {
  181. for (endpoint, auth) in protected_endpoints.iter() {
  182. if let Err(err) = query(
  183. r#"
  184. INSERT INTO protected_endpoints
  185. (endpoint, auth)
  186. VALUES (:endpoint, :auth)
  187. ON CONFLICT (endpoint) DO UPDATE SET
  188. auth = EXCLUDED.auth;
  189. "#,
  190. )?
  191. .bind("endpoint", serde_json::to_string(endpoint)?)
  192. .bind("auth", serde_json::to_string(auth)?)
  193. .execute(&self.inner)
  194. .await
  195. {
  196. tracing::debug!(
  197. "Attempting to add protected endpoint. Skipping.... {:?}",
  198. err
  199. );
  200. }
  201. }
  202. Ok(())
  203. }
  204. async fn remove_protected_endpoints(
  205. &mut self,
  206. protected_endpoints: Vec<ProtectedEndpoint>,
  207. ) -> Result<(), database::Error> {
  208. query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)?
  209. .bind_vec(
  210. "endpoints",
  211. protected_endpoints
  212. .iter()
  213. .map(serde_json::to_string)
  214. .collect::<Result<_, _>>()?,
  215. )
  216. .execute(&self.inner)
  217. .await?;
  218. Ok(())
  219. }
  220. }
  221. #[async_trait]
  222. impl<RM> MintAuthDatabase for SQLMintAuthDatabase<RM>
  223. where
  224. RM: DatabasePool + 'static,
  225. {
  226. type Err = database::Error;
  227. async fn begin_transaction<'a>(
  228. &'a self,
  229. ) -> Result<Box<dyn MintAuthTransaction<database::Error> + Send + Sync + 'a>, database::Error>
  230. {
  231. Ok(Box::new(SQLTransaction {
  232. inner: ConnectionWithTransaction::new(
  233. self.pool.get().map_err(|e| Error::Database(Box::new(e)))?,
  234. )
  235. .await?,
  236. }))
  237. }
  238. async fn get_active_keyset_id(&self) -> Result<Option<Id>, Self::Err> {
  239. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  240. Ok(query(
  241. r#"
  242. SELECT
  243. id
  244. FROM
  245. keyset
  246. WHERE
  247. active = :active;
  248. "#,
  249. )?
  250. .bind("active", true)
  251. .pluck(&*conn)
  252. .await?
  253. .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
  254. .transpose()?)
  255. }
  256. async fn get_keyset_info(&self, id: &Id) -> Result<Option<MintKeySetInfo>, Self::Err> {
  257. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  258. Ok(query(
  259. r#"SELECT
  260. id,
  261. unit,
  262. active,
  263. valid_from,
  264. valid_to,
  265. derivation_path,
  266. derivation_path_index,
  267. amounts,
  268. input_fee_ppk
  269. FROM
  270. keyset
  271. WHERE id=:id"#,
  272. )?
  273. .bind("id", id.to_string())
  274. .fetch_one(&*conn)
  275. .await?
  276. .map(sql_row_to_keyset_info)
  277. .transpose()?)
  278. }
  279. async fn get_keyset_infos(&self) -> Result<Vec<MintKeySetInfo>, Self::Err> {
  280. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  281. Ok(query(
  282. r#"SELECT
  283. id,
  284. unit,
  285. active,
  286. valid_from,
  287. valid_to,
  288. derivation_path,
  289. derivation_path_index,
  290. amounts,
  291. input_fee_ppk
  292. FROM
  293. keyset
  294. WHERE id=:id"#,
  295. )?
  296. .fetch_all(&*conn)
  297. .await?
  298. .into_iter()
  299. .map(sql_row_to_keyset_info)
  300. .collect::<Result<Vec<_>, _>>()?)
  301. }
  302. async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
  303. if ys.is_empty() {
  304. return Ok(vec![]);
  305. }
  306. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  307. let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
  308. .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
  309. .fetch_all(&*conn)
  310. .await?
  311. .into_iter()
  312. .map(|row| {
  313. Ok((
  314. column_as_string!(&row[0], PublicKey::from_hex, PublicKey::from_slice),
  315. column_as_string!(&row[1], State::from_str),
  316. ))
  317. })
  318. .collect::<Result<HashMap<_, _>, Error>>()?;
  319. Ok(ys.iter().map(|y| current_states.remove(y)).collect())
  320. }
  321. async fn get_blind_signatures(
  322. &self,
  323. blinded_messages: &[PublicKey],
  324. ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
  325. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  326. let mut blinded_signatures = query(
  327. r#"SELECT
  328. keyset_id,
  329. amount,
  330. c,
  331. dleq_e,
  332. dleq_s,
  333. blinded_message,
  334. FROM
  335. blind_signature
  336. WHERE blinded_message IN (:blinded_message)
  337. "#,
  338. )?
  339. .bind_vec(
  340. "blinded_message",
  341. blinded_messages
  342. .iter()
  343. .map(|bm| bm.to_bytes().to_vec())
  344. .collect(),
  345. )
  346. .fetch_all(&*conn)
  347. .await?
  348. .into_iter()
  349. .map(|mut row| {
  350. Ok((
  351. column_as_string!(
  352. &row.pop().ok_or(Error::InvalidDbResponse)?,
  353. PublicKey::from_hex,
  354. PublicKey::from_slice
  355. ),
  356. sql_row_to_blind_signature(row)?,
  357. ))
  358. })
  359. .collect::<Result<HashMap<_, _>, Error>>()?;
  360. Ok(blinded_messages
  361. .iter()
  362. .map(|bm| blinded_signatures.remove(bm))
  363. .collect())
  364. }
  365. async fn get_auth_for_endpoint(
  366. &self,
  367. protected_endpoint: ProtectedEndpoint,
  368. ) -> Result<Option<AuthRequired>, Self::Err> {
  369. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  370. Ok(
  371. query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)?
  372. .bind("endpoint", serde_json::to_string(&protected_endpoint)?)
  373. .pluck(&*conn)
  374. .await?
  375. .map(|auth| {
  376. Ok::<_, Error>(column_as_string!(
  377. auth,
  378. serde_json::from_str,
  379. serde_json::from_slice
  380. ))
  381. })
  382. .transpose()?,
  383. )
  384. }
  385. async fn get_auth_for_endpoints(
  386. &self,
  387. ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
  388. let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
  389. Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)?
  390. .fetch_all(&*conn)
  391. .await?
  392. .into_iter()
  393. .map(|row| {
  394. let endpoint =
  395. column_as_string!(&row[0], serde_json::from_str, serde_json::from_slice);
  396. let auth = column_as_string!(&row[1], serde_json::from_str, serde_json::from_slice);
  397. Ok((endpoint, Some(auth)))
  398. })
  399. .collect::<Result<HashMap<_, _>, Error>>()?)
  400. }
  401. }