transport.rs 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. //! HTTP Transport trait with a default implementation
  2. use std::fmt::Debug;
  3. use cdk_common::AuthToken;
  4. #[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
  5. use hickory_resolver::config::ResolverConfig;
  6. #[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
  7. use hickory_resolver::name_server::TokioConnectionProvider;
  8. #[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
  9. use hickory_resolver::Resolver;
  10. use reqwest::Client;
  11. use serde::de::DeserializeOwned;
  12. use serde::Serialize;
  13. use url::Url;
  14. use super::Error;
  15. use crate::error::ErrorResponse;
  16. /// Expected HTTP Transport
  17. #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
  18. #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
  19. pub trait Transport: Default + Send + Sync + Debug + Clone {
  20. #[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
  21. /// DNS resolver to get a TXT record from a domain name
  22. async fn resolve_dns_txt(&self, _domain: &str) -> Result<Vec<String>, Error>;
  23. /// Make the transport to use a given proxy
  24. fn with_proxy(
  25. &mut self,
  26. proxy: Url,
  27. host_matcher: Option<&str>,
  28. accept_invalid_certs: bool,
  29. ) -> Result<(), Error>;
  30. /// HTTP Get request
  31. async fn http_get<R>(&self, url: Url, auth: Option<AuthToken>) -> Result<R, Error>
  32. where
  33. R: DeserializeOwned;
  34. /// HTTP Post request
  35. async fn http_post<P, R>(
  36. &self,
  37. url: Url,
  38. auth_token: Option<AuthToken>,
  39. payload: &P,
  40. ) -> Result<R, Error>
  41. where
  42. P: Serialize + ?Sized + Send + Sync,
  43. R: DeserializeOwned;
  44. }
  45. /// Async transport for Http
  46. #[derive(Debug, Clone)]
  47. pub struct Async {
  48. inner: Client,
  49. }
  50. impl Default for Async {
  51. fn default() -> Self {
  52. #[cfg(not(target_arch = "wasm32"))]
  53. if rustls::crypto::CryptoProvider::get_default().is_none() {
  54. let _ = rustls::crypto::ring::default_provider().install_default();
  55. }
  56. Self {
  57. inner: Client::new(),
  58. }
  59. }
  60. }
  61. #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
  62. #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
  63. impl Transport for Async {
  64. #[cfg(target_arch = "wasm32")]
  65. fn with_proxy(
  66. &mut self,
  67. _proxy: Url,
  68. _host_matcher: Option<&str>,
  69. _accept_invalid_certs: bool,
  70. ) -> Result<(), Error> {
  71. panic!("Not supported in wasm");
  72. }
  73. #[cfg(not(target_arch = "wasm32"))]
  74. fn with_proxy(
  75. &mut self,
  76. proxy: Url,
  77. host_matcher: Option<&str>,
  78. accept_invalid_certs: bool,
  79. ) -> Result<(), Error> {
  80. let builder = reqwest::Client::builder().danger_accept_invalid_certs(accept_invalid_certs);
  81. let builder = match host_matcher {
  82. Some(pattern) => {
  83. // When a matcher is provided, only apply the proxy to matched hosts
  84. let regex = regex::Regex::new(pattern).map_err(|e| Error::Custom(e.to_string()))?;
  85. builder.proxy(reqwest::Proxy::custom(move |url| {
  86. url.host_str()
  87. .filter(|host| regex.is_match(host))
  88. .map(|_| proxy.clone())
  89. }))
  90. }
  91. // Apply proxy to all requests when no matcher is provided
  92. None => {
  93. builder.proxy(reqwest::Proxy::all(proxy).map_err(|e| Error::Custom(e.to_string()))?)
  94. }
  95. };
  96. self.inner = builder
  97. .build()
  98. .map_err(|e| Error::HttpError(e.status().map(|s| s.as_u16()), e.to_string()))?;
  99. Ok(())
  100. }
  101. /// DNS resolver to get a TXT record from a domain name
  102. #[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
  103. async fn resolve_dns_txt(&self, domain: &str) -> Result<Vec<String>, Error> {
  104. let resolver = Resolver::builder_with_config(
  105. ResolverConfig::default(),
  106. TokioConnectionProvider::default(),
  107. )
  108. .build();
  109. Ok(resolver
  110. .txt_lookup(domain)
  111. .await
  112. .map_err(|e| Error::Custom(e.to_string()))?
  113. .into_iter()
  114. .map(|txt| {
  115. txt.txt_data()
  116. .iter()
  117. .map(|bytes| String::from_utf8_lossy(bytes).into_owned())
  118. .collect::<Vec<_>>()
  119. .join("")
  120. })
  121. .collect::<Vec<_>>())
  122. }
  123. async fn http_get<R>(&self, url: Url, auth: Option<AuthToken>) -> Result<R, Error>
  124. where
  125. R: DeserializeOwned,
  126. {
  127. let mut request = self.inner.get(url);
  128. if let Some(auth) = auth {
  129. request = request.header(auth.header_key(), auth.to_string());
  130. }
  131. let response = request
  132. .send()
  133. .await
  134. .map_err(|e| {
  135. Error::HttpError(
  136. e.status().map(|status_code| status_code.as_u16()),
  137. e.to_string(),
  138. )
  139. })?
  140. .text()
  141. .await
  142. .map_err(|e| {
  143. Error::HttpError(
  144. e.status().map(|status_code| status_code.as_u16()),
  145. e.to_string(),
  146. )
  147. })?;
  148. serde_json::from_str::<R>(&response).map_err(|err| {
  149. tracing::warn!("Http Response error: {}", err);
  150. match ErrorResponse::from_json(&response) {
  151. Ok(ok) => <ErrorResponse as Into<Error>>::into(ok),
  152. Err(err) => err.into(),
  153. }
  154. })
  155. }
  156. async fn http_post<P, R>(
  157. &self,
  158. url: Url,
  159. auth_token: Option<AuthToken>,
  160. payload: &P,
  161. ) -> Result<R, Error>
  162. where
  163. P: Serialize + ?Sized + Send + Sync,
  164. R: DeserializeOwned,
  165. {
  166. let mut request = self.inner.post(url).json(&payload);
  167. if let Some(auth) = auth_token {
  168. request = request.header(auth.header_key(), auth.to_string());
  169. }
  170. let response = request.send().await.map_err(|e| {
  171. Error::HttpError(
  172. e.status().map(|status_code| status_code.as_u16()),
  173. e.to_string(),
  174. )
  175. })?;
  176. let response = response.text().await.map_err(|e| {
  177. Error::HttpError(
  178. e.status().map(|status_code| status_code.as_u16()),
  179. e.to_string(),
  180. )
  181. })?;
  182. serde_json::from_str::<R>(&response).map_err(|err| {
  183. tracing::warn!("Http Response error: {}", err);
  184. match ErrorResponse::from_json(&response) {
  185. Ok(ok) => <ErrorResponse as Into<Error>>::into(ok),
  186. Err(err) => err.into(),
  187. }
  188. })
  189. }
  190. }