client.rs 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. //! GRPC Client
  2. use std::path::Path;
  3. use std::str::FromStr;
  4. use std::sync::Arc;
  5. use cdk_common::util::hex;
  6. use hyper_rustls::HttpsConnectorBuilder;
  7. use hyper_util::client::legacy::connect::HttpConnector;
  8. use hyper_util::client::legacy::Client as HyperClient;
  9. use hyper_util::rt::TokioExecutor;
  10. use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
  11. use rustls::crypto::ring::default_provider;
  12. use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
  13. use rustls::{ClientConfig, DigitallySignedStruct, Error as TLSError, SignatureScheme};
  14. use tokio::fs;
  15. use tonic::body::Body;
  16. use tonic::codegen::InterceptedService;
  17. use tonic::metadata::MetadataValue;
  18. use tonic::service::Interceptor;
  19. use tonic::{Request, Status};
  20. use crate::{lnrpc, routerrpc, Error};
  21. /// Custom certificate verifier for LND's self-signed certificates
  22. #[derive(Debug)]
  23. pub(crate) struct LndCertVerifier {
  24. certs: Vec<Vec<u8>>,
  25. provider: Arc<rustls::crypto::CryptoProvider>,
  26. }
  27. impl LndCertVerifier {
  28. pub(crate) async fn load(path: impl AsRef<Path>) -> Result<Self, Error> {
  29. let provider = default_provider();
  30. let contents = fs::read(path).await.map_err(|_| Error::ReadFile)?;
  31. let mut reader = std::io::Cursor::new(contents);
  32. // Parse PEM certificates
  33. let certs: Vec<CertificateDer<'static>> =
  34. rustls_pemfile::certs(&mut reader).flatten().collect();
  35. Ok(LndCertVerifier {
  36. certs: certs.into_iter().map(|c| c.to_vec()).collect(),
  37. provider: Arc::new(provider),
  38. })
  39. }
  40. }
  41. impl ServerCertVerifier for LndCertVerifier {
  42. fn verify_server_cert(
  43. &self,
  44. end_entity: &CertificateDer<'_>,
  45. intermediates: &[CertificateDer<'_>],
  46. _server_name: &ServerName,
  47. _ocsp_response: &[u8],
  48. _now: UnixTime,
  49. ) -> Result<ServerCertVerified, TLSError> {
  50. let mut certs = intermediates
  51. .iter()
  52. .map(|c| c.as_ref().to_vec())
  53. .collect::<Vec<Vec<u8>>>();
  54. certs.push(end_entity.as_ref().to_vec());
  55. certs.sort();
  56. let mut our_certs = self.certs.clone();
  57. our_certs.sort();
  58. if self.certs.len() != certs.len() {
  59. return Err(TLSError::General(format!(
  60. "Mismatched number of certificates (Expected: {}, Presented: {})",
  61. self.certs.len(),
  62. certs.len()
  63. )));
  64. }
  65. for (c, p) in our_certs.iter().zip(certs.iter()) {
  66. if p != c {
  67. return Err(TLSError::General(
  68. "Server certificates do not match ours".to_string(),
  69. ));
  70. }
  71. }
  72. Ok(ServerCertVerified::assertion())
  73. }
  74. fn verify_tls12_signature(
  75. &self,
  76. message: &[u8],
  77. cert: &CertificateDer<'_>,
  78. dss: &DigitallySignedStruct,
  79. ) -> Result<HandshakeSignatureValid, TLSError> {
  80. rustls::crypto::verify_tls12_signature(
  81. message,
  82. cert,
  83. dss,
  84. &self.provider.signature_verification_algorithms,
  85. )
  86. .map(|_| HandshakeSignatureValid::assertion())
  87. }
  88. fn verify_tls13_signature(
  89. &self,
  90. message: &[u8],
  91. cert: &CertificateDer<'_>,
  92. dss: &DigitallySignedStruct,
  93. ) -> Result<HandshakeSignatureValid, TLSError> {
  94. rustls::crypto::verify_tls13_signature(
  95. message,
  96. cert,
  97. dss,
  98. &self.provider.signature_verification_algorithms,
  99. )
  100. .map(|_| HandshakeSignatureValid::assertion())
  101. }
  102. fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
  103. self.provider
  104. .signature_verification_algorithms
  105. .supported_schemes()
  106. }
  107. }
  108. pub type RouterClient = routerrpc::router_client::RouterClient<
  109. InterceptedService<
  110. HyperClient<hyper_rustls::HttpsConnector<HttpConnector>, Body>,
  111. MacaroonInterceptor,
  112. >,
  113. >;
  114. /// The client returned by `connect` function
  115. #[derive(Clone)]
  116. pub struct Client {
  117. lightning: lnrpc::lightning_client::LightningClient<
  118. InterceptedService<
  119. HyperClient<hyper_rustls::HttpsConnector<HttpConnector>, Body>,
  120. MacaroonInterceptor,
  121. >,
  122. >,
  123. router: RouterClient,
  124. }
  125. /// Supplies requests with macaroon
  126. #[derive(Clone)]
  127. pub struct MacaroonInterceptor {
  128. macaroon: String,
  129. }
  130. impl Interceptor for MacaroonInterceptor {
  131. fn call(&mut self, mut request: Request<()>) -> Result<Request<()>, Status> {
  132. request.metadata_mut().insert(
  133. "macaroon",
  134. MetadataValue::from_str(&self.macaroon)
  135. .map_err(|e| Status::internal(format!("Invalid macaroon: {e}")))?,
  136. );
  137. Ok(request)
  138. }
  139. }
  140. async fn load_macaroon(path: impl AsRef<Path>) -> Result<String, Error> {
  141. let macaroon = fs::read(path).await.map_err(|_| Error::ReadFile)?;
  142. Ok(hex::encode(macaroon))
  143. }
  144. pub async fn connect<P: AsRef<Path>>(
  145. address: &str,
  146. cert_path: P,
  147. macaroon_path: P,
  148. ) -> Result<Client, Error> {
  149. if rustls::crypto::CryptoProvider::get_default().is_none() {
  150. let _ = rustls::crypto::ring::default_provider().install_default();
  151. }
  152. let config = ClientConfig::builder()
  153. .dangerous()
  154. .with_custom_certificate_verifier(Arc::new(LndCertVerifier::load(cert_path).await?))
  155. .with_no_client_auth();
  156. // Create HTTPS connector
  157. let https = HttpsConnectorBuilder::new()
  158. .with_tls_config(config)
  159. .https_only()
  160. .enable_http2()
  161. .build();
  162. // Create hyper client
  163. let client = HyperClient::builder(TokioExecutor::new())
  164. .http2_only(true)
  165. .build(https);
  166. // Load macaroon
  167. let macaroon = load_macaroon(macaroon_path).await?;
  168. // Create service with macaroon interceptor
  169. let service = InterceptedService::new(client, MacaroonInterceptor { macaroon });
  170. // Create URI for the service
  171. let address = address
  172. .trim_start_matches("http://")
  173. .trim_start_matches("https://");
  174. let uri = http::Uri::from_str(&format!("https://{address}"))
  175. .map_err(|e| Error::InvalidConfig(format!("Invalid URI: {e}")))?;
  176. // Create LND client
  177. let lightning =
  178. lnrpc::lightning_client::LightningClient::with_origin(service.clone(), uri.clone());
  179. let router = RouterClient::with_origin(service, uri);
  180. Ok(Client { lightning, router })
  181. }
  182. impl Client {
  183. pub fn lightning(
  184. &mut self,
  185. ) -> &mut lnrpc::lightning_client::LightningClient<
  186. InterceptedService<
  187. HyperClient<hyper_rustls::HttpsConnector<HttpConnector>, Body>,
  188. MacaroonInterceptor,
  189. >,
  190. > {
  191. &mut self.lightning
  192. }
  193. pub fn router(&mut self) -> &mut RouterClient {
  194. &mut self.router
  195. }
  196. }