Browse Source

feat: add cancel to wait invoice

thesimplekid 5 months ago
parent
commit
7865f3dc17

+ 1 - 0
crates/cdk-cln/Cargo.toml

@@ -16,6 +16,7 @@ cdk = { path = "../cdk", version = "0.4.0", default-features = false, features =
 cln-rpc = "0.2.0"
 futures = { version = "0.3.28", default-features = false }
 tokio = { version = "1", default-features = false }
+tokio-util = { version = "0.7.11", default-features = false }
 tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
 thiserror = "1"
 uuid = { version = "1", features = ["v4"] }

+ 141 - 24
crates/cdk-cln/src/lib.rs

@@ -6,6 +6,7 @@
 use std::path::PathBuf;
 use std::pin::Pin;
 use std::str::FromStr;
+use std::sync::atomic::{AtomicBool, Ordering};
 use std::sync::Arc;
 use std::time::Duration;
 
@@ -25,13 +26,15 @@ use cln_rpc::model::requests::{
     InvoiceRequest, ListinvoicesRequest, ListpaysRequest, PayRequest, WaitanyinvoiceRequest,
 };
 use cln_rpc::model::responses::{
-    ListinvoicesInvoicesStatus, ListpaysPaysStatus, PayStatus, WaitanyinvoiceResponse,
+    ListinvoicesInvoices, ListinvoicesInvoicesStatus, ListpaysPaysStatus, PayStatus,
+    WaitanyinvoiceResponse, WaitanyinvoiceStatus,
 };
 use cln_rpc::model::Request;
 use cln_rpc::primitives::{Amount as CLN_Amount, AmountOrAny};
 use error::Error;
 use futures::{Stream, StreamExt};
 use tokio::sync::Mutex;
+use tokio_util::sync::CancellationToken;
 use uuid::Uuid;
 
 pub mod error;
@@ -44,6 +47,8 @@ pub struct Cln {
     fee_reserve: FeeReserve,
     mint_settings: MintMethodSettings,
     melt_settings: MeltMethodSettings,
+    wait_invoice_cancel_token: CancellationToken,
+    wait_invoice_is_active: Arc<AtomicBool>,
 }
 
 impl Cln {
@@ -62,6 +67,8 @@ impl Cln {
             fee_reserve,
             mint_settings,
             melt_settings,
+            wait_invoice_cancel_token: CancellationToken::new(),
+            wait_invoice_is_active: Arc::new(AtomicBool::new(false)),
         })
     }
 }
@@ -80,43 +87,123 @@ impl MintLightning for Cln {
         }
     }
 
