Browse Source

Merge remote-tracking branch 'origin/main' into feature/wallet-db-transactions

Cesar Rodas 2 months ago
parent
commit
656995bfb9

+ 3 - 0
crates/cdk-common/src/database/wallet.rs

@@ -160,6 +160,9 @@ pub trait Database: Debug {
         spending_conditions: Option<Vec<SpendingConditions>>,
     ) -> Result<Vec<ProofInfo>, Self::Err>;
 
+    /// Get proofs by Y values
+    async fn get_proofs_by_ys(&self, ys: Vec<PublicKey>) -> Result<Vec<ProofInfo>, Self::Err>;
+
     /// Get balance
     async fn get_balance(
         &self,

+ 60 - 0
crates/cdk-ffi/src/database.rs

@@ -66,6 +66,9 @@ pub trait WalletDatabase: Send + Sync {
         spending_conditions: Option<Vec<SpendingConditions>>,
     ) -> Result<Vec<ProofInfo>, FfiError>;
 
+    /// Get proofs by Y values
+    async fn get_proofs_by_ys(&self, ys: Vec<PublicKey>) -> Result<Vec<ProofInfo>, FfiError>;
+
     /// Get balance efficiently using SQL aggregation
     async fn get_balance(
         &self,
@@ -434,6 +437,48 @@ impl CdkWalletDatabase for WalletDatabaseBridge {
         cdk_result
     }
 
+    async fn get_proofs_by_ys(
+        &self,
+        ys: Vec<cdk::nuts::PublicKey>,
+    ) -> Result<Vec<cdk::types::ProofInfo>, Self::Err> {
+        let ffi_ys: Vec<PublicKey> = ys.into_iter().map(Into::into).collect();
+
+        let result = self
+            .ffi_db
+            .get_proofs_by_ys(ffi_ys)
+            .await
+            .map_err(|e| cdk::cdk_database::Error::Database(e.to_string().into()))?;
+
+        // Convert back to CDK ProofInfo
+        let cdk_result: Result<Vec<cdk::types::ProofInfo>, cdk::cdk_database::Error> = result
+            .into_iter()
+            .map(|info| {
+                Ok(cdk::types::ProofInfo {
+                    proof: info.proof.try_into().map_err(|e: FfiError| {
+                        cdk::cdk_database::Error::Database(e.to_string().into())
+                    })?,
+                    y: info.y.try_into().map_err(|e: FfiError| {
+                        cdk::cdk_database::Error::Database(e.to_string().into())
+                    })?,
+                    mint_url: info.mint_url.try_into().map_err(|e: FfiError| {
+                        cdk::cdk_database::Error::Database(e.to_string().into())
+                    })?,
+                    state: info.state.into(),
+                    spending_condition: info
+                        .spending_condition
+                        .map(|sc| sc.try_into())
+                        .transpose()
+                        .map_err(|e: FfiError| {
+                            cdk::cdk_database::Error::Database(e.to_string().into())
+                        })?,
+                    unit: info.unit.into(),
+                })
+            })
+            .collect();
+
+        cdk_result
+    }
+
     async fn get_balance(
         &self,
         mint_url: Option<cdk::mint_url::MintUrl>,
@@ -887,6 +932,21 @@ where
         })
     }
 
