Explorar o código

Minor improvement in the SQLite crate

This is a minor improvement over the SQLite crate, which performs fewer SQL
statements and fetches multiple results instead.

This will also remove some redundant commit() and rollback calls. Rollback
already happens on drop, and there is no need for a commit if the database
hasn't changed, as the transaction is used as a locking mechanism in this
context.
Cesar Rodas hai 3 meses
pai
achega
af2fe580f4
Modificáronse 1 ficheiros con 135 adicións e 149 borrados
  1. 135 149
      crates/cdk-sqlite/src/mint/mod.rs

+ 135 - 149
crates/cdk-sqlite/src/mint/mod.rs

@@ -781,39 +781,35 @@ VALUES (?, ?, ?, ?, ?, ?, ?, ?);
     async fn get_proofs_by_ys(&self, ys: &[PublicKey]) -> Result<Vec<Option<Proof>>, Self::Err> {
         let mut transaction = self.pool.begin().await.map_err(Error::from)?;
 
-        let mut proofs = Vec::with_capacity(ys.len());
-
-        for y in ys {
-            let rec = sqlx::query(
-                r#"
-SELECT *
-FROM proof
-WHERE y=?;
-        "#,
-            )
-            .bind(y.to_bytes().to_vec())
-            .fetch_one(&mut transaction)
-            .await;
-
-            match rec {
-                Ok(rec) => {
-                    proofs.push(Some(sqlite_row_to_proof(rec)?));
-                }
-                Err(err) => match err {
-                    sqlx::Error::RowNotFound => proofs.push(None),
-                    _ => {
-                        if let Err(err) = transaction.rollback().await {
-                            tracing::error!("Could not rollback sql transaction: {}", err);
-                        }
-                        return Err(Error::SQLX(err).into());
-                    }
-                },
-            };
-        }
-
-        transaction.commit().await.map_err(Error::from)?;
+        let sql = format!(
+            "SELECT * FROM proof WHERE y IN ({})",
+            "?,".repeat(ys.len()).trim_end_matches(',')
+        );
+
+        let mut proofs = ys
+            .iter()
+            .fold(sqlx::query(&sql), |query, y| {
+                query.bind(y.to_bytes().to_vec())
+            })
+            .fetch_all(&mut transaction)
+            .await
+            .map_err(|err| {
+                tracing::error!("SQLite could not get state of proof: {err:?}");
+                Error::SQLX(err)
+            })?
+            .into_iter()
+            .map(|row| {
+                PublicKey::from_slice(row.get("y"))
+                    .map_err(Error::from)
+                    .and_then(|y| {
+                        sqlite_row_to_proof(row)
+                            .map_err(Error::from)
+                            .map(|proof| (y, proof))
+                    })
+            })
+            .collect::<Result<HashMap<_, _>, _>>()?;
 
-        Ok(proofs)
+        Ok(ys.iter().map(|y| proofs.remove(y)).collect())
     }
 
     async fn get_proof_ys_by_quote_id(&self, quote_id: &str) -> Result<Vec<PublicKey>, Self::Err> {
@@ -862,41 +858,36 @@ WHERE quote_id=?;
     async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
         let mut transaction = self.pool.begin().await.map_err(Error::from)?;
 
-        let mut states = Vec::with_capacity(ys.len());
-
-        for y in ys {
-            let rec = sqlx::query(
-                r#"
-SELECT state
-FROM proof
-WHERE y=?;
-        "#,
-            )
-            .bind(y.to_bytes().to_vec())
-            .fetch_one(&mut transaction)
-            .await;
-
-            match rec {
-                Ok(rec) => {
-                    let state: String = rec.get("state");
-                    let state = State::from_str(&state).map_err(Error::from)?;
-                    states.push(Some(state));
-                }
-                Err(err) => match err {
-                    sqlx::Error::RowNotFound => states.push(None),
-                    _ => {
-                        if let Err(err) = transaction.rollback().await {
-                            tracing::error!("Could not rollback sql transaction: {}", err);
-                        }
-                        return Err(Error::SQLX(err).into());
-                    }
-                },
-            };
-        }
-
-        transaction.commit().await.map_err(Error::from)?;
+        let sql = format!(
+            "SELECT y, state FROM proof WHERE y IN ({})",
+            "?,".repeat(ys.len()).trim_end_matches(',')
+        );
+
+        let mut current_states = ys
+            .iter()
+            .fold(sqlx::query(&sql), |query, y| {
+                query.bind(y.to_bytes().to_vec())
+            })
+            .fetch_all(&mut transaction)
+            .await
+            .map_err(|err| {
+                tracing::error!("SQLite could not get state of proof: {err:?}");
+                Error::SQLX(err)
+            })?
+            .into_iter()
+            .map(|row| {
+                PublicKey::from_slice(row.get("y"))
+                    .map_err(Error::from)
+                    .and_then(|y| {
+                        let state: String = row.get("state");
+                        State::from_str(&state)
+                            .map_err(Error::from)
+                            .map(|state| (y, state))
+                    })
+            })
+            .collect::<Result<HashMap<_, _>, _>>()?;
 
-        Ok(states)
+        Ok(ys.iter().map(|y| current_states.remove(y)).collect())
     }
 
     async fn get_proofs_by_keyset_id(
@@ -948,68 +939,57 @@ WHERE keyset_id=?;
     ) -> Result<Vec<Option<State>>, Self::Err> {
         let mut transaction = self.pool.begin().await.map_err(Error::from)?;
 
-        let mut states = Vec::with_capacity(ys.len());
-
-        let proofs_state = proofs_state.to_string();
-        for y in ys {
-            let current_state;
-            let y = y.to_bytes().to_vec();
-            let rec = sqlx::query(
-                r#"
-SELECT state
-FROM proof
-WHERE y=?;
-        "#,
+        let sql = format!(
+            "SELECT y, state FROM proof WHERE y IN ({})",
+            "?,".repeat(ys.len()).trim_end_matches(',')
+        );
+
+        let mut current_states = ys
+            .iter()
+            .fold(sqlx::query(&sql), |query, y| {
+                query.bind(y.to_bytes().to_vec())
+            })
+            .fetch_all(&mut transaction)
+            .await
+            .map_err(|err| {
+                tracing::error!("SQLite could not get state of proof: {err:?}");
+                Error::SQLX(err)
+            })?
+            .into_iter()
+            .map(|row| {
+                PublicKey::from_slice(row.get("y"))
+                    .map_err(Error::from)
+                    .and_then(|y| {
+                        let state: String = row.get("state");
+                        State::from_str(&state)
+                            .map_err(Error::from)
+                            .map(|state| (y, state))
+                    })
+            })
+            .collect::<Result<HashMap<_, _>, _>>()?;
+
+        let update_sql = format!(
+            "UPDATE proof SET state = ? WHERE state != ? AND y IN ({})",
+            "?,".repeat(ys.len()).trim_end_matches(',')
+        );
+
+        ys.iter()
+            .fold(
+                sqlx::query(&update_sql)
+                    .bind(proofs_state.to_string())
+                    .bind(State::Spent.to_string()),
+                |query, y| query.bind(y.to_bytes().to_vec()),
             )
-            .bind(&y)
-            .fetch_one(&mut transaction)
-            .await;
-
-            match rec {
-                Ok(rec) => {
-                    let state: String = rec.get("state");
-                    current_state = Some(State::from_str(&state).map_err(Error::from)?);
-                }
-                Err(err) => match err {
-                    sqlx::Error::RowNotFound => {
-                        current_state = None;
-                    }
-                    _ => {
-                        tracing::error!("SQLite could not get state of proof");
-                        if let Err(err) = transaction.rollback().await {
-                            tracing::error!("Could not rollback sql transaction: {}", err);
-                        }
-                        return Err(Error::SQLX(err).into());
-                    }
-                },
-            };
-
-            states.push(current_state);
-
-            if current_state != Some(State::Spent) {
-                let res = sqlx::query(
-                    r#"
-        UPDATE proof SET state = ? WHERE y = ?
-        "#,
-                )
-                .bind(&proofs_state)
-                .bind(y)
-                .execute(&mut transaction)
-                .await;
-
-                if let Err(err) = res {
-                    tracing::error!("SQLite could not update proof state");
-                    if let Err(err) = transaction.rollback().await {
-                        tracing::error!("Could not rollback sql transaction: {}", err);
-                    }
-                    return Err(Error::SQLX(err).into());
-                }
-            }
-        }
+            .execute(&mut transaction)
+            .await
+            .map_err(|err| {
+                tracing::error!("SQLite could not update proof state: {err:?}");
+                Error::SQLX(err)
+            })?;
 
         transaction.commit().await.map_err(Error::from)?;
 
-        Ok(states)
+        Ok(ys.iter().map(|y| current_states.remove(y)).collect())
     }
 
     async fn add_blind_signatures(
@@ -1057,32 +1037,38 @@ VALUES (?, ?, ?, ?, ?, ?, ?);
     ) -> Result<Vec<Option<BlindSignature>>, Self::Err> {
         let mut transaction = self.pool.begin().await.map_err(Error::from)?;
 
-        let mut signatures = Vec::with_capacity(blinded_messages.len());
-
-        for message in blinded_messages {
-            let rec = sqlx::query(
-                r#"
-SELECT *
-FROM blind_signature
-WHERE y=?;
-        "#,
-            )
-            .bind(message.to_bytes().to_vec())
-            .fetch_one(&mut transaction)
-            .await;
-
-            if let Ok(row) = rec {
-                let blinded = sqlite_row_to_blind_signature(row)?;
-
-                signatures.push(Some(blinded));
-            } else {
-                signatures.push(None);
-            }
-        }
-
-        transaction.commit().await.map_err(Error::from)?;
+        let sql = format!(
+            "SELECT * FROM blind_signature WHERE y IN ({})",
+            "?,".repeat(blinded_messages.len()).trim_end_matches(',')
+        );
+
+        let mut blinded_signatures = blinded_messages
+            .iter()
+            .fold(sqlx::query(&sql), |query, y| {
+                query.bind(y.to_bytes().to_vec())
+            })
+            .fetch_all(&mut transaction)
+            .await
+            .map_err(|err| {
+                tracing::error!("SQLite could not get state of proof: {err:?}");
+                Error::SQLX(err)
+            })?
+            .into_iter()
+            .map(|row| {
+                PublicKey::from_slice(row.get("y"))
+                    .map_err(Error::from)
+                    .and_then(|y| {
+                        sqlite_row_to_blind_signature(row)
+                            .map_err(Error::from)
+                            .map(|blinded| (y, blinded))
+                    })
+            })
+            .collect::<Result<HashMap<_, _>, _>>()?;
 
-        Ok(signatures)
+        Ok(blinded_messages
+            .iter()
+            .map(|y| blinded_signatures.remove(y))
+            .collect())
     }
 
     async fn get_blind_signatures_for_keyset(