+    /// Is wait invoice active
+    fn is_wait_invoice_active(&self) -> bool {
+        self.wait_invoice_is_active.load(Ordering::SeqCst)
+    }
+
+    /// Cancel wait invoice
+    fn cancel_wait_invoice(&self) {
+        self.wait_invoice_cancel_token.cancel()
+    }
+
+    #[allow(clippy::incompatible_msrv)]
+    // Clippy thinks select is not stable but it compiles fine on MSRV (1.63.0)
     async fn wait_any_invoice(
         &self,
     ) -> Result<Pin<Box<dyn Stream<Item = String> + Send>>, Self::Err> {
         let last_pay_index = self.get_last_pay_index().await?;
         let cln_client = cln_rpc::ClnRpc::new(&self.rpc_socket).await?;
 
-        Ok(futures::stream::unfold(
-            (cln_client, last_pay_index),
-            |(mut cln_client, mut last_pay_idx)| async move {
+        let stream = futures::stream::unfold(
+            (
+                cln_client,
+                last_pay_index,
+                self.wait_invoice_cancel_token.clone(),
+                Arc::clone(&self.wait_invoice_is_active),
+            ),
+            |(mut cln_client, mut last_pay_idx, cancel_token, is_active)| async move {
+                // Set the stream as active
+                is_active.store(true, Ordering::SeqCst);
+
                 loop {
-                    let invoice_res = cln_client
-                        .call(cln_rpc::Request::WaitAnyInvoice(WaitanyinvoiceRequest {
+                    tokio::select! {
+                        _ = cancel_token.cancelled() => {
+                            // Set the stream as inactive
+                            is_active.store(false, Ordering::SeqCst);
+                            // End the stream
+                            return None;
+                        }
+                        result = cln_client.call(cln_rpc::Request::WaitAnyInvoice(WaitanyinvoiceRequest {
                             timeout: None,
                             lastpay_index: last_pay_idx,
-                        }))
-                        .await;
-
-                    let invoice: WaitanyinvoiceResponse = match invoice_res {
-                        Ok(invoice) => invoice,
-                        Err(e) => {
-                            tracing::warn!("Error fetching invoice: {e}");
-                            // Let's not spam CLN with requests on failure
-                            tokio::time::sleep(Duration::from_secs(1)).await;
-                            // Retry same request
-                            continue;
+                        })) => {
+                            match result {
+                                Ok(invoice) => {
+
+                                        // Try to convert the invoice to WaitanyinvoiceResponse
+                            let wait_any_response_result: Result<WaitanyinvoiceResponse, _> =
+                                invoice.try_into();
+
+                            let wait_any_response = match wait_any_response_result {
+                                Ok(response) => response,
+                                Err(e) => {
+                                    tracing::warn!(
+                                        "Failed to parse WaitAnyInvoice response: {:?}",
+                                        e
+                                    );
+                                    // Continue to the next iteration without panicking
+                                    continue;
+                                }
+                            };
+
+                            // Check the status of the invoice
+                            // We only want to yield invoices that have been paid
+                            match wait_any_response.status {
+                                WaitanyinvoiceStatus::PAID => (),
+                                WaitanyinvoiceStatus::EXPIRED => continue,
+                            }
+
+                            last_pay_idx = wait_any_response.pay_index;
+
+                            let payment_hash = wait_any_response.payment_hash.to_string();
+
+                            let request_look_up = match wait_any_response.bolt12 {
+                                // If it is a bolt12 payment we need to get the offer_id as this is what we use as the request look up.
+                                // Since this is not returned in the wait any response,
+                                // we need to do a second query for it.
+                                Some(_) => {
+                                    match fetch_invoice_by_payment_hash(
+                                        &mut cln_client,
+                                        &payment_hash,
+                                    )
+                                    .await
+                                    {
+                                        Ok(Some(invoice)) => {
+                                            if let Some(local_offer_id) = invoice.local_offer_id {
+                                                local_offer_id.to_string()
+                                            } else {
+                                                continue;
+                                            }
+                                        }
+                                        Ok(None) => continue,
+                                        Err(e) => {
+                                            tracing::warn!(
+                                                "Error fetching invoice by payment hash: {e}"
+                                            );
+                                            continue;
+                                        }
+                                    }
+                                }
+                                None => payment_hash,
+                            };
+
+                            return Some((request_look_up, (cln_client, last_pay_idx, cancel_token, is_active)));
+                                }
+                                Err(e) => {
+                                    tracing::warn!("Error fetching invoice: {e}");
+                                    tokio::time::sleep(Duration::from_secs(1)).await;
+                                    continue;
+                                }
+                            }
                         }
                     }
-                    .try_into()
-                    .expect("Wrong response from CLN");
-
-                    last_pay_idx = invoice.pay_index;
-
-                    break Some((invoice.payment_hash.to_string(), (cln_client, last_pay_idx)));
                 }
             },
         )
-        .boxed())
+        .boxed();
+
+        Ok(stream)
     }
 
     async fn get_payment_quote(
@@ -425,3 +512,33 @@ fn cln_pays_status_to_mint_state(status: ListpaysPaysStatus) -> MeltQuoteState {
         ListpaysPaysStatus::FAILED => MeltQuoteState::Failed,
     }
 }
+
+async fn fetch_invoice_by_payment_hash(
+    cln_client: &mut cln_rpc::ClnRpc,
+    payment_hash: &str,
+) -> Result<Option<ListinvoicesInvoices>, Error> {
+    match cln_client
+        .call(cln_rpc::Request::ListInvoices(ListinvoicesRequest {
+            payment_hash: Some(payment_hash.to_string()),
+            index: None,
+            invstring: None,
+            label: None,
+            limit: None,
+            offer_id: None,
+            start: None,
+        }))
+        .await
+    {
+        Ok(cln_rpc::Response::ListInvoices(invoice_response)) => {
+            Ok(invoice_response.invoices.first().cloned())
+        }
+        Ok(_) => {
+            tracing::warn!("CLN returned an unexpected response type");
+            Err(Error::WrongClnResponse)
+        }
+        Err(e) => {
+            tracing::warn!("Error fetching invoice: {e}");
+            Err(Error::from(e))
+        }
+    }
+}

+ 8 - 0
crates/cdk-fake-wallet/src/lib.rs

@@ -112,6 +112,14 @@ impl MintLightning for FakeWallet {
         }
     }
 
+    fn is_wait_invoice_active(&self) -> bool {
+        todo!()
+    }
+
+    fn cancel_wait_invoice(&self) {
+        todo!()
+    }
+
     async fn wait_any_invoice(
         &self,
     ) -> Result<Pin<Box<dyn Stream<Item = String> + Send>>, Self::Err> {

+ 8 - 0
crates/cdk-lnbits/src/lib.rs

@@ -80,6 +80,14 @@ impl MintLightning for LNbits {
         }
     }
 
+    fn is_wait_invoice_active(&self) -> bool {
+        todo!()
+    }
+
+    fn cancel_wait_invoice(&self) {
+        todo!()
+    }
+
     async fn wait_any_invoice(
         &self,
     ) -> Result<Pin<Box<dyn Stream<Item = String> + Send>>, Self::Err> {

+ 8 - 0
crates/cdk-lnd/src/lib.rs

@@ -88,6 +88,14 @@ impl MintLightning for Lnd {
         }
     }
 
+    fn is_wait_invoice_active(&self) -> bool {
+        todo!()
+    }
+
+    fn cancel_wait_invoice(&self) {
+        todo!()
+    }
+
     async fn wait_any_invoice(
         &self,
     ) -> Result<Pin<Box<dyn Stream<Item = String> + Send>>, Self::Err> {

+ 8 - 0
crates/cdk-phoenixd/src/lib.rs

@@ -86,6 +86,14 @@ impl MintLightning for Phoenixd {
         }
     }
 
+    fn is_wait_invoice_active(&self) -> bool {
+        todo!()
+    }
+
+    fn cancel_wait_invoice(&self) {
+        todo!()
+    }
+
     async fn wait_any_invoice(
         &self,
     ) -> Result<Pin<Box<dyn Stream<Item = String> + Send>>, Self::Err> {

+ 8 - 0
crates/cdk-strike/src/lib.rs

@@ -78,6 +78,14 @@ impl MintLightning for Strike {
         }
     }
 
+    fn is_wait_invoice_active(&self) -> bool {
+        todo!()
+    }
+
+    fn cancel_wait_invoice(&self) {
+        todo!()
+    }
+
     async fn wait_any_invoice(
         &self,
     ) -> Result<Pin<Box<dyn Stream<Item = String> + Send>>, Self::Err> {

+ 6 - 0
crates/cdk/src/cdk_lightning/mod.rs

@@ -85,6 +85,12 @@ pub trait MintLightning {
         &self,
     ) -> Result<Pin<Box<dyn Stream<Item = String> + Send>>, Self::Err>;
 
+    /// Is wait invoice active
+    fn is_wait_invoice_active(&self) -> bool;
+
+    /// Cancel wait invoice
+    fn cancel_wait_invoice(&self);
+
     /// Check the status of an incoming payment
     async fn check_incoming_invoice_status(
         &self,

+ 8 - 5
crates/cdk/src/mint/mod.rs

@@ -190,15 +190,17 @@ impl Mint {
         let mut join_set = JoinSet::new();
 
         for (key, ln) in self.ln.iter() {
-            let mint = Arc::clone(&mint_arc);
-            let ln = Arc::clone(ln);
-            let shutdown = Arc::clone(&shutdown);
-            let key = *key;
-            join_set.spawn(async move {
+            if !ln.is_wait_invoice_active() {
+                let mint = Arc::clone(&mint_arc);
+                let ln = Arc::clone(ln);
+                let shutdown = Arc::clone(&shutdown);
+                let key = *key;
+                join_set.spawn(async move {
             loop {
                 tokio::select! {
                     _ = shutdown.notified() => {
                         tracing::info!("Shutdown signal received, stopping task for {:?}", key);
+                        ln.cancel_wait_invoice();
                         break;
                     }
                     result = ln.wait_any_invoice() => {
@@ -219,6 +221,7 @@ impl Mint {
                 }
             }
         });
+            }
         }
 
         // Spawn a task to manage the JoinSet