transport.rs 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. //! HTTP Transport trait with a default implementation
  2. use std::fmt::Debug;
  3. use cdk_common::AuthToken;
  4. use reqwest::Client;
  5. use serde::de::DeserializeOwned;
  6. use serde::Serialize;
  7. use url::Url;
  8. use super::Error;
  9. use crate::error::ErrorResponse;
  10. /// Expected HTTP Transport
  11. #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
  12. #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
  13. pub trait Transport: Default + Send + Sync + Debug + Clone {
  14. /// Make the transport to use a given proxy
  15. fn with_proxy(
  16. &mut self,
  17. proxy: Url,
  18. host_matcher: Option<&str>,
  19. accept_invalid_certs: bool,
  20. ) -> Result<(), Error>;
  21. /// HTTP Get request
  22. async fn http_get<R>(&self, url: Url, auth: Option<AuthToken>) -> Result<R, Error>
  23. where
  24. R: DeserializeOwned;
  25. /// HTTP Post request
  26. async fn http_post<P, R>(
  27. &self,
  28. url: Url,
  29. auth_token: Option<AuthToken>,
  30. payload: &P,
  31. ) -> Result<R, Error>
  32. where
  33. P: Serialize + ?Sized + Send + Sync,
  34. R: DeserializeOwned;
  35. }
  36. /// Async transport for Http
  37. #[derive(Debug, Clone)]
  38. pub struct Async {
  39. inner: Client,
  40. }
  41. impl Default for Async {
  42. fn default() -> Self {
  43. #[cfg(not(target_arch = "wasm32"))]
  44. if rustls::crypto::CryptoProvider::get_default().is_none() {
  45. let _ = rustls::crypto::ring::default_provider().install_default();
  46. }
  47. Self {
  48. inner: Client::new(),
  49. }
  50. }
  51. }
  52. #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
  53. #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
  54. impl Transport for Async {
  55. #[cfg(target_arch = "wasm32")]
  56. fn with_proxy(
  57. &mut self,
  58. _proxy: Url,
  59. _host_matcher: Option<&str>,
  60. _accept_invalid_certs: bool,
  61. ) -> Result<(), Error> {
  62. panic!("Not supported in wasm");
  63. }
  64. #[cfg(not(target_arch = "wasm32"))]
  65. fn with_proxy(
  66. &mut self,
  67. proxy: Url,
  68. host_matcher: Option<&str>,
  69. accept_invalid_certs: bool,
  70. ) -> Result<(), Error> {
  71. let builder = reqwest::Client::builder().danger_accept_invalid_certs(accept_invalid_certs);
  72. let builder = match host_matcher {
  73. Some(pattern) => {
  74. // When a matcher is provided, only apply the proxy to matched hosts
  75. let regex = regex::Regex::new(pattern).map_err(|e| Error::Custom(e.to_string()))?;
  76. builder.proxy(reqwest::Proxy::custom(move |url| {
  77. url.host_str()
  78. .filter(|host| regex.is_match(host))
  79. .map(|_| proxy.clone())
  80. }))
  81. }
  82. // Apply proxy to all requests when no matcher is provided
  83. None => {
  84. builder.proxy(reqwest::Proxy::all(proxy).map_err(|e| Error::Custom(e.to_string()))?)
  85. }
  86. };
  87. self.inner = builder
  88. .build()
  89. .map_err(|e| Error::HttpError(e.status().map(|s| s.as_u16()), e.to_string()))?;
  90. Ok(())
  91. }
  92. async fn http_get<R>(&self, url: Url, auth: Option<AuthToken>) -> Result<R, Error>
  93. where
  94. R: DeserializeOwned,
  95. {
  96. let mut request = self.inner.get(url);
  97. if let Some(auth) = auth {
  98. request = request.header(auth.header_key(), auth.to_string());
  99. }
  100. let response = request
  101. .send()
  102. .await
  103. .map_err(|e| {
  104. Error::HttpError(
  105. e.status().map(|status_code| status_code.as_u16()),
  106. e.to_string(),
  107. )
  108. })?
  109. .text()
  110. .await
  111. .map_err(|e| {
  112. Error::HttpError(
  113. e.status().map(|status_code| status_code.as_u16()),
  114. e.to_string(),
  115. )
  116. })?;
  117. serde_json::from_str::<R>(&response).map_err(|err| {
  118. tracing::warn!("Http Response error: {}", err);
  119. match ErrorResponse::from_json(&response) {
  120. Ok(ok) => <ErrorResponse as Into<Error>>::into(ok),
  121. Err(err) => err.into(),
  122. }
  123. })
  124. }
  125. async fn http_post<P, R>(
  126. &self,
  127. url: Url,
  128. auth_token: Option<AuthToken>,
  129. payload: &P,
  130. ) -> Result<R, Error>
  131. where
  132. P: Serialize + ?Sized + Send + Sync,
  133. R: DeserializeOwned,
  134. {
  135. let mut request = self.inner.post(url).json(&payload);
  136. if let Some(auth) = auth_token {
  137. request = request.header(auth.header_key(), auth.to_string());
  138. }
  139. let response = request.send().await.map_err(|e| {
  140. Error::HttpError(
  141. e.status().map(|status_code| status_code.as_u16()),
  142. e.to_string(),
  143. )
  144. })?;
  145. let response = response.text().await.map_err(|e| {
  146. Error::HttpError(
  147. e.status().map(|status_code| status_code.as_u16()),
  148. e.to_string(),
  149. )
  150. })?;
  151. serde_json::from_str::<R>(&response).map_err(|err| {
  152. tracing::warn!("Http Response error: {}", err);
  153. match ErrorResponse::from_json(&response) {
  154. Ok(ok) => <ErrorResponse as Into<Error>>::into(ok),
  155. Err(err) => err.into(),
  156. }
  157. })
  158. }
  159. }