auth_wallet.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. use std::collections::HashMap;
  2. use std::sync::Arc;
  3. use cdk_common::database::{self, WalletDatabase};
  4. use cdk_common::mint_url::MintUrl;
  5. use cdk_common::{AuthProof, Id, Keys, MintInfo};
  6. use serde::{Deserialize, Serialize};
  7. use tokio::sync::RwLock;
  8. use tracing::instrument;
  9. use super::AuthMintConnector;
  10. use crate::amount::SplitTarget;
  11. use crate::dhke::construct_proofs;
  12. use crate::nuts::nut22::MintAuthRequest;
  13. use crate::nuts::{
  14. nut12, AuthRequired, AuthToken, BlindAuthToken, CurrencyUnit, KeySetInfo, PreMintSecrets,
  15. Proofs, ProtectedEndpoint, State,
  16. };
  17. use crate::types::ProofInfo;
  18. use crate::wallet::mint_connector::AuthHttpClient;
  19. use crate::wallet::mint_metadata_cache::MintMetadataCache;
  20. use crate::{Amount, Error, OidcClient};
  21. /// JWT Claims structure for decoding tokens
  22. #[derive(Debug, Serialize, Deserialize)]
  23. struct _Claims {
  24. /// Subject
  25. sub: Option<String>,
  26. /// Expiration time (as UTC timestamp)
  27. exp: Option<u64>,
  28. /// Issued at (as UTC timestamp)
  29. iat: Option<u64>,
  30. }
  31. /// CDK Auth Wallet
  32. ///
  33. /// A [`AuthWallet`] is for auth operations with a single mint.
  34. #[derive(Debug, Clone)]
  35. pub struct AuthWallet {
  36. /// Mint Url
  37. pub mint_url: MintUrl,
  38. /// Storage backend
  39. pub localstore: Arc<dyn WalletDatabase<Err = database::Error> + Send + Sync>,
  40. /// Mint metadata cache (lock-free cached access to keys, keysets, and mint info)
  41. pub metadata_cache: Arc<MintMetadataCache>,
  42. /// Protected methods
  43. pub protected_endpoints: Arc<RwLock<HashMap<ProtectedEndpoint, AuthRequired>>>,
  44. /// Refresh token for auth
  45. refresh_token: Arc<RwLock<Option<String>>>,
  46. auth_client: Arc<dyn AuthMintConnector + Send + Sync>,
  47. /// OIDC client for authentication
  48. oidc_client: Arc<RwLock<Option<OidcClient>>>,
  49. }
  50. impl AuthWallet {
  51. /// Create a new [`AuthWallet`] instance
  52. pub fn new(
  53. mint_url: MintUrl,
  54. cat: Option<AuthToken>,
  55. localstore: Arc<dyn WalletDatabase<Err = database::Error> + Send + Sync>,
  56. metadata_cache: Arc<MintMetadataCache>,
  57. protected_endpoints: HashMap<ProtectedEndpoint, AuthRequired>,
  58. oidc_client: Option<OidcClient>,
  59. ) -> Self {
  60. let http_client = Arc::new(AuthHttpClient::new(mint_url.clone(), cat));
  61. Self {
  62. mint_url,
  63. localstore,
  64. metadata_cache,
  65. protected_endpoints: Arc::new(RwLock::new(protected_endpoints)),
  66. refresh_token: Arc::new(RwLock::new(None)),
  67. auth_client: http_client,
  68. oidc_client: Arc::new(RwLock::new(oidc_client)),
  69. }
  70. }
  71. /// Get the current auth token
  72. #[instrument(skip(self))]
  73. pub async fn get_auth_token(&self) -> Result<AuthToken, Error> {
  74. self.auth_client.get_auth_token().await
  75. }
  76. /// Set a new auth token
  77. #[instrument(skip_all)]
  78. pub async fn verify_cat(&self, token: AuthToken) -> Result<(), Error> {
  79. match &token {
  80. AuthToken::ClearAuth(clear_token) => {
  81. if let Some(oidc) = self.oidc_client.read().await.as_ref() {
  82. oidc.verify_cat(clear_token).await?;
  83. }
  84. Ok(())
  85. }
  86. AuthToken::BlindAuth(_) => Err(Error::Custom(
  87. "Cannot set blind auth token directly".to_string(),
  88. )),
  89. }
  90. }
  91. /// Set a new auth token
  92. #[instrument(skip_all)]
  93. pub async fn set_auth_token(&self, token: AuthToken) -> Result<(), Error> {
  94. match &token {
  95. AuthToken::ClearAuth(clear_token) => {
  96. if let Some(oidc) = self.oidc_client.read().await.as_ref() {
  97. oidc.verify_cat(clear_token).await?;
  98. }
  99. self.auth_client.set_auth_token(token).await
  100. }
  101. AuthToken::BlindAuth(_) => Err(Error::Custom(
  102. "Cannot set blind auth token directly".to_string(),
  103. )),
  104. }
  105. }
  106. /// Get the current refresh token if one exists
  107. #[instrument(skip(self))]
  108. pub async fn get_refresh_token(&self) -> Option<String> {
  109. self.refresh_token.read().await.clone()
  110. }
  111. /// Set a new refresh token
  112. #[instrument(skip(self))]
  113. pub async fn set_refresh_token(&self, token: Option<String>) {
  114. *self.refresh_token.write().await = token;
  115. }
  116. /// Get the OIDC client if one exists
  117. #[instrument(skip(self))]
  118. pub async fn get_oidc_client(&self) -> Option<OidcClient> {
  119. self.oidc_client.read().await.clone()
  120. }
  121. /// Set a new OIDC client
  122. #[instrument(skip(self))]
  123. pub async fn set_oidc_client(&self, client: Option<OidcClient>) {
  124. *self.oidc_client.write().await = client;
  125. }
  126. /// Refresh the access token using the stored refresh token
  127. #[instrument(skip(self))]
  128. pub async fn refresh_access_token(&self) -> Result<(), Error> {
  129. if let Some(oidc) = self.oidc_client.read().await.as_ref() {
  130. if let Some(refresh_token) = self.get_refresh_token().await {
  131. let mint_info = self
  132. .get_mint_info()
  133. .await?
  134. .ok_or(Error::CouldNotGetMintInfo)?;
  135. let token_response = oidc
  136. .refresh_access_token(
  137. mint_info.client_id().ok_or(Error::CouldNotGetMintInfo)?,
  138. refresh_token,
  139. )
  140. .await?;
  141. // Store new refresh token if provided
  142. self.set_refresh_token(token_response.refresh_token).await;
  143. // Set new access token
  144. self.set_auth_token(AuthToken::ClearAuth(token_response.access_token))
  145. .await?;
  146. return Ok(());
  147. }
  148. }
  149. Err(Error::Custom(
  150. "No refresh token or OIDC client available".to_string(),
  151. ))
  152. }
  153. /// Query mint for current mint information
  154. #[instrument(skip(self))]
  155. pub async fn get_mint_info(&self) -> Result<Option<MintInfo>, Error> {
  156. self.auth_client
  157. .get_mint_info()
  158. .await
  159. .map(Some)
  160. .or(Ok(None))
  161. }
  162. /// Fetch keys for mint keyset
  163. ///
  164. /// Returns keys from metadata cache if available, fetches from mint if not.
  165. #[instrument(skip(self))]
  166. pub async fn load_keyset_keys(&self, keyset_id: Id) -> Result<Keys, Error> {
  167. let metadata = self
  168. .metadata_cache
  169. .load_auth(&self.localstore, &self.auth_client)
  170. .await?;
  171. let active = metadata
  172. .active_keysets
  173. .iter()
  174. .find(|x| x.unit == CurrencyUnit::Auth)
  175. .cloned()
  176. .ok_or(Error::NoActiveKeyset)?;
  177. metadata
  178. .keys
  179. .get(&active.id)
  180. .map(|x| (*(x.clone())).clone())
  181. .ok_or(Error::NoActiveKeyset)
  182. }
  183. /// Get blind auth keysets from metadata cache
  184. ///
  185. /// Checks the metadata cache for auth keysets. If cache is not populated,
  186. /// fetches from the mint server and updates the cache.
  187. /// This is the main method for getting auth keysets in operations that can work offline
  188. /// but will fall back to online if needed.
  189. #[instrument(skip(self))]
  190. pub async fn load_mint_keysets(&self) -> Result<Vec<KeySetInfo>, Error> {
  191. let metadata = self
  192. .metadata_cache
  193. .load_auth(&self.localstore, &self.auth_client)
  194. .await?;
  195. let auth_keysets = metadata
  196. .keysets
  197. .iter()
  198. .filter_map(|(_, k)| {
  199. if k.unit == CurrencyUnit::Auth {
  200. Some((*(k.clone())).clone())
  201. } else {
  202. None
  203. }
  204. })
  205. .collect::<Vec<_>>();
  206. if !auth_keysets.is_empty() {
  207. Ok(auth_keysets)
  208. } else {
  209. Err(Error::UnknownKeySet)
  210. }
  211. }
  212. /// Refresh blind auth keysets by fetching the latest from mint
  213. ///
  214. /// Fetches the latest blind auth keyset information from the mint server,
  215. /// updating the metadata cache and database. Returns only the keysets with
  216. /// Auth currency unit. Use this when you need the most up-to-date keyset information.
  217. #[instrument(skip(self))]
  218. pub async fn refresh_keysets(&self) -> Result<Vec<KeySetInfo>, Error> {
  219. tracing::debug!("Refreshing auth keysets from mint");
  220. self.load_mint_keysets().await
  221. }
  222. /// Get the first active blind auth keyset - always goes online
  223. ///
  224. /// This method always goes online to refresh keysets from the mint and then returns
  225. /// the first active keyset found. Use this when you need the most up-to-date
  226. /// keyset information for blind auth operations.
  227. #[instrument(skip(self))]
  228. pub async fn fetch_active_keyset(&self) -> Result<KeySetInfo, Error> {
  229. let auth_keysets = self.refresh_keysets().await?;
  230. let keyset = auth_keysets.first().ok_or(Error::NoActiveKeyset)?;
  231. Ok(keyset.clone())
  232. }
  233. /// Get unspent auth proofs from local database only - offline operation
  234. ///
  235. /// Returns auth proofs from the local database that are in the Unspent state.
  236. /// This is an offline operation that does not contact the mint.
  237. #[instrument(skip(self))]
  238. pub async fn get_unspent_auth_proofs(&self) -> Result<Vec<AuthProof>, Error> {
  239. Ok(self
  240. .localstore
  241. .get_proofs(
  242. Some(self.mint_url.clone()),
  243. Some(CurrencyUnit::Auth),
  244. Some(vec![State::Unspent]),
  245. None,
  246. )
  247. .await?
  248. .into_iter()
  249. .map(|p| p.proof.try_into())
  250. .collect::<Result<Vec<AuthProof>, _>>()?)
  251. }
  252. /// Check if and what kind of auth is required for a method
  253. #[instrument(skip(self))]
  254. pub async fn is_protected(&self, method: &ProtectedEndpoint) -> Option<AuthRequired> {
  255. let protected_endpoints = self.protected_endpoints.read().await;
  256. protected_endpoints.get(method).copied()
  257. }
  258. /// Get Auth Token
  259. #[instrument(skip(self))]
  260. pub async fn get_blind_auth_token(&self) -> Result<Option<BlindAuthToken>, Error> {
  261. let unspent = self.get_unspent_auth_proofs().await?;
  262. let auth_proof = match unspent.first() {
  263. Some(proof) => {
  264. self.localstore
  265. .update_proofs(vec![], vec![proof.y()?])
  266. .await?;
  267. proof
  268. }
  269. None => return Ok(None),
  270. };
  271. Ok(Some(BlindAuthToken {
  272. auth_proof: auth_proof.clone(),
  273. }))
  274. }
  275. /// Auth for request
  276. #[instrument(skip(self))]
  277. pub async fn get_auth_for_request(
  278. &self,
  279. method: &ProtectedEndpoint,
  280. ) -> Result<Option<AuthToken>, Error> {
  281. match self.is_protected(method).await {
  282. Some(auth) => match auth {
  283. AuthRequired::Clear => {
  284. tracing::trace!("Clear auth needed for request.");
  285. self.auth_client.get_auth_token().await.map(Some)
  286. }
  287. AuthRequired::Blind => {
  288. tracing::trace!("Blind auth needed for request getting Auth proof.");
  289. let proof = self.get_blind_auth_token().await?.ok_or_else(|| {
  290. tracing::debug!(
  291. "Insufficient blind auth proofs in wallet. Must mint bats."
  292. );
  293. Error::InsufficientBlindAuthTokens
  294. })?;
  295. let auth_token = AuthToken::BlindAuth(proof.without_dleq());
  296. Ok(Some(auth_token))
  297. }
  298. },
  299. None => Ok(None),
  300. }
  301. }
  302. /// Mint blind auth
  303. #[instrument(skip(self))]
  304. pub async fn mint_blind_auth(&self, amount: Amount) -> Result<Proofs, Error> {
  305. tracing::debug!("Minting {} blind auth proofs", amount);
  306. // Check that mint is in store of mints
  307. if self
  308. .localstore
  309. .get_mint(self.mint_url.clone())
  310. .await?
  311. .is_none()
  312. {
  313. self.get_mint_info().await?;
  314. }
  315. let auth_token = self.auth_client.get_auth_token().await?;
  316. match &auth_token {
  317. AuthToken::ClearAuth(cat) => {
  318. if cat.is_empty() {
  319. tracing::warn!("Auth Cat is not set");
  320. return Err(Error::ClearAuthRequired);
  321. }
  322. if let Err(err) = self.verify_cat(auth_token).await {
  323. tracing::warn!("Current cat is invalid {}", err);
  324. }
  325. let has_refresh;
  326. {
  327. has_refresh = self.refresh_token.read().await.is_some();
  328. }
  329. if has_refresh {
  330. tracing::info!("Attempting to refresh using refresh token");
  331. self.refresh_access_token().await?;
  332. } else {
  333. tracing::warn!(
  334. "Wallet cat is invalid and there is no refresh token please reauth"
  335. );
  336. }
  337. }
  338. AuthToken::BlindAuth(_) => {
  339. tracing::error!("Blind auth set as client cat");
  340. return Err(Error::ClearAuthFailed);
  341. }
  342. }
  343. let keysets = self
  344. .load_mint_keysets()
  345. .await?
  346. .into_iter()
  347. .map(|x| (x.id, x))
  348. .collect::<HashMap<_, _>>();
  349. let active_keyset_id = self.fetch_active_keyset().await?.id;
  350. let fee_and_amounts = (
  351. keysets
  352. .get(&active_keyset_id)
  353. .map(|x| x.input_fee_ppk)
  354. .unwrap_or_default(),
  355. self.load_keyset_keys(active_keyset_id)
  356. .await?
  357. .iter()
  358. .map(|(amount, _)| amount.to_u64())
  359. .collect::<Vec<_>>(),
  360. )
  361. .into();
  362. let premint_secrets = PreMintSecrets::random(
  363. active_keyset_id,
  364. amount,
  365. &SplitTarget::Value(1.into()),
  366. &fee_and_amounts,
  367. )?;
  368. let request = MintAuthRequest {
  369. outputs: premint_secrets.blinded_messages(),
  370. };
  371. let mint_res = self.auth_client.post_mint_blind_auth(request).await?;
  372. let keys = self.load_keyset_keys(active_keyset_id).await?;
  373. // Verify the signature DLEQ is valid
  374. {
  375. assert!(mint_res.signatures.len() == premint_secrets.secrets.len());
  376. for (sig, premint) in mint_res.signatures.iter().zip(&premint_secrets.secrets) {
  377. let keys = self.load_keyset_keys(sig.keyset_id).await?;
  378. let key = keys.amount_key(sig.amount).ok_or(Error::AmountKey)?;
  379. match sig.verify_dleq(key, premint.blinded_message.blinded_secret) {
  380. Ok(_) => (),
  381. Err(nut12::Error::MissingDleqProof) => {
  382. tracing::warn!("Signature for bat returned without dleq proof.");
  383. return Err(Error::DleqProofNotProvided);
  384. }
  385. Err(_) => return Err(Error::CouldNotVerifyDleq),
  386. }
  387. }
  388. }
  389. let proofs = construct_proofs(
  390. mint_res.signatures,
  391. premint_secrets.rs(),
  392. premint_secrets.secrets(),
  393. &keys,
  394. )?;
  395. let proof_infos = proofs
  396. .clone()
  397. .into_iter()
  398. .map(|proof| {
  399. ProofInfo::new(
  400. proof,
  401. self.mint_url.clone(),
  402. State::Unspent,
  403. crate::nuts::CurrencyUnit::Auth,
  404. )
  405. })
  406. .collect::<Result<Vec<ProofInfo>, _>>()?;
  407. // Add new proofs to store
  408. self.localstore.update_proofs(proof_infos, vec![]).await?;
  409. Ok(proofs)
  410. }
  411. /// Total unspent balance of wallet
  412. #[instrument(skip(self))]
  413. pub async fn total_blind_auth_balance(&self) -> Result<Amount, Error> {
  414. Ok(Amount::from(
  415. self.get_unspent_auth_proofs().await?.len() as u64
  416. ))
  417. }
  418. }