| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- //! HTTP Transport trait with a default implementation
- use std::fmt::Debug;
- use cdk_common::AuthToken;
- #[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
- use hickory_resolver::config::ResolverConfig;
- #[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
- use hickory_resolver::name_server::TokioConnectionProvider;
- #[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
- use hickory_resolver::Resolver;
- use reqwest::Client;
- use serde::de::DeserializeOwned;
- use serde::Serialize;
- use url::Url;
- use super::Error;
- use crate::error::ErrorResponse;
- /// Expected HTTP Transport
- #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
- #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
- pub trait Transport: Default + Send + Sync + Debug + Clone {
- #[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
- /// DNS resolver to get a TXT record from a domain name
- async fn resolve_dns_txt(&self, _domain: &str) -> Result<Vec<String>, Error>;
- /// Make the transport to use a given proxy
- fn with_proxy(
- &mut self,
- proxy: Url,
- host_matcher: Option<&str>,
- accept_invalid_certs: bool,
- ) -> Result<(), Error>;
- /// HTTP Get request
- async fn http_get<R>(&self, url: Url, auth: Option<AuthToken>) -> Result<R, Error>
- where
- R: DeserializeOwned;
- /// HTTP Post request
- async fn http_post<P, R>(
- &self,
- url: Url,
- auth_token: Option<AuthToken>,
- payload: &P,
- ) -> Result<R, Error>
- where
- P: Serialize + ?Sized + Send + Sync,
- R: DeserializeOwned;
- }
- /// Async transport for Http
- #[derive(Debug, Clone)]
- pub struct Async {
- inner: Client,
- }
- impl Default for Async {
- fn default() -> Self {
- #[cfg(not(target_arch = "wasm32"))]
- if rustls::crypto::CryptoProvider::get_default().is_none() {
- let _ = rustls::crypto::ring::default_provider().install_default();
- }
- Self {
- inner: Client::new(),
- }
- }
- }
- #[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
- #[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
- impl Transport for Async {
- #[cfg(target_arch = "wasm32")]
- fn with_proxy(
- &mut self,
- _proxy: Url,
- _host_matcher: Option<&str>,
- _accept_invalid_certs: bool,
- ) -> Result<(), Error> {
- panic!("Not supported in wasm");
- }
- #[cfg(not(target_arch = "wasm32"))]
- fn with_proxy(
- &mut self,
- proxy: Url,
- host_matcher: Option<&str>,
- accept_invalid_certs: bool,
- ) -> Result<(), Error> {
- let builder = reqwest::Client::builder().danger_accept_invalid_certs(accept_invalid_certs);
- let builder = match host_matcher {
- Some(pattern) => {
- // When a matcher is provided, only apply the proxy to matched hosts
- let regex = regex::Regex::new(pattern).map_err(|e| Error::Custom(e.to_string()))?;
- builder.proxy(reqwest::Proxy::custom(move |url| {
- url.host_str()
- .filter(|host| regex.is_match(host))
- .map(|_| proxy.clone())
- }))
- }
- // Apply proxy to all requests when no matcher is provided
- None => {
- builder.proxy(reqwest::Proxy::all(proxy).map_err(|e| Error::Custom(e.to_string()))?)
- }
- };
- self.inner = builder
- .build()
- .map_err(|e| Error::HttpError(e.status().map(|s| s.as_u16()), e.to_string()))?;
- Ok(())
- }
- /// DNS resolver to get a TXT record from a domain name
- #[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
- async fn resolve_dns_txt(&self, domain: &str) -> Result<Vec<String>, Error> {
- let resolver = Resolver::builder_with_config(
- ResolverConfig::default(),
- TokioConnectionProvider::default(),
- )
- .build();
- Ok(resolver
- .txt_lookup(domain)
- .await
- .map_err(|e| Error::Custom(e.to_string()))?
- .into_iter()
- .map(|txt| {
- txt.txt_data()
- .iter()
- .map(|bytes| String::from_utf8_lossy(bytes).into_owned())
- .collect::<Vec<_>>()
- .join("")
- })
- .collect::<Vec<_>>())
- }
- async fn http_get<R>(&self, url: Url, auth: Option<AuthToken>) -> Result<R, Error>
- where
- R: DeserializeOwned,
- {
- let mut request = self.inner.get(url);
- if let Some(auth) = auth {
- request = request.header(auth.header_key(), auth.to_string());
- }
- let response = request
- .send()
- .await
- .map_err(|e| {
- Error::HttpError(
- e.status().map(|status_code| status_code.as_u16()),
- e.to_string(),
- )
- })?
- .text()
- .await
- .map_err(|e| {
- Error::HttpError(
- e.status().map(|status_code| status_code.as_u16()),
- e.to_string(),
- )
- })?;
- serde_json::from_str::<R>(&response).map_err(|err| {
- tracing::warn!("Http Response error: {}", err);
- match ErrorResponse::from_json(&response) {
- Ok(ok) => <ErrorResponse as Into<Error>>::into(ok),
- Err(err) => err.into(),
- }
- })
- }
- async fn http_post<P, R>(
- &self,
- url: Url,
- auth_token: Option<AuthToken>,
- payload: &P,
- ) -> Result<R, Error>
- where
- P: Serialize + ?Sized + Send + Sync,
- R: DeserializeOwned,
- {
- let mut request = self.inner.post(url).json(&payload);
- if let Some(auth) = auth_token {
- request = request.header(auth.header_key(), auth.to_string());
- }
- let response = request.send().await.map_err(|e| {
- Error::HttpError(
- e.status().map(|status_code| status_code.as_u16()),
- e.to_string(),
- )
- })?;
- let response = response.text().await.map_err(|e| {
- Error::HttpError(
- e.status().map(|status_code| status_code.as_u16()),
- e.to_string(),
- )
- })?;
- serde_json::from_str::<R>(&response).map_err(|err| {
- tracing::warn!("Http Response error: {}", err);
- match ErrorResponse::from_json(&response) {
- Ok(ok) => <ErrorResponse as Into<Error>>::into(ok),
- Err(err) => err.into(),
- }
- })
- }
- }
|