nut21.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. //! 21 Clear Auth
  2. use std::collections::HashSet;
  3. use std::str::FromStr;
  4. use regex::Regex;
  5. use serde::{Deserialize, Serialize};
  6. use strum::IntoEnumIterator;
  7. use strum_macros::EnumIter;
  8. use thiserror::Error;
  9. /// NUT21 Error
  10. #[derive(Debug, Error)]
  11. pub enum Error {
  12. /// Invalid regex pattern
  13. #[error("Invalid regex pattern: {0}")]
  14. InvalidRegex(#[from] regex::Error),
  15. }
  16. /// Clear Auth Settings
  17. #[derive(Debug, Clone, PartialEq, Eq, Hash, Default, Serialize)]
  18. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  19. pub struct Settings {
  20. /// Openid discovery
  21. pub openid_discovery: String,
  22. /// Client ID
  23. pub client_id: String,
  24. /// Protected endpoints
  25. pub protected_endpoints: Vec<ProtectedEndpoint>,
  26. }
  27. impl Settings {
  28. /// Create new [`Settings`]
  29. pub fn new(
  30. openid_discovery: String,
  31. client_id: String,
  32. protected_endpoints: Vec<ProtectedEndpoint>,
  33. ) -> Self {
  34. Self {
  35. openid_discovery,
  36. client_id,
  37. protected_endpoints,
  38. }
  39. }
  40. }
  41. // Custom deserializer for Settings to expand regex patterns in protected endpoints
  42. impl<'de> Deserialize<'de> for Settings {
  43. fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
  44. where
  45. D: serde::Deserializer<'de>,
  46. {
  47. // Define a temporary struct to deserialize the raw data
  48. #[derive(Deserialize)]
  49. struct RawSettings {
  50. openid_discovery: String,
  51. client_id: String,
  52. protected_endpoints: Vec<RawProtectedEndpoint>,
  53. }
  54. #[derive(Deserialize)]
  55. struct RawProtectedEndpoint {
  56. method: Method,
  57. path: String,
  58. }
  59. // Deserialize into the temporary struct
  60. let raw = RawSettings::deserialize(deserializer)?;
  61. // Process protected endpoints, expanding regex patterns if present
  62. let mut protected_endpoints = HashSet::new();
  63. for raw_endpoint in raw.protected_endpoints {
  64. let expanded_paths = matching_route_paths(&raw_endpoint.path).map_err(|e| {
  65. serde::de::Error::custom(format!(
  66. "Invalid regex pattern '{}': {}",
  67. raw_endpoint.path, e
  68. ))
  69. })?;
  70. for path in expanded_paths {
  71. protected_endpoints.insert(ProtectedEndpoint::new(raw_endpoint.method, path));
  72. }
  73. }
  74. // Create the final Settings struct
  75. Ok(Settings {
  76. openid_discovery: raw.openid_discovery,
  77. client_id: raw.client_id,
  78. protected_endpoints: protected_endpoints.into_iter().collect(),
  79. })
  80. }
  81. }
  82. /// List of the methods and paths that are protected
  83. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
  84. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  85. pub struct ProtectedEndpoint {
  86. /// HTTP Method
  87. pub method: Method,
  88. /// Route path
  89. pub path: RoutePath,
  90. }
  91. impl ProtectedEndpoint {
  92. /// Create [`ProtectedEndpoint`]
  93. pub fn new(method: Method, path: RoutePath) -> Self {
  94. Self { method, path }
  95. }
  96. }
  97. /// HTTP method
  98. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
  99. #[serde(rename_all = "UPPERCASE")]
  100. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  101. pub enum Method {
  102. /// Get
  103. Get,
  104. /// POST
  105. Post,
  106. }
  107. /// Route path
  108. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, EnumIter)]
  109. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  110. #[serde(rename_all = "snake_case")]
  111. pub enum RoutePath {
  112. /// Bolt11 Mint Quote
  113. #[serde(rename = "/v1/mint/quote/bolt11")]
  114. MintQuoteBolt11,
  115. /// Bolt11 Mint
  116. #[serde(rename = "/v1/mint/bolt11")]
  117. MintBolt11,
  118. /// Bolt11 Melt Quote
  119. #[serde(rename = "/v1/melt/quote/bolt11")]
  120. MeltQuoteBolt11,
  121. /// Bolt11 Melt
  122. #[serde(rename = "/v1/melt/bolt11")]
  123. MeltBolt11,
  124. /// Swap
  125. #[serde(rename = "/v1/swap")]
  126. Swap,
  127. /// Checkstate
  128. #[serde(rename = "/v1/checkstate")]
  129. Checkstate,
  130. /// Restore
  131. #[serde(rename = "/v1/restore")]
  132. Restore,
  133. /// Mint Blind Auth
  134. #[serde(rename = "/v1/auth/blind/mint")]
  135. MintBlindAuth,
  136. /// Bolt12 Mint Quote
  137. #[serde(rename = "/v1/mint/quote/bolt12")]
  138. MintQuoteBolt12,
  139. /// Bolt12 Mint
  140. #[serde(rename = "/v1/mint/bolt12")]
  141. MintBolt12,
  142. /// Bolt12 Melt Quote
  143. #[serde(rename = "/v1/melt/quote/bolt12")]
  144. MeltQuoteBolt12,
  145. /// Bolt12 Quote
  146. #[serde(rename = "/v1/melt/bolt12")]
  147. MeltBolt12,
  148. /// WebSocket
  149. #[serde(rename = "/v1/ws")]
  150. Ws,
  151. }
  152. /// Returns [`RoutePath`]s that match regex
  153. pub fn matching_route_paths(pattern: &str) -> Result<Vec<RoutePath>, Error> {
  154. let regex = Regex::from_str(pattern)?;
  155. Ok(RoutePath::iter()
  156. .filter(|path| regex.is_match(&path.to_string()))
  157. .collect())
  158. }
  159. impl std::fmt::Display for RoutePath {
  160. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  161. // Use serde to serialize to a JSON string, then extract the value without quotes
  162. let json_str = match serde_json::to_string(self) {
  163. Ok(s) => s,
  164. Err(_) => return write!(f, "<error>"),
  165. };
  166. // Remove the quotes from the JSON string
  167. let path = json_str.trim_matches('"');
  168. write!(f, "{path}")
  169. }
  170. }
  171. #[cfg(test)]
  172. mod tests {
  173. use super::*;
  174. #[test]
  175. fn test_matching_route_paths_all() {
  176. // Regex that matches all paths
  177. let paths = matching_route_paths(".*").unwrap();
  178. // Should match all variants
  179. assert_eq!(paths.len(), RoutePath::iter().count());
  180. // Verify all variants are included
  181. assert!(paths.contains(&RoutePath::MintQuoteBolt11));
  182. assert!(paths.contains(&RoutePath::MintBolt11));
  183. assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
  184. assert!(paths.contains(&RoutePath::MeltBolt11));
  185. assert!(paths.contains(&RoutePath::Swap));
  186. assert!(paths.contains(&RoutePath::Checkstate));
  187. assert!(paths.contains(&RoutePath::Restore));
  188. assert!(paths.contains(&RoutePath::MintBlindAuth));
  189. assert!(paths.contains(&RoutePath::MintQuoteBolt12));
  190. assert!(paths.contains(&RoutePath::MintBolt12));
  191. }
  192. #[test]
  193. fn test_matching_route_paths_mint_only() {
  194. // Regex that matches only mint paths
  195. let paths = matching_route_paths("^/v1/mint/.*").unwrap();
  196. // Should match only mint paths
  197. assert_eq!(paths.len(), 4);
  198. assert!(paths.contains(&RoutePath::MintQuoteBolt11));
  199. assert!(paths.contains(&RoutePath::MintBolt11));
  200. assert!(paths.contains(&RoutePath::MintQuoteBolt12));
  201. assert!(paths.contains(&RoutePath::MintBolt12));
  202. // Should not match other paths
  203. assert!(!paths.contains(&RoutePath::MeltQuoteBolt11));
  204. assert!(!paths.contains(&RoutePath::MeltBolt11));
  205. assert!(!paths.contains(&RoutePath::MeltQuoteBolt12));
  206. assert!(!paths.contains(&RoutePath::MeltBolt12));
  207. assert!(!paths.contains(&RoutePath::Swap));
  208. }
  209. #[test]
  210. fn test_matching_route_paths_quote_only() {
  211. // Regex that matches only quote paths
  212. let paths = matching_route_paths(".*/quote/.*").unwrap();
  213. // Should match only quote paths
  214. assert_eq!(paths.len(), 4);
  215. assert!(paths.contains(&RoutePath::MintQuoteBolt11));
  216. assert!(paths.contains(&RoutePath::MeltQuoteBolt11));
  217. assert!(paths.contains(&RoutePath::MintQuoteBolt12));
  218. assert!(paths.contains(&RoutePath::MeltQuoteBolt12));
  219. // Should not match non-quote paths
  220. assert!(!paths.contains(&RoutePath::MintBolt11));
  221. assert!(!paths.contains(&RoutePath::MeltBolt11));
  222. }
  223. #[test]
  224. fn test_matching_route_paths_no_match() {
  225. // Regex that matches nothing
  226. let paths = matching_route_paths("/nonexistent/path").unwrap();
  227. // Should match nothing
  228. assert!(paths.is_empty());
  229. }
  230. #[test]
  231. fn test_matching_route_paths_quote_bolt11_only() {
  232. // Regex that matches only quote paths
  233. let paths = matching_route_paths("/v1/mint/quote/bolt11").unwrap();
  234. // Should match only quote paths
  235. assert_eq!(paths.len(), 1);
  236. assert!(paths.contains(&RoutePath::MintQuoteBolt11));
  237. }
  238. #[test]
  239. fn test_matching_route_paths_invalid_regex() {
  240. // Invalid regex pattern
  241. let result = matching_route_paths("(unclosed parenthesis");
  242. // Should return an error for invalid regex
  243. assert!(result.is_err());
  244. assert!(matches!(result.unwrap_err(), Error::InvalidRegex(_)));
  245. }
  246. #[test]
  247. fn test_route_path_to_string() {
  248. // Test that to_string() returns the correct path strings
  249. assert_eq!(
  250. RoutePath::MintQuoteBolt11.to_string(),
  251. "/v1/mint/quote/bolt11"
  252. );
  253. assert_eq!(RoutePath::MintBolt11.to_string(), "/v1/mint/bolt11");
  254. assert_eq!(
  255. RoutePath::MeltQuoteBolt11.to_string(),
  256. "/v1/melt/quote/bolt11"
  257. );
  258. assert_eq!(RoutePath::MeltBolt11.to_string(), "/v1/melt/bolt11");
  259. assert_eq!(RoutePath::Swap.to_string(), "/v1/swap");
  260. assert_eq!(RoutePath::Checkstate.to_string(), "/v1/checkstate");
  261. assert_eq!(RoutePath::Restore.to_string(), "/v1/restore");
  262. assert_eq!(RoutePath::MintBlindAuth.to_string(), "/v1/auth/blind/mint");
  263. }
  264. #[test]
  265. fn test_settings_deserialize_direct_paths() {
  266. let json = r#"{
  267. "openid_discovery": "https://example.com/.well-known/openid-configuration",
  268. "client_id": "client123",
  269. "protected_endpoints": [
  270. {
  271. "method": "GET",
  272. "path": "/v1/mint/bolt11"
  273. },
  274. {
  275. "method": "POST",
  276. "path": "/v1/swap"
  277. }
  278. ]
  279. }"#;
  280. let settings: Settings = serde_json::from_str(json).unwrap();
  281. assert_eq!(
  282. settings.openid_discovery,
  283. "https://example.com/.well-known/openid-configuration"
  284. );
  285. assert_eq!(settings.client_id, "client123");
  286. assert_eq!(settings.protected_endpoints.len(), 2);
  287. // Check that both paths are included
  288. let paths = settings
  289. .protected_endpoints
  290. .iter()
  291. .map(|ep| (ep.method, ep.path))
  292. .collect::<Vec<_>>();
  293. assert!(paths.contains(&(Method::Get, RoutePath::MintBolt11)));
  294. assert!(paths.contains(&(Method::Post, RoutePath::Swap)));
  295. }
  296. #[test]
  297. fn test_settings_deserialize_with_regex() {
  298. let json = r#"{
  299. "openid_discovery": "https://example.com/.well-known/openid-configuration",
  300. "client_id": "client123",
  301. "protected_endpoints": [
  302. {
  303. "method": "GET",
  304. "path": "^/v1/mint/.*"
  305. },
  306. {
  307. "method": "POST",
  308. "path": "/v1/swap"
  309. }
  310. ]
  311. }"#;
  312. let settings: Settings = serde_json::from_str(json).unwrap();
  313. assert_eq!(
  314. settings.openid_discovery,
  315. "https://example.com/.well-known/openid-configuration"
  316. );
  317. assert_eq!(settings.client_id, "client123");
  318. assert_eq!(settings.protected_endpoints.len(), 5); // 3 mint paths + 1 swap path
  319. let expected_protected: HashSet<ProtectedEndpoint> = HashSet::from_iter(vec![
  320. ProtectedEndpoint::new(Method::Post, RoutePath::Swap),
  321. ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt11),
  322. ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt11),
  323. ProtectedEndpoint::new(Method::Get, RoutePath::MintQuoteBolt12),
  324. ProtectedEndpoint::new(Method::Get, RoutePath::MintBolt12),
  325. ]);
  326. let deserlized_protected = settings.protected_endpoints.into_iter().collect();
  327. assert_eq!(expected_protected, deserlized_protected);
  328. }
  329. #[test]
  330. fn test_settings_deserialize_invalid_regex() {
  331. let json = r#"{
  332. "openid_discovery": "https://example.com/.well-known/openid-configuration",
  333. "client_id": "client123",
  334. "protected_endpoints": [
  335. {
  336. "method": "GET",
  337. "path": "(unclosed parenthesis"
  338. }
  339. ]
  340. }"#;
  341. let result = serde_json::from_str::<Settings>(json);
  342. assert!(result.is_err());
  343. }
  344. #[test]
  345. fn test_settings_deserialize_exact_path_match() {
  346. let json = r#"{
  347. "openid_discovery": "https://example.com/.well-known/openid-configuration",
  348. "client_id": "client123",
  349. "protected_endpoints": [
  350. {
  351. "method": "GET",
  352. "path": "/v1/mint/quote/bolt11"
  353. }
  354. ]
  355. }"#;
  356. let settings: Settings = serde_json::from_str(json).unwrap();
  357. assert_eq!(settings.protected_endpoints.len(), 1);
  358. assert_eq!(settings.protected_endpoints[0].method, Method::Get);
  359. assert_eq!(
  360. settings.protected_endpoints[0].path,
  361. RoutePath::MintQuoteBolt11
  362. );
  363. }
  364. #[test]
  365. fn test_settings_deserialize_all_paths() {
  366. let json = r#"{
  367. "openid_discovery": "https://example.com/.well-known/openid-configuration",
  368. "client_id": "client123",
  369. "protected_endpoints": [
  370. {
  371. "method": "GET",
  372. "path": ".*"
  373. }
  374. ]
  375. }"#;
  376. let settings: Settings = serde_json::from_str(json).unwrap();
  377. assert_eq!(
  378. settings.protected_endpoints.len(),
  379. RoutePath::iter().count()
  380. );
  381. }
  382. }