Browse Source

refactor: get spent and pending proofs by list of secrets or ys

thesimplekid 9 months ago
parent
commit
ed007c475e

+ 61 - 20
crates/cdk-redb/src/mint/mod.rs

@@ -518,32 +518,54 @@ impl MintDatabase for MintRedbDatabase {
         Ok(())
     }
 
-    async fn get_spent_proof_by_y(&self, y: &PublicKey) -> Result<Option<Proof>, Self::Err> {
+    async fn get_spent_proofs_by_ys(
+        &self,
+        ys: &[PublicKey],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
         let db = self.db.lock().await;
         let read_txn = db.begin_read().map_err(Error::from)?;
         let table = read_txn
             .open_table(SPENT_PROOFS_TABLE)
             .map_err(Error::from)?;
 
-        match table.get(y.to_bytes()).map_err(Error::from)? {
-            Some(proof) => Ok(serde_json::from_str(proof.value()).map_err(Error::from)?),
-            None => Ok(None),
+        let mut proofs = Vec::with_capacity(ys.len());
+
+        for y in ys {
+            match table.get(y.to_bytes()).map_err(Error::from)? {
+                Some(proof) => proofs.push(Some(
+                    serde_json::from_str(proof.value()).map_err(Error::from)?,
+                )),
+                None => proofs.push(None),
+            }
         }
+
+        Ok(proofs)
     }
 
-    async fn get_spent_proof_by_secret(&self, secret: &Secret) -> Result<Option<Proof>, Self::Err> {
+    async fn get_spent_proofs_by_secrets(
+        &self,
+        secrets: &[Secret],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
         let db = self.db.lock().await;
         let read_txn = db.begin_read().map_err(Error::from)?;
         let table = read_txn
             .open_table(SPENT_PROOFS_TABLE)
             .map_err(Error::from)?;
 
-        let y: PublicKey = hash_to_curve(&secret.to_bytes())?;
+        let mut proofs = Vec::with_capacity(secrets.len());
 
-        match table.get(y.to_bytes()).map_err(Error::from)? {
-            Some(proof) => Ok(serde_json::from_str(proof.value()).map_err(Error::from)?),
-            None => Ok(None),
+        for secret in secrets {
+            let y: PublicKey = hash_to_curve(&secret.to_bytes())?;
+
+            match table.get(y.to_bytes()).map_err(Error::from)? {
+                Some(proof) => proofs.push(Some(
+                    serde_json::from_str(proof.value()).map_err(Error::from)?,
+                )),
+                None => proofs.push(None),
+            }
         }
+
+        Ok(proofs)
     }
 
     async fn add_pending_proofs(&self, proofs: Vec<Proof>) -> Result<(), Self::Err> {
@@ -569,35 +591,54 @@ impl MintDatabase for MintRedbDatabase {
         Ok(())
     }
 
-    async fn get_pending_proof_by_y(&self, y: &PublicKey) -> Result<Option<Proof>, Self::Err> {
+    async fn get_pending_proofs_by_ys(
+        &self,
+        ys: &[PublicKey],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
         let db = self.db.lock().await;
         let read_txn = db.begin_read().map_err(Error::from)?;
         let table = read_txn
             .open_table(PENDING_PROOFS_TABLE)
             .map_err(Error::from)?;
 
-        match table.get(y.to_bytes()).map_err(Error::from)? {
-            Some(proof) => Ok(serde_json::from_str(proof.value()).map_err(Error::from)?),
-            None => Ok(None),
+        let mut proofs = Vec::with_capacity(ys.len());
+
+        for y in ys {
+            match table.get(y.to_bytes()).map_err(Error::from)? {
+                Some(proof) => proofs.push(Some(
+                    serde_json::from_str(proof.value()).map_err(Error::from)?,
+                )),
+                None => proofs.push(None),
+            }
         }
+
+        Ok(proofs)
     }
 
-    async fn get_pending_proof_by_secret(
+    async fn get_pending_proofs_by_secrets(
         &self,
-        secret: &Secret,
-    ) -> Result<Option<Proof>, Self::Err> {
+        secrets: &[Secret],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
         let db = self.db.lock().await;
         let read_txn = db.begin_read().map_err(Error::from)?;
         let table = read_txn
             .open_table(PENDING_PROOFS_TABLE)
             .map_err(Error::from)?;
 
-        let secret_hash = hash_to_curve(&secret.to_bytes())?;
+        let mut proofs = Vec::with_capacity(secrets.len());
 
-        match table.get(secret_hash.to_bytes()).map_err(Error::from)? {
-            Some(proof) => Ok(serde_json::from_str(proof.value()).map_err(Error::from)?),
-            None => Ok(None),
+        for secret in secrets {
+            let y: PublicKey = hash_to_curve(&secret.to_bytes())?;
+
+            match table.get(y.to_bytes()).map_err(Error::from)? {
+                Some(proof) => proofs.push(Some(
+                    serde_json::from_str(proof.value()).map_err(Error::from)?,
+                )),
+                None => proofs.push(None),
+            }
         }
+
+        Ok(proofs)
     }
 
     async fn remove_pending_proofs(&self, secrets: Vec<&Secret>) -> Result<(), Self::Err> {

+ 106 - 63
crates/cdk-sqlite/src/mint/mod.rs

@@ -506,51 +506,75 @@ VALUES (?, ?, ?, ?, ?, ?, ?);
         transaction.commit().await.map_err(Error::from)?;
         Ok(())
     }
-    async fn get_spent_proof_by_secret(&self, secret: &Secret) -> Result<Option<Proof>, Self::Err> {
-        let rec = sqlx::query(
-            r#"
+    async fn get_spent_proofs_by_secrets(
+        &self,
+        secrets: &[Secret],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
+        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
+
+        let mut proofs = Vec::with_capacity(secrets.len());
+
+        for secret in secrets {
+            let rec = sqlx::query(
+                r#"
 SELECT *
 FROM proof
 WHERE secret=?
 AND state="SPENT";
         "#,
-        )
-        .bind(secret.to_string())
-        .fetch_one(&self.pool)
-        .await;
+            )
+            .bind(secret.to_string())
+            .fetch_one(&mut transaction)
+            .await;
 
-        let rec = match rec {
-            Ok(rec) => rec,
-            Err(err) => match err {
-                sqlx::Error::RowNotFound => return Ok(None),
-                _ => return Err(Error::SQLX(err).into()),
-            },
-        };
+            match rec {
+                Ok(rec) => {
+                    proofs.push(Some(sqlite_row_to_proof(rec)?));
+                }
+                Err(err) => match err {
+                    sqlx::Error::RowNotFound => proofs.push(None),
+                    _ => return Err(Error::SQLX(err).into()),
+                },
+            };
+        }
+        transaction.commit().await.map_err(Error::from)?;
 
-        Ok(Some(sqlite_row_to_proof(rec)?))
+        Ok(proofs)
     }
-    async fn get_spent_proof_by_y(&self, y: &PublicKey) -> Result<Option<Proof>, Self::Err> {
-        let rec = sqlx::query(
-            r#"
+    async fn get_spent_proofs_by_ys(
+        &self,
+        ys: &[PublicKey],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
+        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
+
+        let mut proofs = Vec::with_capacity(ys.len());
+        for y in ys {
+            let rec = sqlx::query(
+                r#"
 SELECT *
 FROM proof
 WHERE y=?
 AND state="SPENT";
         "#,
-        )
-        .bind(y.to_bytes().to_vec())
-        .fetch_one(&self.pool)
-        .await;
+            )
+            .bind(y.to_bytes().to_vec())
+            .fetch_one(&mut transaction)
+            .await;
 
-        let rec = match rec {
-            Ok(rec) => rec,
-            Err(err) => match err {
-                sqlx::Error::RowNotFound => return Ok(None),
-                _ => return Err(Error::SQLX(err).into()),
-            },
-        };
+            match rec {
+                Ok(rec) => {
+                    proofs.push(Some(sqlite_row_to_proof(rec)?));
+                }
+                Err(err) => match err {
+                    sqlx::Error::RowNotFound => proofs.push(None),
+                    _ => return Err(Error::SQLX(err).into()),
+                },
+            };
+        }
+
+        transaction.commit().await.map_err(Error::from)?;
 
-        Ok(Some(sqlite_row_to_proof(rec)?))
+        Ok(proofs)
     }
 
     async fn add_pending_proofs(&self, proofs: Proofs) -> Result<(), Self::Err> {
@@ -578,53 +602,72 @@ VALUES (?, ?, ?, ?, ?, ?, ?);
 
         Ok(())
     }
-    async fn get_pending_proof_by_secret(
+    async fn get_pending_proofs_by_secrets(
         &self,
-        secret: &Secret,
-    ) -> Result<Option<Proof>, Self::Err> {
-        let rec = sqlx::query(
-            r#"
+        secrets: &[Secret],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
+        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
+
+        let mut proofs = Vec::with_capacity(secrets.len());
+
+        for secret in secrets {
+            let rec = sqlx::query(
+                r#"
 SELECT *
 FROM proof
 WHERE secret=?
 AND state="PENDING";
         "#,
-        )
-        .bind(secret.to_string())
-        .fetch_one(&self.pool)
-        .await;
+            )
+            .bind(secret.to_string())
+            .fetch_one(&mut transaction)
+            .await;
+            match rec {
+                Ok(rec) => {
+                    proofs.push(Some(sqlite_row_to_proof(rec)?));
+                }
+                Err(err) => match err {
+                    sqlx::Error::RowNotFound => proofs.push(None),
+                    _ => return Err(Error::SQLX(err).into()),
+                },
+            };
+        }
+        transaction.commit().await.map_err(Error::from)?;
+        Ok(proofs)
+    }
+    async fn get_pending_proofs_by_ys(
+        &self,
+        ys: &[PublicKey],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
+        let mut transaction = self.pool.begin().await.map_err(Error::from)?;
 
-        let rec = match rec {
-            Ok(rec) => rec,
-            Err(err) => match err {
-                sqlx::Error::RowNotFound => return Ok(None),
-                _ => return Err(Error::SQLX(err).into()),
-            },
-        };
+        let mut proofs = Vec::with_capacity(ys.len());
 
-        Ok(Some(sqlite_row_to_proof(rec)?))
-    }
-    async fn get_pending_proof_by_y(&self, y: &PublicKey) -> Result<Option<Proof>, Self::Err> {
-        let rec = sqlx::query(
-            r#"
+        for y in ys {
+            let rec = sqlx::query(
+                r#"
 SELECT *
 FROM proof
 WHERE y=?
 AND state="PENDING";
         "#,
-        )
-        .bind(y.to_bytes().to_vec())
-        .fetch_one(&self.pool)
-        .await;
+            )
+            .bind(y.to_bytes().to_vec())
+            .fetch_one(&mut transaction)
+            .await;
 
-        let rec = match rec {
-            Ok(rec) => rec,
-            Err(err) => match err {
-                sqlx::Error::RowNotFound => return Ok(None),
-                _ => return Err(Error::SQLX(err).into()),
-            },
-        };
-        Ok(Some(sqlite_row_to_proof(rec)?))
+            match rec {
+                Ok(rec) => {
+                    proofs.push(Some(sqlite_row_to_proof(rec)?));
+                }
+                Err(err) => match err {
+                    sqlx::Error::RowNotFound => proofs.push(None),
+                    _ => return Err(Error::SQLX(err).into()),
+                },
+            };
+        }
+
+        Ok(proofs)
     }
     async fn remove_pending_proofs(&self, secrets: Vec<&Secret>) -> Result<(), Self::Err> {
         let mut transaction = self.pool.begin().await.map_err(Error::from)?;

+ 63 - 21
crates/cdk/src/cdk_database/mint_memory.rs

@@ -223,17 +223,40 @@ impl MintDatabase for MintMemoryDatabase {
         Ok(())
     }
 
-    async fn get_spent_proof_by_secret(&self, secret: &Secret) -> Result<Option<Proof>, Self::Err> {
-        Ok(self
-            .spent_proofs
-            .read()
-            .await
-            .get(&hash_to_curve(&secret.to_bytes())?.to_bytes())
-            .cloned())
+    async fn get_spent_proofs_by_secrets(
+        &self,
+        secrets: &[Secret],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
+        let spent_proofs = self.spent_proofs.read().await;
+
+        let mut proofs = Vec::with_capacity(secrets.len());
+
+        for secret in secrets {
+            let y = hash_to_curve(&secret.to_bytes())?;
+
+            let proof = spent_proofs.get(&y.to_bytes()).cloned();
+
+            proofs.push(proof);
+        }
+
+        Ok(proofs)
     }
 
-    async fn get_spent_proof_by_y(&self, y: &PublicKey) -> Result<Option<Proof>, Self::Err> {
-        Ok(self.spent_proofs.read().await.get(&y.to_bytes()).cloned())
+    async fn get_spent_proofs_by_ys(
+        &self,
+        ys: &[PublicKey],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
+        let spent_proofs = self.spent_proofs.read().await;
+
+        let mut proofs = Vec::with_capacity(ys.len());
+
+        for y in ys {
+            let proof = spent_proofs.get(&y.to_bytes()).cloned();
+
+            proofs.push(proof);
+        }
+
+        Ok(proofs)
     }
 
     async fn add_pending_proofs(&self, pending_proofs: Proofs) -> Result<(), Self::Err> {
@@ -245,21 +268,40 @@ impl MintDatabase for MintMemoryDatabase {
         Ok(())
     }
 
-    async fn get_pending_proof_by_secret(
+    async fn get_pending_proofs_by_secrets(
         &self,
-        secret: &Secret,
-    ) -> Result<Option<Proof>, Self::Err> {
-        let secret_point = hash_to_curve(&secret.to_bytes())?;
-        Ok(self
-            .pending_proofs
-            .read()
-            .await
-            .get(&secret_point.to_bytes())
-            .cloned())
+        secrets: &[Secret],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
+        let spent_proofs = self.pending_proofs.read().await;
+
+        let mut proofs = Vec::with_capacity(secrets.len());
+
+        for secret in secrets {
+            let y = hash_to_curve(&secret.to_bytes())?;
+
+            let proof = spent_proofs.get(&y.to_bytes()).cloned();
+
+            proofs.push(proof);
+        }
+
+        Ok(proofs)
     }
 
-    async fn get_pending_proof_by_y(&self, y: &PublicKey) -> Result<Option<Proof>, Self::Err> {
-        Ok(self.pending_proofs.read().await.get(&y.to_bytes()).cloned())
+    async fn get_pending_proofs_by_ys(
+        &self,
+        ys: &[PublicKey],
+    ) -> Result<Vec<Option<Proof>>, Self::Err> {
+        let spent_proofs = self.pending_proofs.read().await;
+
+        let mut proofs = Vec::with_capacity(ys.len());
+
+        for y in ys {
+            let proof = spent_proofs.get(&y.to_bytes()).cloned();
+
+            proofs.push(proof);
+        }
+
+        Ok(proofs)
     }
 
     async fn remove_pending_proofs(&self, secrets: Vec<&Secret>) -> Result<(), Self::Err> {

+ 19 - 10
crates/cdk/src/cdk_database/mod.rs

@@ -219,20 +219,29 @@ pub trait MintDatabase {
 
     /// Add spent [`Proofs`]
     async fn add_spent_proofs(&self, proof: Proofs) -> Result<(), Self::Err>;
-    /// Get spent [`Proof`] by secret
-    async fn get_spent_proof_by_secret(&self, secret: &Secret) -> Result<Option<Proof>, Self::Err>;
-    /// Get spent [`Proof`] by y
-    async fn get_spent_proof_by_y(&self, y: &PublicKey) -> Result<Option<Proof>, Self::Err>;
+    /// Get spent [`Proofs`] by secrets
+    async fn get_spent_proofs_by_secrets(
+        &self,
+        secret: &[Secret],
+    ) -> Result<Vec<Option<Proof>>, Self::Err>;
+    /// Get spent [`Proofs`] by ys
+    async fn get_spent_proofs_by_ys(
+        &self,
+        y: &[PublicKey],
+    ) -> Result<Vec<Option<Proof>>, Self::Err>;
 
     /// Add pending [`Proofs`]
     async fn add_pending_proofs(&self, proof: Proofs) -> Result<(), Self::Err>;
-    /// Get pending [`Proof`] by secret
-    async fn get_pending_proof_by_secret(
+    /// Get pending [`Proofs`] by secrets
+    async fn get_pending_proofs_by_secrets(
+        &self,
+        secrets: &[Secret],
+    ) -> Result<Vec<Option<Proof>>, Self::Err>;
+    /// Get pending [`Proofs`] by ys
+    async fn get_pending_proofs_by_ys(
         &self,
-        secret: &Secret,
-    ) -> Result<Option<Proof>, Self::Err>;
-    /// Get pending [`Proof`] by y
-    async fn get_pending_proof_by_y(&self, y: &PublicKey) -> Result<Option<Proof>, Self::Err>;
+        ys: &[PublicKey],
+    ) -> Result<Vec<Option<Proof>>, Self::Err>;
     /// Remove pending [`Proofs`]
     async fn remove_pending_proofs(&self, secret: Vec<&Secret>) -> Result<(), Self::Err>;
 

+ 88 - 32
crates/cdk/src/mint/mod.rs

@@ -606,15 +606,43 @@ impl Mint {
 
         let proof_count = swap_request.inputs.len();
 
-        let secrets: HashSet<[u8; 33]> = swap_request
+        let secrets: Vec<PublicKey> = swap_request
             .inputs
             .iter()
             .flat_map(|p| hash_to_curve(&p.secret.to_bytes()))
-            .map(|p| p.to_bytes())
             .collect();
 
+        let pending_proofs: Proofs = self
+            .localstore
+            .get_pending_proofs_by_ys(&secrets)
+            .await?
+            .into_iter()
+            .flatten()
+            .collect();
+
+        if !pending_proofs.is_empty() {
+            return Err(Error::TokenPending);
+        }
+
+        let spent_proofs: Proofs = self
+            .localstore
+            .get_spent_proofs_by_ys(&secrets)
+            .await?
+            .into_iter()
+            .flatten()
+            .collect();
+
+        if !spent_proofs.is_empty() {
+            return Err(Error::TokenAlreadySpent);
+        }
+
         // Check that there are no duplicate proofs in request
-        if secrets.len().ne(&proof_count) {
+        if secrets
+            .iter()
+            .collect::<HashSet<&PublicKey>>()
+            .len()
+            .ne(&proof_count)
+        {
             return Err(Error::DuplicateProofs);
         }
 
@@ -709,16 +737,6 @@ impl Mint {
             }
         }
 
-        let y: PublicKey = hash_to_curve(&proof.secret.to_bytes())?;
-
-        if self.localstore.get_spent_proof_by_y(&y).await?.is_some() {
-            return Err(Error::TokenAlreadySpent);
-        }
-
-        if self.localstore.get_pending_proof_by_y(&y).await?.is_some() {
-            return Err(Error::TokenPending);
-        }
-
         self.ensure_keyset_loaded(&proof.keyset_id).await?;
         let keysets = self.keysets.read().await;
         let keyset = keysets.get(&proof.keyset_id).ok_or(Error::UnknownKeySet)?;
@@ -739,13 +757,28 @@ impl Mint {
     ) -> Result<CheckStateResponse, Error> {
         let mut states = Vec::with_capacity(check_state.ys.len());
 
-        for y in &check_state.ys {
-            let state = if self.localstore.get_spent_proof_by_y(y).await?.is_some() {
-                State::Spent
-            } else if self.localstore.get_pending_proof_by_y(y).await?.is_some() {
-                State::Pending
-            } else {
-                State::Unspent
+        let spent_proofs = self
+            .localstore
+            .get_spent_proofs_by_ys(&check_state.ys)
+            .await?;
+        let pending_proofs = self
+            .localstore
+            .get_pending_proofs_by_ys(&check_state.ys)
+            .await?;
+
+        for ((spent, pending), y) in spent_proofs
+            .iter()
+            .zip(&pending_proofs)
+            .zip(&check_state.ys)
+        {
+            let state = match (spent, pending) {
+                (None, None) => State::Unspent,
+                (Some(_), None) => State::Spent,
+                (None, Some(_)) => State::Pending,
+                (Some(_), Some(_)) => {
+                    tracing::error!("Proof should not be both pending and spent. Assuming Spent");
+                    State::Spent
+                }
             };
 
             states.push(ProofState {
@@ -763,6 +796,41 @@ impl Mint {
         &self,
         melt_request: &MeltBolt11Request,
     ) -> Result<MeltQuote, Error> {
+        let secrets: Vec<PublicKey> = melt_request
+            .inputs
+            .iter()
+            .flat_map(|p| hash_to_curve(&p.secret.to_bytes()))
+            .collect();
+
+        // Ensure proofs are unique and not being double spent
+        if melt_request.inputs.len() != secrets.iter().collect::<HashSet<_>>().len() {
+            return Err(Error::DuplicateProofs);
+        }
+
+        let pending_proofs: Proofs = self
+            .localstore
+            .get_pending_proofs_by_ys(&secrets)
+            .await?
+            .into_iter()
+            .flatten()
+            .collect();
+
+        if !pending_proofs.is_empty() {
+            return Err(Error::TokenPending);
+        }
+
+        let spent_proofs: Proofs = self
+            .localstore
+            .get_spent_proofs_by_ys(&secrets)
+            .await?
+            .into_iter()
+            .flatten()
+            .collect();
+
+        if !spent_proofs.is_empty() {
+            return Err(Error::TokenAlreadySpent);
+        }
+
         for proof in &melt_request.inputs {
             self.verify_proof(proof).await?;
         }
@@ -858,18 +926,6 @@ impl Mint {
             return Err(Error::MultipleUnits);
         }
 
-        let secrets: HashSet<[u8; 33]> = melt_request
-            .inputs
-            .iter()
-            .flat_map(|p| hash_to_curve(&p.secret.to_bytes()))
-            .map(|p| p.to_bytes())
-            .collect();
-
-        // Ensure proofs are unique and not being double spent
-        if melt_request.inputs.len().ne(&secrets.len()) {
-            return Err(Error::DuplicateProofs);
-        }
-
         // Add proofs to pending
         self.localstore
             .add_pending_proofs(melt_request.inputs.clone())