|
|
@@ -1,6 +1,8 @@
|
|
|
//! HTTP Transport trait with a default implementation
|
|
|
+use std::collections::HashMap;
|
|
|
use std::fmt::Debug;
|
|
|
|
|
|
+use bitreq::{Client, Proxy, Request, RequestExt};
|
|
|
use cdk_common::AuthToken;
|
|
|
#[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
|
|
|
use hickory_resolver::config::ResolverConfig;
|
|
|
@@ -8,7 +10,7 @@ use hickory_resolver::config::ResolverConfig;
|
|
|
use hickory_resolver::name_server::TokioConnectionProvider;
|
|
|
#[cfg(all(feature = "bip353", not(target_arch = "wasm32")))]
|
|
|
use hickory_resolver::Resolver;
|
|
|
-use reqwest::Client;
|
|
|
+use regex::Regex;
|
|
|
use serde::de::DeserializeOwned;
|
|
|
use serde::Serialize;
|
|
|
use url::Url;
|
|
|
@@ -53,10 +55,56 @@ pub trait Transport: Default + Send + Sync + Debug + Clone {
|
|
|
R: serde::de::DeserializeOwned;
|
|
|
}
|
|
|
|
|
|
-/// Async transport for Http
|
|
|
#[derive(Debug, Clone)]
|
|
|
+struct ProxyWrapper {
|
|
|
+ proxy: Proxy,
|
|
|
+ _accept_invalid_certs: bool,
|
|
|
+}
|
|
|
+
|
|
|
+/// Async transport for Http
|
|
|
+#[derive(Clone)]
|
|
|
pub struct Async {
|
|
|
- inner: Client,
|
|
|
+ client: Client,
|
|
|
+ proxy_per_url: HashMap<String, (Regex, ProxyWrapper)>,
|
|
|
+ all_proxy: Option<ProxyWrapper>,
|
|
|
+}
|
|
|
+
|
|
|
+impl Async {
|
|
|
+ fn prepare_request(&self, req: Request, url: Url, auth: Option<AuthToken>) -> Request {
|
|
|
+ let proxy = {
|
|
|
+ let url = url.to_string();
|
|
|
+ let mut proxy = None;
|
|
|
+ for (pattern, proxy_wrapper) in self.proxy_per_url.values() {
|
|
|
+ if pattern.is_match(&url) {
|
|
|
+ proxy = Some(proxy_wrapper.proxy.clone());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if proxy.is_some() {
|
|
|
+ proxy
|
|
|
+ } else {
|
|
|
+ self.all_proxy.as_ref().map(|x| x.proxy.clone())
|
|
|
+ }
|
|
|
+ };
|
|
|
+
|
|
|
+ let request = if let Some(proxy) = proxy {
|
|
|
+ req.with_proxy(proxy)
|
|
|
+ } else {
|
|
|
+ req
|
|
|
+ };
|
|
|
+
|
|
|
+ if let Some(auth) = auth {
|
|
|
+ request.with_header(auth.header_key(), auth.to_string())
|
|
|
+ } else {
|
|
|
+ request
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl Debug for Async {
|
|
|
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
|
+ write!(f, "HTTP Async client")
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
impl Default for Async {
|
|
|
@@ -67,7 +115,9 @@ impl Default for Async {
|
|
|
}
|
|
|
|
|
|
Self {
|
|
|
- inner: Client::new(),
|
|
|
+ client: Client::new(10),
|
|
|
+ proxy_per_url: HashMap::new(),
|
|
|
+ all_proxy: None,
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -92,27 +142,23 @@ impl Transport for Async {
|
|
|
host_matcher: Option<&str>,
|
|
|
accept_invalid_certs: bool,
|
|
|
) -> Result<(), Error> {
|
|
|
- let builder = reqwest::Client::builder().danger_accept_invalid_certs(accept_invalid_certs);
|
|
|
-
|
|
|
- let builder = match host_matcher {
|
|
|
- Some(pattern) => {
|
|
|
- // When a matcher is provided, only apply the proxy to matched hosts
|
|
|
- let regex = regex::Regex::new(pattern).map_err(|e| Error::Custom(e.to_string()))?;
|
|
|
- builder.proxy(reqwest::Proxy::custom(move |url| {
|
|
|
- url.host_str()
|
|
|
- .filter(|host| regex.is_match(host))
|
|
|
- .map(|_| proxy.clone())
|
|
|
- }))
|
|
|
- }
|
|
|
- // Apply proxy to all requests when no matcher is provided
|
|
|
- None => {
|
|
|
- builder.proxy(reqwest::Proxy::all(proxy).map_err(|e| Error::Custom(e.to_string()))?)
|
|
|
- }
|
|
|
+ let proxy = ProxyWrapper {
|
|
|
+ proxy: bitreq::Proxy::new_http(proxy).map_err(|_| Error::Internal)?,
|
|
|
+ _accept_invalid_certs: accept_invalid_certs,
|
|
|
};
|
|
|
+ if let Some((key, pattern)) = host_matcher
|
|
|
+ .map(|pattern| {
|
|
|
+ regex::Regex::new(pattern)
|
|
|
+ .map(|regex| (pattern.to_owned(), regex))
|
|
|
+ .map_err(|e| Error::Custom(e.to_string()))
|
|
|
+ })
|
|
|
+ .transpose()?
|
|
|
+ {
|
|
|
+ self.proxy_per_url.insert(key, (pattern, proxy));
|
|
|
+ } else {
|
|
|
+ self.all_proxy = Some(proxy);
|
|
|
+ }
|
|
|
|
|
|
- self.inner = builder
|
|
|
- .build()
|
|
|
- .map_err(|e| Error::HttpError(e.status().map(|s| s.as_u16()), e.to_string()))?;
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
@@ -144,33 +190,22 @@ impl Transport for Async {
|
|
|
where
|
|
|
R: DeserializeOwned,
|
|
|
{
|
|
|
- let mut request = self.inner.get(url);
|
|
|
+ let response = self
|
|
|
+ .prepare_request(bitreq::get(url.clone()), url, auth)
|
|
|
+ .send_async_with_client(&self.client)
|
|
|
+ .await
|
|
|
+ .map_err(|e| Error::HttpError(None, e.to_string()))?;
|
|
|
|
|
|
- if let Some(auth) = auth {
|
|
|
- request = request.header(auth.header_key(), auth.to_string());
|
|
|
+ if response.status_code != 200 {
|
|
|
+ return Err(Error::HttpError(
|
|
|
+ Some(response.status_code as u16),
|
|
|
+ "".to_string(),
|
|
|
+ ));
|
|
|
}
|
|
|
|
|
|
- let response = request
|
|
|
- .send()
|
|
|
- .await
|
|
|
- .map_err(|e| {
|
|
|
- Error::HttpError(
|
|
|
- e.status().map(|status_code| status_code.as_u16()),
|
|
|
- e.to_string(),
|
|
|
- )
|
|
|
- })?
|
|
|
- .text()
|
|
|
- .await
|
|
|
- .map_err(|e| {
|
|
|
- Error::HttpError(
|
|
|
- e.status().map(|status_code| status_code.as_u16()),
|
|
|
- e.to_string(),
|
|
|
- )
|
|
|
- })?;
|
|
|
-
|
|
|
- serde_json::from_str::<R>(&response).map_err(|err| {
|
|
|
+ serde_json::from_slice::<R>(response.as_bytes()).map_err(|err| {
|
|
|
tracing::warn!("Http Response error: {}", err);
|
|
|
- match ErrorResponse::from_json(&response) {
|
|
|
+ match ErrorResponse::from_slice(response.as_bytes()) {
|
|
|
Ok(ok) => <ErrorResponse as Into<Error>>::into(ok),
|
|
|
Err(err) => err.into(),
|
|
|
}
|
|
|
@@ -187,30 +222,27 @@ impl Transport for Async {
|
|
|
P: Serialize + ?Sized + Send + Sync,
|
|
|
R: DeserializeOwned,
|
|
|
{
|
|
|
- let mut request = self.inner.post(url).json(&payload);
|
|
|
-
|
|
|
- if let Some(auth) = auth_token {
|
|
|
- request = request.header(auth.header_key(), auth.to_string());
|
|
|
- }
|
|
|
-
|
|
|
- let response = request.send().await.map_err(|e| {
|
|
|
- Error::HttpError(
|
|
|
- e.status().map(|status_code| status_code.as_u16()),
|
|
|
- e.to_string(),
|
|
|
+ let response = self
|
|
|
+ .prepare_request(bitreq::post(url.clone()), url, auth_token)
|
|
|
+ .with_body(serde_json::to_string(payload).map_err(Error::SerdeJsonError)?)
|
|
|
+ .with_header(
|
|
|
+ "Content-Type".to_string(),
|
|
|
+ "application/json; charset=UTF-8".to_string(),
|
|
|
)
|
|
|
- })?;
|
|
|
+ .send_async_with_client(&self.client)
|
|
|
+ .await
|
|
|
+ .map_err(|e| Error::HttpError(None, e.to_string()))?;
|
|
|
|
|
|
- let response = response.text().await.map_err(|e| {
|
|
|
- Error::HttpError(
|
|
|
- e.status().map(|status_code| status_code.as_u16()),
|
|
|
- e.to_string(),
|
|
|
- )
|
|
|
- })?;
|
|
|
+ if response.status_code != 200 {
|
|
|
+ return Err(Error::HttpError(
|
|
|
+ Some(response.status_code as u16),
|
|
|
+ "".to_string(),
|
|
|
+ ));
|
|
|
+ }
|
|
|
|
|
|
- serde_json::from_str::<R>(&response).map_err(|err| {
|
|
|
+ serde_json::from_slice::<R>(response.as_bytes()).map_err(|err| {
|
|
|
tracing::warn!("Http Response error: {}", err);
|
|
|
- tracing::debug!("{:?}", response);
|
|
|
- match ErrorResponse::from_json(&response) {
|
|
|
+ match ErrorResponse::from_slice(response.as_bytes()) {
|
|
|
Ok(ok) => <ErrorResponse as Into<Error>>::into(ok),
|
|
|
Err(err) => err.into(),
|
|
|
}
|