server.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. use std::net::SocketAddr;
  2. use std::path::PathBuf;
  3. use std::pin::Pin;
  4. use std::str::FromStr;
  5. use std::sync::Arc;
  6. use std::time::Duration;
  7. use cdk_common::payment::{IncomingPaymentOptions, MintPayment};
  8. use cdk_common::CurrencyUnit;
  9. use futures::{Stream, StreamExt};
  10. use lightning::offers::offer::Offer;
  11. use serde_json::Value;
  12. use tokio::sync::{mpsc, Notify};
  13. use tokio::task::JoinHandle;
  14. use tokio::time::{sleep, Instant};
  15. use tokio_stream::wrappers::ReceiverStream;
  16. use tonic::transport::{Certificate, Identity, Server, ServerTlsConfig};
  17. use tonic::{async_trait, Request, Response, Status};
  18. use tracing::instrument;
  19. use super::cdk_payment_processor_server::{CdkPaymentProcessor, CdkPaymentProcessorServer};
  20. use crate::error::Error;
  21. use crate::proto::*;
  22. type ResponseStream =
  23. Pin<Box<dyn Stream<Item = Result<WaitIncomingPaymentResponse, Status>> + Send>>;
  24. /// Payment Processor
  25. #[derive(Clone)]
  26. pub struct PaymentProcessorServer {
  27. inner: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
  28. socket_addr: SocketAddr,
  29. shutdown: Arc<Notify>,
  30. handle: Option<Arc<JoinHandle<anyhow::Result<()>>>>,
  31. }
  32. impl PaymentProcessorServer {
  33. /// Create new [`PaymentProcessorServer`]
  34. pub fn new(
  35. payment_processor: Arc<dyn MintPayment<Err = cdk_common::payment::Error> + Send + Sync>,
  36. addr: &str,
  37. port: u16,
  38. ) -> anyhow::Result<Self> {
  39. let socket_addr = SocketAddr::new(addr.parse()?, port);
  40. Ok(Self {
  41. inner: payment_processor,
  42. socket_addr,
  43. shutdown: Arc::new(Notify::new()),
  44. handle: None,
  45. })
  46. }
  47. /// Start fake wallet grpc server
  48. pub async fn start(&mut self, tls_dir: Option<PathBuf>) -> anyhow::Result<()> {
  49. tracing::info!("Starting RPC server {}", self.socket_addr);
  50. let server = match tls_dir {
  51. Some(tls_dir) => {
  52. tracing::info!("TLS configuration found, starting secure server");
  53. // Check for server.pem
  54. let server_pem_path = tls_dir.join("server.pem");
  55. if !server_pem_path.exists() {
  56. let err_msg = format!(
  57. "TLS certificate file not found: {}",
  58. server_pem_path.display()
  59. );
  60. tracing::error!("{}", err_msg);
  61. return Err(anyhow::anyhow!(err_msg));
  62. }
  63. // Check for server.key
  64. let server_key_path = tls_dir.join("server.key");
  65. if !server_key_path.exists() {
  66. let err_msg = format!("TLS key file not found: {}", server_key_path.display());
  67. tracing::error!("{}", err_msg);
  68. return Err(anyhow::anyhow!(err_msg));
  69. }
  70. // Check for ca.pem
  71. let ca_pem_path = tls_dir.join("ca.pem");
  72. if !ca_pem_path.exists() {
  73. let err_msg =
  74. format!("CA certificate file not found: {}", ca_pem_path.display());
  75. tracing::error!("{}", err_msg);
  76. return Err(anyhow::anyhow!(err_msg));
  77. }
  78. let cert = std::fs::read_to_string(&server_pem_path)?;
  79. let key = std::fs::read_to_string(&server_key_path)?;
  80. let client_ca_cert = std::fs::read_to_string(&ca_pem_path)?;
  81. let client_ca_cert = Certificate::from_pem(client_ca_cert);
  82. let server_identity = Identity::from_pem(cert, key);
  83. let tls_config = ServerTlsConfig::new()
  84. .identity(server_identity)
  85. .client_ca_root(client_ca_cert);
  86. Server::builder()
  87. .tls_config(tls_config)?
  88. .add_service(CdkPaymentProcessorServer::new(self.clone()))
  89. }
  90. None => {
  91. tracing::warn!("No valid TLS configuration found, starting insecure server");
  92. Server::builder().add_service(CdkPaymentProcessorServer::new(self.clone()))
  93. }
  94. };
  95. let shutdown = self.shutdown.clone();
  96. let addr = self.socket_addr;
  97. self.handle = Some(Arc::new(tokio::spawn(async move {
  98. let server = server.serve_with_shutdown(addr, async {
  99. shutdown.notified().await;
  100. });
  101. server.await?;
  102. Ok(())
  103. })));
  104. Ok(())
  105. }
  106. /// Stop fake wallet grpc server
  107. pub async fn stop(&self) -> anyhow::Result<()> {
  108. const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
  109. if let Some(handle) = &self.handle {
  110. tracing::info!("Initiating server shutdown");
  111. self.shutdown.notify_waiters();
  112. let start = Instant::now();
  113. while !handle.is_finished() {
  114. if start.elapsed() >= SHUTDOWN_TIMEOUT {
  115. tracing::error!(
  116. "Server shutdown timed out after {} seconds, aborting handle",
  117. SHUTDOWN_TIMEOUT.as_secs()
  118. );
  119. handle.abort();
  120. break;
  121. }
  122. sleep(Duration::from_millis(100)).await;
  123. }
  124. if handle.is_finished() {
  125. tracing::info!("Server shutdown completed successfully");
  126. }
  127. } else {
  128. tracing::info!("No server handle found, nothing to stop");
  129. }
  130. Ok(())
  131. }
  132. }
  133. impl Drop for PaymentProcessorServer {
  134. fn drop(&mut self) {
  135. tracing::debug!("Dropping payment process server");
  136. self.shutdown.notify_one();
  137. }
  138. }
  139. #[async_trait]
  140. impl CdkPaymentProcessor for PaymentProcessorServer {
  141. async fn get_settings(
  142. &self,
  143. _request: Request<EmptyRequest>,
  144. ) -> Result<Response<SettingsResponse>, Status> {
  145. let settings: Value = self
  146. .inner
  147. .get_settings()
  148. .await
  149. .map_err(|_| Status::internal("Could not get settings"))?;
  150. Ok(Response::new(SettingsResponse {
  151. inner: settings.to_string(),
  152. }))
  153. }
  154. async fn create_payment(
  155. &self,
  156. request: Request<CreatePaymentRequest>,
  157. ) -> Result<Response<CreatePaymentResponse>, Status> {
  158. let CreatePaymentRequest { unit, options } = request.into_inner();
  159. let unit = CurrencyUnit::from_str(&unit)
  160. .map_err(|_| Status::invalid_argument("Invalid currency unit"))?;
  161. let options = options.ok_or_else(|| Status::invalid_argument("Missing payment options"))?;
  162. let proto_options = match options
  163. .options
  164. .ok_or_else(|| Status::invalid_argument("Missing options"))?
  165. {
  166. incoming_payment_options::Options::Bolt11(opts) => {
  167. IncomingPaymentOptions::Bolt11(cdk_common::payment::Bolt11IncomingPaymentOptions {
  168. description: opts.description,
  169. amount: opts.amount.into(),
  170. unix_expiry: opts.unix_expiry,
  171. })
  172. }
  173. incoming_payment_options::Options::Bolt12(opts) => IncomingPaymentOptions::Bolt12(
  174. Box::new(cdk_common::payment::Bolt12IncomingPaymentOptions {
  175. description: opts.description,
  176. amount: opts.amount.map(Into::into),
  177. unix_expiry: opts.unix_expiry,
  178. }),
  179. ),
  180. };
  181. let invoice_response = self
  182. .inner
  183. .create_incoming_payment_request(&unit, proto_options)
  184. .await
  185. .map_err(|_| Status::internal("Could not create invoice"))?;
  186. Ok(Response::new(invoice_response.into()))
  187. }
  188. async fn get_payment_quote(
  189. &self,
  190. request: Request<PaymentQuoteRequest>,
  191. ) -> Result<Response<PaymentQuoteResponse>, Status> {
  192. let request = request.into_inner();
  193. let unit = CurrencyUnit::from_str(&request.unit)
  194. .map_err(|_| Status::invalid_argument("Invalid currency unit"))?;
  195. let options = match request.request_type() {
  196. OutgoingPaymentRequestType::Bolt11Invoice => {
  197. let bolt11: cdk_common::Bolt11Invoice =
  198. request.request.parse().map_err(Error::Invoice)?;
  199. cdk_common::payment::OutgoingPaymentOptions::Bolt11(Box::new(
  200. cdk_common::payment::Bolt11OutgoingPaymentOptions {
  201. bolt11,
  202. max_fee_amount: None,
  203. timeout_secs: None,
  204. melt_options: request.options.map(Into::into),
  205. },
  206. ))
  207. }
  208. OutgoingPaymentRequestType::Bolt12Offer => {
  209. // Parse offer to verify it's valid, but store as string
  210. let _: Offer = request.request.parse().map_err(|_| Error::Bolt12Parse)?;
  211. cdk_common::payment::OutgoingPaymentOptions::Bolt12(Box::new(
  212. cdk_common::payment::Bolt12OutgoingPaymentOptions {
  213. offer: Offer::from_str(&request.request).unwrap(),
  214. max_fee_amount: None,
  215. timeout_secs: None,
  216. melt_options: request.options.map(Into::into),
  217. },
  218. ))
  219. }
  220. };
  221. let payment_quote = self
  222. .inner
  223. .get_payment_quote(&unit, options)
  224. .await
  225. .map_err(|err| {
  226. tracing::error!("Could not get payment quote: {}", err);
  227. Status::internal("Could not get quote")
  228. })?;
  229. Ok(Response::new(payment_quote.into()))
  230. }
  231. async fn make_payment(
  232. &self,
  233. request: Request<MakePaymentRequest>,
  234. ) -> Result<Response<MakePaymentResponse>, Status> {
  235. let request = request.into_inner();
  236. let options = request
  237. .payment_options
  238. .ok_or_else(|| Status::invalid_argument("Missing payment options"))?;
  239. let (unit, payment_options) = match options
  240. .options
  241. .ok_or_else(|| Status::invalid_argument("Missing options"))?
  242. {
  243. outgoing_payment_variant::Options::Bolt11(opts) => {
  244. let bolt11: cdk_common::Bolt11Invoice =
  245. opts.bolt11.parse().map_err(Error::Invoice)?;
  246. let payment_options = cdk_common::payment::OutgoingPaymentOptions::Bolt11(
  247. Box::new(cdk_common::payment::Bolt11OutgoingPaymentOptions {
  248. bolt11,
  249. max_fee_amount: opts.max_fee_amount.map(Into::into),
  250. timeout_secs: opts.timeout_secs,
  251. melt_options: opts.melt_options.map(Into::into),
  252. }),
  253. );
  254. (CurrencyUnit::Msat, payment_options)
  255. }
  256. outgoing_payment_variant::Options::Bolt12(opts) => {
  257. let offer = Offer::from_str(&opts.offer)
  258. .map_err(|_| Error::Bolt12Parse)
  259. .unwrap();
  260. let payment_options = cdk_common::payment::OutgoingPaymentOptions::Bolt12(
  261. Box::new(cdk_common::payment::Bolt12OutgoingPaymentOptions {
  262. offer,
  263. max_fee_amount: opts.max_fee_amount.map(Into::into),
  264. timeout_secs: opts.timeout_secs,
  265. melt_options: opts.melt_options.map(Into::into),
  266. }),
  267. );
  268. (CurrencyUnit::Msat, payment_options)
  269. }
  270. };
  271. let pay_response = self
  272. .inner
  273. .make_payment(&unit, payment_options)
  274. .await
  275. .map_err(|err| {
  276. tracing::error!("Could not make payment: {}", err);
  277. match err {
  278. cdk_common::payment::Error::InvoiceAlreadyPaid => {
  279. Status::already_exists("Payment request already paid")
  280. }
  281. cdk_common::payment::Error::InvoicePaymentPending => {
  282. Status::already_exists("Payment request pending")
  283. }
  284. _ => Status::internal("Could not pay invoice"),
  285. }
  286. })?;
  287. Ok(Response::new(pay_response.into()))
  288. }
  289. async fn check_incoming_payment(
  290. &self,
  291. request: Request<CheckIncomingPaymentRequest>,
  292. ) -> Result<Response<CheckIncomingPaymentResponse>, Status> {
  293. let request = request.into_inner();
  294. let payment_identifier = request
  295. .request_identifier
  296. .ok_or_else(|| Status::invalid_argument("Missing request identifier"))?
  297. .try_into()
  298. .map_err(|_| Status::invalid_argument("Invalid request identifier"))?;
  299. let check_responses = self
  300. .inner
  301. .check_incoming_payment_status(&payment_identifier)
  302. .await
  303. .map_err(|_| Status::internal("Could not check incoming payment status"))?;
  304. Ok(Response::new(CheckIncomingPaymentResponse {
  305. payments: check_responses.into_iter().map(|r| r.into()).collect(),
  306. }))
  307. }
  308. async fn check_outgoing_payment(
  309. &self,
  310. request: Request<CheckOutgoingPaymentRequest>,
  311. ) -> Result<Response<MakePaymentResponse>, Status> {
  312. let request = request.into_inner();
  313. let payment_identifier = request
  314. .request_identifier
  315. .ok_or_else(|| Status::invalid_argument("Missing request identifier"))?
  316. .try_into()
  317. .map_err(|_| Status::invalid_argument("Invalid request identifier"))?;
  318. let check_response = self
  319. .inner
  320. .check_outgoing_payment(&payment_identifier)
  321. .await
  322. .map_err(|_| Status::internal("Could not check outgoing payment status"))?;
  323. Ok(Response::new(check_response.into()))
  324. }
  325. type WaitIncomingPaymentStream = ResponseStream;
  326. #[allow(clippy::incompatible_msrv)]
  327. #[instrument(skip_all)]
  328. async fn wait_incoming_payment(
  329. &self,
  330. _request: Request<EmptyRequest>,
  331. ) -> Result<Response<Self::WaitIncomingPaymentStream>, Status> {
  332. tracing::debug!("Server waiting for payment stream");
  333. let (tx, rx) = mpsc::channel(128);
  334. let shutdown_clone = self.shutdown.clone();
  335. let ln = self.inner.clone();
  336. tokio::spawn(async move {
  337. loop {
  338. tokio::select! {
  339. _ = shutdown_clone.notified() => {
  340. tracing::info!("Shutdown signal received, stopping task");
  341. ln.cancel_wait_invoice();
  342. break;
  343. }
  344. result = ln.wait_payment_event() => {
  345. match result {
  346. Ok(mut stream) => {
  347. while let Some(event) = stream.next().await {
  348. match event {
  349. cdk_common::payment::Event::PaymentReceived(payment_response) => {
  350. match tx.send(Result::<_, Status>::Ok(payment_response.into()))
  351. .await
  352. {
  353. Ok(_) => {
  354. // Response was queued to be sent to client
  355. }
  356. Err(item) => {
  357. tracing::error!("Error adding incoming payment to stream: {}", item);
  358. break;
  359. }
  360. }
  361. }
  362. }
  363. }
  364. }
  365. Err(err) => {
  366. tracing::warn!("Could not get invoice stream: {}", err);
  367. tokio::time::sleep(std::time::Duration::from_secs(5)).await;
  368. }
  369. }
  370. }
  371. }
  372. }
  373. });
  374. let output_stream = ReceiverStream::new(rx);
  375. Ok(Response::new(
  376. Box::pin(output_stream) as Self::WaitIncomingPaymentStream
  377. ))
  378. }
  379. }