nut22.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. //! 22 Blind Auth
  2. use std::fmt;
  3. use bitcoin::base64::engine::general_purpose;
  4. use bitcoin::base64::Engine;
  5. use serde::{Deserialize, Serialize};
  6. use thiserror::Error;
  7. use super::nut21::ProtectedEndpoint;
  8. use crate::dhke::hash_to_curve;
  9. use crate::secret::Secret;
  10. use crate::util::hex;
  11. use crate::{BlindedMessage, Id, Proof, ProofDleq, PublicKey};
  12. /// NUT22 Error
  13. #[derive(Debug, Error)]
  14. pub enum Error {
  15. /// Invalid Prefix
  16. #[error("Invalid prefix")]
  17. InvalidPrefix,
  18. /// Dleq proof not included
  19. #[error("Dleq Proof not included for auth proof")]
  20. DleqProofNotIncluded,
  21. /// Hex Error
  22. #[error(transparent)]
  23. HexError(#[from] hex::Error),
  24. /// Base64 error
  25. #[error(transparent)]
  26. Base64Error(#[from] bitcoin::base64::DecodeError),
  27. /// Serde Json error
  28. #[error(transparent)]
  29. SerdeJsonError(#[from] serde_json::Error),
  30. /// Utf8 parse error
  31. #[error(transparent)]
  32. Utf8ParseError(#[from] std::string::FromUtf8Error),
  33. /// DHKE error
  34. #[error(transparent)]
  35. DHKE(#[from] crate::dhke::Error),
  36. }
  37. /// Blind auth settings
  38. #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
  39. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  40. pub struct Settings {
  41. /// Max number of blind auth tokens that can be minted per request
  42. pub bat_max_mint: u64,
  43. /// Protected endpoints
  44. pub protected_endpoints: Vec<ProtectedEndpoint>,
  45. }
  46. impl Settings {
  47. /// Create new [`Settings`]
  48. pub fn new(bat_max_mint: u64, protected_endpoints: Vec<ProtectedEndpoint>) -> Self {
  49. Self {
  50. bat_max_mint,
  51. protected_endpoints,
  52. }
  53. }
  54. }
  55. // Custom deserializer for Settings to expand regex patterns in protected endpoints
  56. impl<'de> Deserialize<'de> for Settings {
  57. fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
  58. where
  59. D: serde::Deserializer<'de>,
  60. {
  61. use std::collections::HashSet;
  62. use super::nut21::matching_route_paths;
  63. // Define a temporary struct to deserialize the raw data
  64. #[derive(Deserialize)]
  65. struct RawSettings {
  66. bat_max_mint: u64,
  67. protected_endpoints: Vec<RawProtectedEndpoint>,
  68. }
  69. #[derive(Deserialize)]
  70. struct RawProtectedEndpoint {
  71. method: super::nut21::Method,
  72. path: String,
  73. }
  74. // Deserialize into the temporary struct
  75. let raw = RawSettings::deserialize(deserializer)?;
  76. // Process protected endpoints, expanding regex patterns if present
  77. let mut protected_endpoints = HashSet::new();
  78. for raw_endpoint in raw.protected_endpoints {
  79. let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
  80. serde::de::Error::custom(format!(
  81. "Invalid regex pattern '{}': {}",
  82. raw_endpoint.path, e
  83. ))
  84. })?;
  85. for path in expanded_paths {
  86. protected_endpoints.insert(super::nut21::ProtectedEndpoint::new(
  87. raw_endpoint.method,
  88. path,
  89. ));
  90. }
  91. }
  92. // Create the final Settings struct
  93. Ok(Settings {
  94. bat_max_mint: raw.bat_max_mint,
  95. protected_endpoints: protected_endpoints.into_iter().collect(),
  96. })
  97. }
  98. }
  99. /// Auth Token
  100. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  101. pub enum AuthToken {
  102. /// Clear Auth token
  103. ClearAuth(String),
  104. /// Blind Auth token
  105. BlindAuth(BlindAuthToken),
  106. }
  107. impl fmt::Display for AuthToken {
  108. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  109. match self {
  110. Self::ClearAuth(cat) => cat.fmt(f),
  111. Self::BlindAuth(bat) => bat.fmt(f),
  112. }
  113. }
  114. }
  115. impl AuthToken {
  116. /// Header key for auth token type
  117. pub fn header_key(&self) -> String {
  118. match self {
  119. Self::ClearAuth(_) => "Clear-auth".to_string(),
  120. Self::BlindAuth(_) => "Blind-auth".to_string(),
  121. }
  122. }
  123. }
  124. /// Required Auth
  125. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
  126. pub enum AuthRequired {
  127. /// Clear Auth token
  128. Clear,
  129. /// Blind Auth token
  130. Blind,
  131. }
  132. /// Auth Proofs
  133. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  134. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  135. pub struct AuthProof {
  136. /// `Keyset id`
  137. #[serde(rename = "id")]
  138. pub keyset_id: Id,
  139. /// Secret message
  140. #[cfg_attr(feature = "swagger", schema(value_type = String))]
  141. pub secret: Secret,
  142. /// Unblinded signature
  143. #[serde(rename = "C")]
  144. #[cfg_attr(feature = "swagger", schema(value_type = String))]
  145. pub c: PublicKey,
  146. /// Auth Proof Dleq
  147. pub dleq: Option<ProofDleq>,
  148. }
  149. impl AuthProof {
  150. /// Y of AuthProof
  151. pub fn y(&self) -> Result<PublicKey, Error> {
  152. Ok(hash_to_curve(self.secret.as_bytes())?)
  153. }
  154. }
  155. impl From<AuthProof> for Proof {
  156. fn from(value: AuthProof) -> Self {
  157. Self {
  158. amount: 1.into(),
  159. keyset_id: value.keyset_id,
  160. secret: value.secret,
  161. c: value.c,
  162. witness: None,
  163. dleq: value.dleq,
  164. }
  165. }
  166. }
  167. impl TryFrom<Proof> for AuthProof {
  168. type Error = Error;
  169. fn try_from(value: Proof) -> Result<Self, Self::Error> {
  170. Ok(Self {
  171. keyset_id: value.keyset_id,
  172. secret: value.secret,
  173. c: value.c,
  174. dleq: value.dleq,
  175. })
  176. }
  177. }
  178. /// Blind Auth Token
  179. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  180. pub struct BlindAuthToken {
  181. /// [AuthProof]
  182. pub auth_proof: AuthProof,
  183. }
  184. impl BlindAuthToken {
  185. /// Create new [ `BlindAuthToken`]
  186. pub fn new(auth_proof: AuthProof) -> Self {
  187. BlindAuthToken { auth_proof }
  188. }
  189. /// Remove DLEQ
  190. ///
  191. /// We do not send the DLEQ to the mint as it links redemption and creation
  192. pub fn without_dleq(&self) -> Self {
  193. Self {
  194. auth_proof: AuthProof {
  195. keyset_id: self.auth_proof.keyset_id,
  196. secret: self.auth_proof.secret.clone(),
  197. c: self.auth_proof.c,
  198. dleq: None,
  199. },
  200. }
  201. }
  202. }
  203. impl fmt::Display for BlindAuthToken {
  204. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  205. let json_string = serde_json::to_string(&self.auth_proof).map_err(|_| fmt::Error)?;
  206. let encoded = general_purpose::URL_SAFE.encode(json_string);
  207. write!(f, "authA{encoded}")
  208. }
  209. }
  210. impl std::str::FromStr for BlindAuthToken {
  211. type Err = Error;
  212. fn from_str(s: &str) -> Result<Self, Self::Err> {
  213. // Check prefix and extract the base64 encoded part in one step
  214. let encoded = s.strip_prefix("authA").ok_or(Error::InvalidPrefix)?;
  215. // Decode the base64 URL-safe string
  216. let json_string = general_purpose::URL_SAFE.decode(encoded)?;
  217. // Convert bytes to UTF-8 string
  218. let json_str = String::from_utf8(json_string)?;
  219. // Deserialize the JSON string into AuthProof
  220. let auth_proof: AuthProof = serde_json::from_str(&json_str)?;
  221. Ok(BlindAuthToken { auth_proof })
  222. }
  223. }
  224. /// Mint auth request [NUT-XX]
  225. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
  226. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  227. pub struct MintAuthRequest {
  228. /// Outputs
  229. #[cfg_attr(feature = "swagger", schema(max_items = 1_000))]
  230. pub outputs: Vec<BlindedMessage>,
  231. }
  232. impl MintAuthRequest {
  233. /// Count of tokens
  234. pub fn amount(&self) -> u64 {
  235. self.outputs.len() as u64
  236. }
  237. }
  238. #[cfg(test)]
  239. mod tests {
  240. use std::collections::HashSet;
  241. use strum::IntoEnumIterator;
  242. use super::super::nut21::{Method, RoutePath};
  243. use super::*;
  244. #[test]
  245. fn test_settings_deserialize_direct_paths() {
  246. let json = r#"{
  247. "bat_max_mint": 10,
  248. "protected_endpoints": [
  249. {
  250. "method": "GET",
  251. "path": "/v1/mint/bolt11"
  252. },
  253. {
  254. "method": "POST",
  255. "path": "/v1/swap"
  256. }
  257. ]
  258. }"#;
  259. let settings: Settings = serde_json::from_str(json).unwrap();
  260. assert_eq!(settings.bat_max_mint, 10);
  261. assert_eq!(settings.protected_endpoints.len(), 2);
  262. // Check that both paths are included
  263. let paths = settings
  264. .protected_endpoints
  265. .iter()
  266. .map(|ep| (ep.method, ep.path))
  267. .collect::<Vec<_>>();
  268. assert!(paths.contains(&(Method::Get, RoutePath::MintBolt11)));
  269. assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
  270. }
  271. #[test]
  272. fn test_settings_deserialize_with_regex() {
  273. let json = r#"{
  274. "bat_max_mint": 5,
  275. "protected_endpoints": [
  276. {
  277. "method": "GET",
  278. "path": "^/v1/mint/.*"
  279. },
  280. {
  281. "method": "POST",
  282. "path": "/v1/swap"
  283. }
  284. ]
  285. }"#;
  286. let settings: Settings = serde_json::from_str(json).unwrap();
  287. assert_eq!(settings.bat_max_mint, 5);
  288. assert_eq!(settings.protected_endpoints.len(), 5); // 4 mint paths + 1 swap path
  289. let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
  290. ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
  291. ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt11),
  292. ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt11),
  293. ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt12),
  294. ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt12),
  295. ]);
  296. let deserialized_protected = settings.protected_endpoints.into_iter().collect();
  297. assert_eq!(expected_protected, deserialized_protected);
  298. }
  299. #[test]
  300. fn test_settings_deserialize_invalid_regex() {
  301. let json = r#"{
  302. "bat_max_mint": 5,
  303. "protected_endpoints": [
  304. {
  305. "method": "GET",
  306. "path": "(unclosed parenthesis"
  307. }
  308. ]
  309. }"#;
  310. let result = serde_json::from_str::<Settings>(json);
  311. assert!(result.is_err());
  312. }
  313. #[test]
  314. fn test_settings_deserialize_all_paths() {
  315. let json = r#"{
  316. "bat_max_mint": 5,
  317. "protected_endpoints": [
  318. {
  319. "method": "GET",
  320. "path": ".*"
  321. }
  322. ]
  323. }"#;
  324. let settings: Settings = serde_json::from_str(json).unwrap();
  325. assert_eq!(
  326. settings.protected_endpoints.len(),
  327. RoutePath::iter().count()
  328. );
  329. }
  330. }