+    async fn get_proofs_by_ys(&self, ys: Vec<PublicKey>) -> Result<Vec<ProofInfo>, FfiError> {
+        let cdk_ys: Vec<cdk::nuts::PublicKey> = ys
+            .into_iter()
+            .map(|y| y.try_into())
+            .collect::<Result<Vec<_>, FfiError>>()?;
+
+        let result = self
+            .inner
+            .get_proofs_by_ys(cdk_ys)
+            .await
+            .map_err(|e| FfiError::Database { msg: e.to_string() })?;
+
+        Ok(result.into_iter().map(Into::into).collect())
+    }
+
     async fn get_mint(&self, mint_url: MintUrl) -> Result<Option<MintInfo>, FfiError> {
         let cdk_mint_url = mint_url.try_into()?;
         let result = self

+ 24 - 0
crates/cdk-ffi/src/multi_mint_wallet.rs

@@ -508,6 +508,30 @@ impl MultiMintWallet {
         Ok(transactions.into_iter().map(Into::into).collect())
     }
 
+    /// Get proofs for a transaction by transaction ID
+    ///
+    /// This retrieves all proofs associated with a transaction. If `mint_url` is provided,
+    /// it will only check that specific mint's wallet. Otherwise, it searches across all
+    /// wallets to find which mint the transaction belongs to.
+    ///
+    /// # Arguments
+    ///
+    /// * `id` - The transaction ID
+    /// * `mint_url` - Optional mint URL to check directly, avoiding iteration over all wallets
+    pub async fn get_proofs_for_transaction(
+        &self,
+        id: TransactionId,
+        mint_url: Option<MintUrl>,
+    ) -> Result<Vec<Proof>, FfiError> {
+        let cdk_id = id.try_into()?;
+        let cdk_mint_url = mint_url.map(|url| url.try_into()).transpose()?;
+        let proofs = self
+            .inner
+            .get_proofs_for_transaction(cdk_id, cdk_mint_url)
+            .await?;
+        Ok(proofs.into_iter().map(Into::into).collect())
+    }
+
     /// Check all mint quotes and mint if paid
     pub async fn check_all_mint_quotes(
         &self,

+ 5 - 1
crates/cdk-ffi/src/postgres.rs

@@ -6,7 +6,7 @@ use cdk_postgres::PgConnectionPool;
 
 use crate::{
     CurrencyUnit, FfiError, FfiWalletSQLDatabase, Id, KeySetInfo, Keys, MeltQuote, MintInfo,
-    MintQuote, MintUrl, ProofInfo, ProofState, SpendingConditions, Transaction,
+    MintQuote, MintUrl, ProofInfo, ProofState, PublicKey, SpendingConditions, Transaction,
     TransactionDirection, TransactionId, WalletDatabase, WalletDatabaseTransactionWrapper,
 };
 
@@ -66,6 +66,10 @@ impl WalletDatabase for WalletPostgresDatabase {
         self.inner.begin_db_transaction().await
     }
 
+    async fn get_proofs_by_ys(&self, ys: Vec<PublicKey>) -> Result<Vec<ProofInfo>, FfiError> {
+        self.inner.get_proofs_by_ys(ys).await
+    }
+
     async fn get_mint(&self, mint_url: MintUrl) -> Result<Option<MintInfo>, FfiError> {
         self.inner.get_mint(mint_url).await
     }

+ 5 - 1
crates/cdk-ffi/src/sqlite.rs

@@ -6,7 +6,7 @@ use cdk_sqlite::SqliteConnectionManager;
 
 use crate::{
     CurrencyUnit, FfiError, FfiWalletSQLDatabase, Id, KeySetInfo, Keys, MeltQuote, MintInfo,
-    MintQuote, MintUrl, ProofInfo, ProofState, SpendingConditions, Transaction,
+    MintQuote, MintUrl, ProofInfo, ProofState, PublicKey, SpendingConditions, Transaction,
     TransactionDirection, TransactionId, WalletDatabase,
 };
 
@@ -124,6 +124,10 @@ impl WalletDatabase for WalletSqliteDatabase {
             .await
     }
 
+    async fn get_proofs_by_ys(&self, ys: Vec<PublicKey>) -> Result<Vec<ProofInfo>, FfiError> {
+        self.inner.get_proofs_by_ys(ys).await
+    }
+
     async fn get_balance(
         &self,
         mint_url: Option<MintUrl>,

+ 13 - 0
crates/cdk-ffi/src/wallet.rs

@@ -367,6 +367,19 @@ impl Wallet {
         Ok(transaction.map(Into::into))
     }
 
+    /// Get proofs for a transaction by transaction ID
+    ///
+    /// This retrieves all proofs associated with a transaction by looking up
+    /// the transaction's Y values and fetching the corresponding proofs.
+    pub async fn get_proofs_for_transaction(
+        &self,
+        id: TransactionId,
+    ) -> Result<Vec<Proof>, FfiError> {
+        let cdk_id = id.try_into()?;
+        let proofs = self.inner.get_proofs_for_transaction(cdk_id).await?;
+        Ok(proofs.into_iter().map(Into::into).collect())
+    }
+
     /// Revert a transaction
     pub async fn revert_transaction(&self, id: TransactionId) -> Result<(), FfiError> {
         let cdk_id = id.try_into()?;

+ 22 - 0
crates/cdk-redb/src/wallet/mod.rs

@@ -408,6 +408,28 @@ impl WalletDatabase for WalletRedbDatabase {
         Ok(proofs)
     }
 
+    #[instrument(skip(self, ys))]
+    async fn get_proofs_by_ys(&self, ys: Vec<PublicKey>) -> Result<Vec<ProofInfo>, Self::Err> {
+        if ys.is_empty() {
+            return Ok(Vec::new());
+        }
+
+        let read_txn = self.db.begin_read().map_err(Error::from)?;
+        let table = read_txn.open_table(PROOFS_TABLE).map_err(Error::from)?;
+
+        let mut proofs = Vec::new();
+
+        for y in ys {
+            if let Some(proof) = table.get(y.to_bytes().as_slice()).map_err(Error::from)? {
+                let proof_info =
+                    serde_json::from_str::<ProofInfo>(proof.value()).map_err(Error::from)?;
+                proofs.push(proof_info);
+            }
+        }
+
+        Ok(proofs)
+    }
+
     async fn get_balance(
         &self,
         mint_url: Option<MintUrl>,

+ 35 - 0
crates/cdk-sql-common/src/wallet/mod.rs

@@ -1090,6 +1090,41 @@ where
         get_proofs_inner(&*conn, mint_url, unit, state, spending_conditions, false).await
     }
 
+    #[instrument(skip(self, ys))]
+    async fn get_proofs_by_ys(&self, ys: Vec<PublicKey>) -> Result<Vec<ProofInfo>, Self::Err> {
+        if ys.is_empty() {
+            return Ok(Vec::new());
+        }
+
+        let conn = self.pool.get().map_err(|e| Error::Database(Box::new(e)))?;
+        Ok(query(
+            r#"
+            SELECT
+                amount,
+                unit,
+                keyset_id,
+                secret,
+                c,
+                witness,
+                dleq_e,
+                dleq_s,
+                dleq_r,
+                y,
+                mint_url,
+                state,
+                spending_condition
+            FROM proof
+            WHERE y IN (:ys)
+        "#,
+        )?
+        .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
+        .fetch_all(&*conn)
+        .await?
+        .into_iter()
+        .filter_map(|row| sql_row_to_proof_info(row).ok())
+        .collect::<Vec<_>>())
+    }
+
     async fn get_balance(
         &self,
         mint_url: Option<MintUrl>,

+ 89 - 0
crates/cdk-sqlite/src/wallet/mod.rs

@@ -190,4 +190,93 @@ mod tests {
         }
         tx.commit().await.expect("commit");
     }
+
+    #[tokio::test]
+    async fn test_get_proofs_by_ys() {
+        use cdk_common::common::ProofInfo;
+        use cdk_common::mint_url::MintUrl;
+        use cdk_common::nuts::{CurrencyUnit, Id, Proof, SecretKey};
+        use cdk_common::Amount;
+
+        // Create a temporary database
+        let path = std::env::temp_dir().to_path_buf().join(format!(
+            "cdk-test-proofs-by-ys-{}.sqlite",
+            uuid::Uuid::new_v4()
+        ));
+
+        #[cfg(feature = "sqlcipher")]
+        let db = WalletSqliteDatabase::new((path, "password".to_string()))
+            .await
+            .unwrap();
+
+        #[cfg(not(feature = "sqlcipher"))]
+        let db = WalletSqliteDatabase::new(path).await.unwrap();
+
+        // Create multiple proofs
+        let keyset_id = Id::from_str("00deadbeef123456").unwrap();
+        let mint_url = MintUrl::from_str("https://example.com").unwrap();
+
+        let mut proof_infos = vec![];
+        let mut expected_ys = vec![];
+
+        // Generate valid public keys using SecretKey
+        for _i in 0..5 {
+            let secret = Secret::generate();
+
+            // Generate a valid public key from a secret key
+            let secret_key = SecretKey::generate();
+            let c = secret_key.public_key();
+
+            let proof = Proof::new(Amount::from(64), keyset_id, secret, c);
+
+            let proof_info =
+                ProofInfo::new(proof, mint_url.clone(), State::Unspent, CurrencyUnit::Sat).unwrap();
+
+            expected_ys.push(proof_info.y);
+            proof_infos.push(proof_info);
+        }
+
+        // Store all proofs in the database
+        db.update_proofs(proof_infos.clone(), vec![]).await.unwrap();
+
+        // Test 1: Retrieve all proofs by their Y values
+        let retrieved_proofs = db.get_proofs_by_ys(expected_ys.clone()).await.unwrap();
+
+        assert_eq!(retrieved_proofs.len(), 5);
+        for retrieved_proof in &retrieved_proofs {
+            assert!(expected_ys.contains(&retrieved_proof.y));
+        }
+
+        // Test 2: Retrieve subset of proofs (first 3)
+        let subset_ys = expected_ys[0..3].to_vec();
+        let subset_proofs = db.get_proofs_by_ys(subset_ys.clone()).await.unwrap();
+
+        assert_eq!(subset_proofs.len(), 3);
+        for retrieved_proof in &subset_proofs {
+            assert!(subset_ys.contains(&retrieved_proof.y));
+        }
+
+        // Test 3: Retrieve with non-existent Y values
+        let non_existent_secret_key = SecretKey::generate();
+        let non_existent_y = non_existent_secret_key.public_key();
+        let mixed_ys = vec![expected_ys[0], non_existent_y, expected_ys[1]];
+        let mixed_proofs = db.get_proofs_by_ys(mixed_ys).await.unwrap();
+
+        // Should only return the 2 that exist
+        assert_eq!(mixed_proofs.len(), 2);
+
+        // Test 4: Empty input returns empty result
+        let empty_result = db.get_proofs_by_ys(vec![]).await.unwrap();
+        assert_eq!(empty_result.len(), 0);
+
+        // Test 5: Verify retrieved proof data matches original
+        let single_y = vec![expected_ys[2]];
+        let single_proof = db.get_proofs_by_ys(single_y).await.unwrap();
+
+        assert_eq!(single_proof.len(), 1);
+        assert_eq!(single_proof[0].y, proof_infos[2].y);
+        assert_eq!(single_proof[0].proof.amount, proof_infos[2].proof.amount);
+        assert_eq!(single_proof[0].mint_url, proof_infos[2].mint_url);
+        assert_eq!(single_proof[0].state, proof_infos[2].state);
+    }
 }

+ 1 - 2
crates/cdk/examples/p2pk.rs

@@ -27,8 +27,7 @@ async fn main() -> Result<(), Error> {
     let seed = random::<[u8; 64]>();
 
     // Define the mint URL and currency unit
-    // let mint_url = "https://fake.thesimplekid.dev";
-    let mint_url = "https://testnut.cashu.space";
+    let mint_url = "https://fake.thesimplekid.dev";
     let unit = CurrencyUnit::Sat;
     let amount = Amount::from(100);
 

+ 45 - 1
crates/cdk/src/wallet/multi_mint_wallet.rs

@@ -12,7 +12,7 @@ use anyhow::Result;
 use cdk_common::database;
 use cdk_common::database::WalletDatabase;
 use cdk_common::task::spawn;
-use cdk_common::wallet::{MeltQuote, Transaction, TransactionDirection};
+use cdk_common::wallet::{MeltQuote, Transaction, TransactionDirection, TransactionId};
 use tokio::sync::RwLock;
 use tracing::instrument;
 use zeroize::Zeroize;
@@ -563,6 +563,50 @@ impl MultiMintWallet {
         Ok(transactions)
     }
 
+    /// Get proofs for a transaction by transaction ID
+    ///
+    /// This retrieves all proofs associated with a transaction. If `mint_url` is provided,
+    /// it will only check that specific mint's wallet. Otherwise, it searches across all
+    /// wallets to find which mint the transaction belongs to.
+    ///
+    /// # Arguments
+    ///
+    /// * `id` - The transaction ID
+    /// * `mint_url` - Optional mint URL to check directly, avoiding iteration over all wallets
+    #[instrument(skip(self))]
+    pub async fn get_proofs_for_transaction(
+        &self,
+        id: TransactionId,
+        mint_url: Option<MintUrl>,
+    ) -> Result<Proofs, Error> {
+        let wallets = self.wallets.read().await;
+
+        // If mint_url is provided, try that wallet directly
+        if let Some(mint_url) = mint_url {
+            if let Some(wallet) = wallets.get(&mint_url) {
+                // Verify the transaction exists in this wallet
+                if wallet.get_transaction(id).await?.is_some() {
+                    return wallet.get_proofs_for_transaction(id).await;
+                }
+            }
+            // Transaction not found in specified mint
+            return Err(Error::TransactionNotFound);
+        }
+
+        // No mint_url provided, search across all wallets
+        for (mint_url, wallet) in wallets.iter() {
+            if let Some(transaction) = wallet.get_transaction(id).await? {
+                // Verify the transaction belongs to this wallet's mint
+                if &transaction.mint_url == mint_url {
+                    return wallet.get_proofs_for_transaction(id).await;
+                }
+            }
+        }
+
+        // Transaction not found in any wallet
+        Err(Error::TransactionNotFound)
+    }
+
     /// Get total balance across all wallets (since all wallets use the same currency unit)
     #[instrument(skip(self))]
     pub async fn total_balance(&self) -> Result<Amount, Error> {

+ 23 - 0
crates/cdk/src/wallet/transactions.rs

@@ -1,4 +1,5 @@
 use cdk_common::wallet::{Transaction, TransactionDirection, TransactionId};
+use cdk_common::Proofs;
 
 use crate::{Error, Wallet};
 
@@ -29,6 +30,28 @@ impl Wallet {
         Ok(transaction)
     }
 
+    /// Get proofs for a transaction by transaction ID
+    ///
+    /// This retrieves all proofs associated with a transaction by looking up
+    /// the transaction's Y values and fetching the corresponding proofs.
+    pub async fn get_proofs_for_transaction(&self, id: TransactionId) -> Result<Proofs, Error> {
+        let transaction = self
+            .localstore
+            .get_transaction(id)
+            .await?
+            .ok_or(Error::TransactionNotFound)?;
+
+        let proofs = self
+            .localstore
+            .get_proofs_by_ys(transaction.ys)
+            .await?
+            .into_iter()
+            .map(|p| p.proof)
+            .collect();
+
+        Ok(proofs)
+    }
+
     /// Revert a transaction
     pub async fn revert_transaction(&self, id: TransactionId) -> Result<(), Error> {
         let tx = self