Quellcode durchsuchen

Add a simple placeholder parser

Introduced a better parser for SQL placeholders, to be usable for many
databases and to support named paramters but use positional
Cesar Rodas vor 1 Monat
Ursprung
Commit
67f4cf4708

+ 3 - 0
crates/cdk-common/src/database/mod.rs

@@ -150,6 +150,9 @@ pub enum Error {
     /// Unknown Quote
     #[error("Unknown Quote")]
     UnknownQuote,
+    /// Missing Placeholder value
+    #[error("Missing placeholder value {0}")]
+    MissingPlaceholder(String),
     /// Attempt to remove spent proof
     #[error("Attempt to remove spent proof")]
     AttemptRemoveSpentProof,

+ 6 - 6
crates/cdk-sql-base/src/common.rs

@@ -14,7 +14,7 @@ pub async fn migrate<C: DatabaseExecutor>(
                applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            )
            "#,
-    )
+    )?
     .execute(conn)
     .await?;
 
@@ -38,16 +38,16 @@ pub async fn migrate<C: DatabaseExecutor>(
 
     // Apply each migration if it hasn’t been applied yet
     for (name, sql) in migrations {
-        let is_missing = query("SELECT name FROM migrations WHERE name = :name")
-            .bind(":name", name)
+        let is_missing = query("SELECT name FROM migrations WHERE name = :name")?
+            .bind("name", name)
             .pluck(conn)
             .await?
             .is_none();
 
         if is_missing {
-            query(sql).batch(conn).await?;
-            query(r#"INSERT INTO migrations (name) VALUES (:name)"#)
-                .bind(":name", name)
+            query(sql)?.batch(conn).await?;
+            query(r#"INSERT INTO migrations (name) VALUES (:name)"#)?
+                .bind("name", name)
                 .execute(conn)
                 .await?;
         }

+ 1 - 1
crates/cdk-sql-base/src/database.rs

@@ -30,7 +30,7 @@ pub trait DatabaseExecutor: Debug + Sync + Send {
 /// Database transaction trait
 #[async_trait::async_trait]
 pub trait DatabaseTransaction<'a>: Debug + DatabaseExecutor + Send + Sync {
-    /// Consumes the current transaction comitting the changes
+    /// Consumes the current transaction committing the changes
     async fn commit(self) -> Result<(), Error>;
 
     /// Consumes the transaction rolling back all changes

+ 44 - 44
crates/cdk-sql-base/src/mint/auth/mod.rs

@@ -71,8 +71,8 @@ where
                 ELSE FALSE
             END;
             "#,
-        )
-        .bind(":id", id.to_string())
+        )?
+        .bind("id", id.to_string())
         .execute(&self.inner)
         .await?;
 
@@ -100,15 +100,15 @@ where
             max_order = excluded.max_order,
             derivation_path_index = excluded.derivation_path_index
         "#,
-        )
-        .bind(":id", keyset.id.to_string())
-        .bind(":unit", keyset.unit.to_string())
-        .bind(":active", keyset.active)
-        .bind(":valid_from", keyset.valid_from as i64)
-        .bind(":valid_to", keyset.final_expiry.map(|v| v as i64))
-        .bind(":derivation_path", keyset.derivation_path.to_string())
-        .bind(":max_order", keyset.max_order)
-        .bind(":derivation_path_index", keyset.derivation_path_index)
+        )?
+        .bind("id", keyset.id.to_string())
+        .bind("unit", keyset.unit.to_string())
+        .bind("active", keyset.active)
+        .bind("valid_from", keyset.valid_from as i64)
+        .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
+        .bind("derivation_path", keyset.derivation_path.to_string())
+        .bind("max_order", keyset.max_order)
+        .bind("derivation_path_index", keyset.derivation_path_index)
         .execute(&self.inner)
         .await?;
 
@@ -123,12 +123,12 @@ where
                 VALUES
                 (:y, :keyset_id, :secret, :c, :state)
                 "#,
-        )
-        .bind(":y", proof.y()?.to_bytes().to_vec())
-        .bind(":keyset_id", proof.keyset_id.to_string())
-        .bind(":secret", proof.secret.to_string())
-        .bind(":c", proof.c.to_bytes().to_vec())
-        .bind(":state", "UNSPENT".to_string())
+        )?
+        .bind("y", proof.y()?.to_bytes().to_vec())
+        .bind("keyset_id", proof.keyset_id.to_string())
+        .bind("secret", proof.secret.to_string())
+        .bind("c", proof.c.to_bytes().to_vec())
+        .bind("state", "UNSPENT".to_string())
         .execute(&self.inner)
         .await
         {
@@ -142,20 +142,20 @@ where
         y: &PublicKey,
         proofs_state: State,
     ) -> Result<Option<State>, Self::Err> {
-        let current_state = query(r#"SELECT state FROM proof WHERE y = :y"#)
-            .bind(":y", y.to_bytes().to_vec())
+        let current_state = query(r#"SELECT state FROM proof WHERE y = :y"#)?
+            .bind("y", y.to_bytes().to_vec())
             .pluck(&self.inner)
             .await?
             .map(|state| Ok::<_, Error>(column_as_string!(state, State::from_str)))
             .transpose()?;
 
-        query(r#"UPDATE proof SET state = :new_state WHERE state = :state AND y = :y"#)
-            .bind(":y", y.to_bytes().to_vec())
+        query(r#"UPDATE proof SET state = :new_state WHERE state = :state AND y = :y"#)?
+            .bind("y", y.to_bytes().to_vec())
             .bind(
-                ":state",
+                "state",
                 current_state.as_ref().map(|state| state.to_string()),
             )
-            .bind(":new_state", proofs_state.to_string())
+            .bind("new_state", proofs_state.to_string())
             .execute(&self.inner)
             .await?;
 
@@ -176,11 +176,11 @@ where
                        VALUES
                        (:y, :amount, :keyset_id, :c)
                    "#,
-            )
-            .bind(":y", message.to_bytes().to_vec())
-            .bind(":amount", u64::from(signature.amount) as i64)
-            .bind(":keyset_id", signature.keyset_id.to_string())
-            .bind(":c", signature.c.to_bytes().to_vec())
+            )?
+            .bind("y", message.to_bytes().to_vec())
+            .bind("amount", u64::from(signature.amount) as i64)
+            .bind("keyset_id", signature.keyset_id.to_string())
+            .bind("c", signature.c.to_bytes().to_vec())
             .execute(&self.inner)
             .await?;
         }
@@ -199,9 +199,9 @@ where
                  (endpoint, auth)
                  VALUES (:endpoint, :auth);
                  "#,
-            )
-            .bind(":endpoint", serde_json::to_string(endpoint)?)
-            .bind(":auth", serde_json::to_string(auth)?)
+            )?
+            .bind("endpoint", serde_json::to_string(endpoint)?)
+            .bind("auth", serde_json::to_string(auth)?)
             .execute(&self.inner)
             .await
             {
@@ -218,9 +218,9 @@ where
         &mut self,
         protected_endpoints: Vec<ProtectedEndpoint>,
     ) -> Result<(), database::Error> {
-        query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)
+        query(r#"DELETE FROM protected_endpoints WHERE endpoint IN (:endpoints)"#)?
             .bind_vec(
-                ":endpoints",
+                "endpoints",
                 protected_endpoints
                     .iter()
                     .map(serde_json::to_string)
@@ -259,7 +259,7 @@ where
             WHERE
                 active = 1;
             "#,
-        )
+        )?
         .pluck(&self.db)
         .await?
         .map(|id| Ok::<_, Error>(column_as_string!(id, Id::from_str, Id::from_bytes)))
@@ -281,8 +281,8 @@ where
             FROM
                 keyset
                 WHERE id=:id"#,
-        )
-        .bind(":id", id.to_string())
+        )?
+        .bind("id", id.to_string())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_keyset_info)
@@ -304,7 +304,7 @@ where
             FROM
                 keyset
                 WHERE id=:id"#,
-        )
+        )?
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -313,8 +313,8 @@ where
     }
 
     async fn get_proofs_states(&self, ys: &[PublicKey]) -> Result<Vec<Option<State>>, Self::Err> {
-        let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)
-            .bind_vec(":ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
+        let mut current_states = query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
+            .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
             .fetch_all(&self.db)
             .await?
             .into_iter()
@@ -345,9 +345,9 @@ where
                 blind_signature
             WHERE y IN (:y)
             "#,
-        )
+        )?
         .bind_vec(
-            ":y",
+            "y",
             blinded_messages
                 .iter()
                 .map(|y| y.to_bytes().to_vec())
@@ -378,8 +378,8 @@ where
         protected_endpoint: ProtectedEndpoint,
     ) -> Result<Option<AuthRequired>, Self::Err> {
         Ok(
-            query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)
-                .bind(":endpoint", serde_json::to_string(&protected_endpoint)?)
+            query(r#"SELECT auth FROM protected_endpoints WHERE endpoint = :endpoint"#)?
+                .bind("endpoint", serde_json::to_string(&protected_endpoint)?)
                 .pluck(&self.db)
                 .await?
                 .map(|auth| {
@@ -396,7 +396,7 @@ where
     async fn get_auth_for_endpoints(
         &self,
     ) -> Result<HashMap<ProtectedEndpoint, Option<AuthRequired>>, Self::Err> {
-        Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)
+        Ok(query(r#"SELECT endpoint, auth FROM protected_endpoints"#)?
             .fetch_all(&self.db)
             .await?
             .into_iter()

+ 146 - 146
crates/cdk-sql-base/src/mint/mod.rs

@@ -78,8 +78,8 @@ async fn get_current_states<C>(
 where
     C: DatabaseExecutor + Send + Sync,
 {
-    query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)
-        .bind_vec(":ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
+    query(r#"SELECT y, state FROM proof WHERE y IN (:ys)"#)?
+        .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
         .fetch_all(conn)
         .await?
         .into_iter()
@@ -103,9 +103,9 @@ where
         INSERT INTO config (id, value) VALUES (:id, :value)
             ON CONFLICT(id) DO UPDATE SET value = excluded.value
             "#,
-    )
-    .bind(":id", id.to_owned())
-    .bind(":value", serde_json::to_string(&value)?)
+    )?
+    .bind("id", id.to_owned())
+    .bind("value", serde_json::to_string(&value)?)
     .execute(conn)
     .await?;
 
@@ -139,8 +139,8 @@ where
     where
         R: serde::de::DeserializeOwned,
     {
-        let value = column_as_string!(query(r#"SELECT value FROM config WHERE id = :id LIMIT 1"#)
-            .bind(":id", id.to_owned())
+        let value = column_as_string!(query(r#"SELECT value FROM config WHERE id = :id LIMIT 1"#)?
+            .bind("id", id.to_owned())
             .pluck(&self.db)
             .await?
             .ok_or(Error::UnknownQuoteTTL)?);
@@ -206,16 +206,16 @@ where
             input_fee_ppk = excluded.input_fee_ppk,
             derivation_path_index = excluded.derivation_path_index
         "#,
-        )
-        .bind(":id", keyset.id.to_string())
-        .bind(":unit", keyset.unit.to_string())
-        .bind(":active", keyset.active)
-        .bind(":valid_from", keyset.valid_from as i64)
-        .bind(":valid_to", keyset.final_expiry.map(|v| v as i64))
-        .bind(":derivation_path", keyset.derivation_path.to_string())
-        .bind(":max_order", keyset.max_order)
-        .bind(":input_fee_ppk", keyset.input_fee_ppk as i64)
-        .bind(":derivation_path_index", keyset.derivation_path_index)
+        )?
+        .bind("id", keyset.id.to_string())
+        .bind("unit", keyset.unit.to_string())
+        .bind("active", keyset.active)
+        .bind("valid_from", keyset.valid_from as i64)
+        .bind("valid_to", keyset.final_expiry.map(|v| v as i64))
+        .bind("derivation_path", keyset.derivation_path.to_string())
+        .bind("max_order", keyset.max_order)
+        .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
+        .bind("derivation_path_index", keyset.derivation_path_index)
         .execute(&self.inner)
         .await?;
 
@@ -223,14 +223,14 @@ where
     }
 
     async fn set_active_keyset(&mut self, unit: CurrencyUnit, id: Id) -> Result<(), Error> {
-        query(r#"UPDATE keyset SET active=FALSE WHERE unit IS :unit"#)
-            .bind(":unit", unit.to_string())
+        query(r#"UPDATE keyset SET active=FALSE WHERE unit IS :unit"#)?
+            .bind("unit", unit.to_string())
             .execute(&self.inner)
             .await?;
 
-        query(r#"UPDATE keyset SET active=TRUE WHERE unit IS :unit AND id IS :id"#)
-            .bind(":unit", unit.to_string())
-            .bind(":id", id.to_string())
+        query(r#"UPDATE keyset SET active=TRUE WHERE unit IS :unit AND id IS :id"#)?
+            .bind("unit", unit.to_string())
+            .bind("id", id.to_string())
             .execute(&self.inner)
             .await?;
 
@@ -256,8 +256,8 @@ where
 
     async fn get_active_keyset_id(&self, unit: &CurrencyUnit) -> Result<Option<Id>, Self::Err> {
         Ok(
-            query(r#" SELECT id FROM keyset WHERE active = 1 AND unit IS :unit"#)
-                .bind(":unit", unit.to_string())
+            query(r#" SELECT id FROM keyset WHERE active = 1 AND unit IS :unit"#)?
+                .bind("unit", unit.to_string())
                 .pluck(&self.db)
                 .await?
                 .map(|id| match id {
@@ -270,7 +270,7 @@ where
     }
 
     async fn get_active_keysets(&self) -> Result<HashMap<CurrencyUnit, Id>, Self::Err> {
-        Ok(query(r#"SELECT id, unit FROM keyset WHERE active = 1"#)
+        Ok(query(r#"SELECT id, unit FROM keyset WHERE active = 1"#)?
             .fetch_all(&self.db)
             .await?
             .into_iter()
@@ -298,8 +298,8 @@ where
             FROM
                 keyset
                 WHERE id=:id"#,
-        )
-        .bind(":id", id.to_string())
+        )?
+        .bind("id", id.to_string())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_keyset_info)
@@ -321,7 +321,7 @@ where
             FROM
                 keyset
             "#,
-        )
+        )?
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -349,18 +349,18 @@ where
                     :pubkey, :created_time, :paid_time, :issued_time
                 )
             "#,
-        )
-        .bind(":id", quote.id.to_string())
-        .bind(":amount", u64::from(quote.amount) as i64)
-        .bind(":unit", quote.unit.to_string())
-        .bind(":request", quote.request)
-        .bind(":state", quote.state.to_string())
-        .bind(":expiry", quote.expiry as i64)
-        .bind(":request_lookup_id", quote.request_lookup_id)
-        .bind(":pubkey", quote.pubkey.map(|p| p.to_string()))
-        .bind(":created_time", quote.created_time as i64)
-        .bind(":paid_time", quote.paid_time.map(|t| t as i64))
-        .bind(":issued_time", quote.issued_time.map(|t| t as i64))
+        )?
+        .bind("id", quote.id.to_string())
+        .bind("amount", u64::from(quote.amount) as i64)
+        .bind("unit", quote.unit.to_string())
+        .bind("request", quote.request)
+        .bind("state", quote.state.to_string())
+        .bind("expiry", quote.expiry as i64)
+        .bind("request_lookup_id", quote.request_lookup_id)
+        .bind("pubkey", quote.pubkey.map(|p| p.to_string()))
+        .bind("created_time", quote.created_time as i64)
+        .bind("paid_time", quote.paid_time.map(|t| t as i64))
+        .bind("issued_time", quote.issued_time.map(|t| t as i64))
         .execute(&self.inner)
         .await?;
 
@@ -368,8 +368,8 @@ where
     }
 
     async fn remove_mint_quote(&mut self, quote_id: &Uuid) -> Result<(), Self::Err> {
-        query(r#"DELETE FROM mint_quote WHERE id=:id"#)
-            .bind(":id", quote_id.as_hyphenated().to_string())
+        query(r#"DELETE FROM mint_quote WHERE id=:id"#)?
+            .bind("id", quote_id.as_hyphenated().to_string())
             .execute(&self.inner)
             .await?;
         Ok(())
@@ -385,10 +385,10 @@ where
             AND state = :state
             AND expiry < :current_time
             "#,
-        )
-        .bind(":request_lookup_id", quote.request_lookup_id.to_string())
-        .bind(":state", MeltQuoteState::Unpaid.to_string())
-        .bind(":current_time", current_time as i64)
+        )?
+        .bind("request_lookup_id", quote.request_lookup_id.to_string())
+        .bind("state", MeltQuoteState::Unpaid.to_string())
+        .bind("current_time", current_time as i64)
         .execute(&self.inner)
         .await?;
 
@@ -412,22 +412,22 @@ where
                 :created_time, :paid_time
             )
         "#,
-        )
-        .bind(":id", quote.id.to_string())
-        .bind(":unit", quote.unit.to_string())
-        .bind(":amount", u64::from(quote.amount) as i64)
-        .bind(":request", quote.request)
-        .bind(":fee_reserve", u64::from(quote.fee_reserve) as i64)
-        .bind(":state", quote.state.to_string())
-        .bind(":expiry", quote.expiry as i64)
-        .bind(":payment_preimage", quote.payment_preimage)
-        .bind(":request_lookup_id", quote.request_lookup_id)
+        )?
+        .bind("id", quote.id.to_string())
+        .bind("unit", quote.unit.to_string())
+        .bind("amount", u64::from(quote.amount) as i64)
+        .bind("request", quote.request)
+        .bind("fee_reserve", u64::from(quote.fee_reserve) as i64)
+        .bind("state", quote.state.to_string())
+        .bind("expiry", quote.expiry as i64)
+        .bind("payment_preimage", quote.payment_preimage)
+        .bind("request_lookup_id", quote.request_lookup_id)
         .bind(
-            ":msat_to_pay",
+            "msat_to_pay",
             quote.msat_to_pay.map(|a| u64::from(a) as i64),
         )
-        .bind(":created_time", quote.created_time as i64)
-        .bind(":paid_time", quote.paid_time.map(|t| t as i64))
+        .bind("created_time", quote.created_time as i64)
+        .bind("paid_time", quote.paid_time.map(|t| t as i64))
         .execute(&self.inner)
         .await?;
 
@@ -439,9 +439,9 @@ where
         quote_id: &Uuid,
         new_request_lookup_id: &str,
     ) -> Result<(), Self::Err> {
-        query(r#"UPDATE melt_quote SET request_lookup_id = :new_req_id WHERE id = :id"#)
-            .bind(":new_req_id", new_request_lookup_id.to_owned())
-            .bind(":id", quote_id.as_hyphenated().to_string())
+        query(r#"UPDATE melt_quote SET request_lookup_id = :new_req_id WHERE id = :id"#)?
+            .bind("new_req_id", new_request_lookup_id.to_owned())
+            .bind("id", quote_id.as_hyphenated().to_string())
             .execute(&self.inner)
             .await?;
         Ok(())
@@ -473,9 +473,9 @@ where
                 id=:id
                 AND state != :state
             "#,
-        )
-        .bind(":id", quote_id.as_hyphenated().to_string())
-        .bind(":state", state.to_string())
+        )?
+        .bind("id", quote_id.as_hyphenated().to_string())
+        .bind("state", state.to_string())
         .fetch_one(&self.inner)
         .await?
         .map(sql_row_to_melt_quote)
@@ -484,16 +484,16 @@ where
 
         let rec = if state == MeltQuoteState::Paid {
             let current_time = unix_time();
-            query(r#"UPDATE melt_quote SET state = :state, paid_time = :paid_time WHERE id = :id"#)
-                .bind(":state", state.to_string())
-                .bind(":paid_time", current_time as i64)
-                .bind(":id", quote_id.as_hyphenated().to_string())
+            query(r#"UPDATE melt_quote SET state = :state, paid_time = :paid_time WHERE id = :id"#)?
+                .bind("state", state.to_string())
+                .bind("paid_time", current_time as i64)
+                .bind("id", quote_id.as_hyphenated().to_string())
                 .execute(&self.inner)
                 .await
         } else {
-            query(r#"UPDATE melt_quote SET state = :state WHERE id = :id"#)
-                .bind(":state", state.to_string())
-                .bind(":id", quote_id.as_hyphenated().to_string())
+            query(r#"UPDATE melt_quote SET state = :state WHERE id = :id"#)?
+                .bind("state", state.to_string())
+                .bind("id", quote_id.as_hyphenated().to_string())
                 .execute(&self.inner)
                 .await
         };
@@ -518,8 +518,8 @@ where
             DELETE FROM melt_quote
             WHERE id=?
             "#,
-        )
-        .bind(":id", quote_id.as_hyphenated().to_string())
+        )?
+        .bind("id", quote_id.as_hyphenated().to_string())
         .execute(&self.inner)
         .await?;
 
@@ -548,8 +548,8 @@ where
             FROM
                 mint_quote
             WHERE id = :id"#,
-        )
-        .bind(":id", quote_id.as_hyphenated().to_string())
+        )?
+        .bind("id", quote_id.as_hyphenated().to_string())
         .fetch_one(&self.inner)
         .await?
         .map(sql_row_to_mint_quote)
@@ -568,17 +568,17 @@ where
         let current_time = unix_time();
 
         let update = match state {
-            MintQuoteState::Paid => query(update_query)
-                .bind(":state", state.to_string())
-                .bind(":current_time", current_time as i64)
-                .bind(":quote_id", quote_id.as_hyphenated().to_string()),
-            MintQuoteState::Issued => query(update_query)
-                .bind(":state", state.to_string())
-                .bind(":current_time", current_time as i64)
-                .bind(":quote_id", quote_id.as_hyphenated().to_string()),
-            _ => query(update_query)
-                .bind(":state", state.to_string())
-                .bind(":quote_id", quote_id.as_hyphenated().to_string()),
+            MintQuoteState::Paid => query(update_query)?
+                .bind("state", state.to_string())
+                .bind("current_time", current_time as i64)
+                .bind("quote_id", quote_id.as_hyphenated().to_string()),
+            MintQuoteState::Issued => query(update_query)?
+                .bind("state", state.to_string())
+                .bind("current_time", current_time as i64)
+                .bind("quote_id", quote_id.as_hyphenated().to_string()),
+            _ => query(update_query)?
+                .bind("state", state.to_string())
+                .bind("quote_id", quote_id.as_hyphenated().to_string()),
         };
 
         match update.execute(&self.inner).await {
@@ -609,8 +609,8 @@ where
             FROM
                 mint_quote
             WHERE id = :id"#,
-        )
-        .bind(":id", quote_id.as_hyphenated().to_string())
+        )?
+        .bind("id", quote_id.as_hyphenated().to_string())
         .fetch_one(&self.inner)
         .await?
         .map(sql_row_to_mint_quote)
@@ -641,8 +641,8 @@ where
             WHERE
                 id=:id
             "#,
-        )
-        .bind(":id", quote_id.as_hyphenated().to_string())
+        )?
+        .bind("id", quote_id.as_hyphenated().to_string())
         .fetch_one(&self.inner)
         .await?
         .map(sql_row_to_melt_quote)
@@ -670,8 +670,8 @@ where
             FROM
                 mint_quote
             WHERE request = :request"#,
-        )
-        .bind(":request", request.to_owned())
+        )?
+        .bind("request", request.to_owned())
         .fetch_one(&self.inner)
         .await?
         .map(sql_row_to_mint_quote)
@@ -704,8 +704,8 @@ where
             FROM
                 mint_quote
             WHERE id = :id"#,
-        )
-        .bind(":id", quote_id.as_hyphenated().to_string())
+        )?
+        .bind("id", quote_id.as_hyphenated().to_string())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_mint_quote)
@@ -733,8 +733,8 @@ where
             FROM
                 mint_quote
             WHERE request = :request"#,
-        )
-        .bind(":request", request.to_owned())
+        )?
+        .bind("request", request.to_owned())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_mint_quote)
@@ -762,8 +762,8 @@ where
             FROM
                 mint_quote
             WHERE request_lookup_id = :request_lookup_id"#,
-        )
-        .bind(":request_lookup_id", request_lookup_id.to_owned())
+        )?
+        .bind("request_lookup_id", request_lookup_id.to_owned())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_mint_quote)
@@ -788,7 +788,7 @@ where
                    FROM
                        mint_quote
                   "#,
-        )
+        )?
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -819,8 +819,8 @@ where
                     WHERE
                         state = :state
                   "#,
-        )
-        .bind(":state", state.to_string())
+        )?
+        .bind("state", state.to_string())
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -849,8 +849,8 @@ where
             WHERE
                 id=:id
             "#,
-        )
-        .bind(":id", quote_id.as_hyphenated().to_string())
+        )?
+        .bind("id", quote_id.as_hyphenated().to_string())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_melt_quote)
@@ -876,7 +876,7 @@ where
             FROM
                 melt_quote
             "#,
-        )
+        )?
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -901,9 +901,9 @@ where
 
         // Check any previous proof, this query should return None in order to proceed storing
         // Any result here would error
-        match query(r#"SELECT state FROM proof WHERE y IN (:ys) LIMIT 1"#)
+        match query(r#"SELECT state FROM proof WHERE y IN (:ys) LIMIT 1"#)?
             .bind_vec(
-                ":ys",
+                "ys",
                 proofs
                     .iter()
                     .map(|y| y.y().map(|y| y.to_bytes().to_vec()))
@@ -927,19 +927,19 @@ where
                   VALUES
                   (:y, :amount, :keyset_id, :secret, :c, :witness, :state, :quote_id, :created_time)
                   "#,
-            )
-            .bind(":y", proof.y()?.to_bytes().to_vec())
-            .bind(":amount", u64::from(proof.amount) as i64)
-            .bind(":keyset_id", proof.keyset_id.to_string())
-            .bind(":secret", proof.secret.to_string())
-            .bind(":c", proof.c.to_bytes().to_vec())
+            )?
+            .bind("y", proof.y()?.to_bytes().to_vec())
+            .bind("amount", u64::from(proof.amount) as i64)
+            .bind("keyset_id", proof.keyset_id.to_string())
+            .bind("secret", proof.secret.to_string())
+            .bind("c", proof.c.to_bytes().to_vec())
             .bind(
-                ":witness",
+                "witness",
                 proof.witness.map(|w| serde_json::to_string(&w).unwrap()),
             )
-            .bind(":state", "UNSPENT".to_string())
-            .bind(":quote_id", quote_id.map(|q| q.hyphenated().to_string()))
-            .bind(":created_time", current_time as i64)
+            .bind("state", "UNSPENT".to_string())
+            .bind("quote_id", quote_id.map(|q| q.hyphenated().to_string()))
+            .bind("created_time", current_time as i64)
             .execute(&self.inner)
             .await?;
         }
@@ -967,9 +967,9 @@ where
             check_state_transition(*state, new_state)?;
         }
 
-        query(r#"UPDATE proof SET state = :new_state WHERE y IN (:ys)"#)
-            .bind(":new_state", new_state.to_string())
-            .bind_vec(":ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
+        query(r#"UPDATE proof SET state = :new_state WHERE y IN (:ys)"#)?
+            .bind("new_state", new_state.to_string())
+            .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
             .execute(&self.inner)
             .await?;
 
@@ -985,9 +985,9 @@ where
             r#"
             DELETE FROM proof WHERE y IN (:ys) AND state NOT IN (:exclude_state)
             "#,
-        )
-        .bind_vec(":ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
-        .bind_vec(":exclude_state", vec![State::Spent.to_string()])
+        )?
+        .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
+        .bind_vec("exclude_state", vec![State::Spent.to_string()])
         .execute(&self.inner)
         .await?;
 
@@ -1021,8 +1021,8 @@ where
             WHERE
                 y IN (:ys)
             "#,
-        )
-        .bind_vec(":ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
+        )?
+        .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -1055,8 +1055,8 @@ where
             WHERE
                 quote_id = :quote_id
             "#,
-        )
-        .bind(":quote_id", quote_id.as_hyphenated().to_string())
+        )?
+        .bind("quote_id", quote_id.as_hyphenated().to_string())
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -1089,8 +1089,8 @@ where
             WHERE
                 keyset_id=?
             "#,
-        )
-        .bind(":keyset_id", keyset_id.to_string())
+        )?
+        .bind("keyset_id", keyset_id.to_string())
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -1124,21 +1124,21 @@ where
                     VALUES
                     (:blinded_message, :amount, :keyset_id, :c, :quote_id, :dleq_e, :dleq_s, :created_time)
                 "#,
-            )
-            .bind(":blinded_message", message.to_bytes().to_vec())
-            .bind(":amount", u64::from(signature.amount) as i64)
-            .bind(":keyset_id", signature.keyset_id.to_string())
-            .bind(":c", signature.c.to_bytes().to_vec())
-            .bind(":quote_id", quote_id.map(|q| q.hyphenated().to_string()))
+            )?
+            .bind("blinded_message", message.to_bytes().to_vec())
+            .bind("amount", u64::from(signature.amount) as i64)
+            .bind("keyset_id", signature.keyset_id.to_string())
+            .bind("c", signature.c.to_bytes().to_vec())
+            .bind("quote_id", quote_id.map(|q| q.hyphenated().to_string()))
             .bind(
-                ":dleq_e",
+                "dleq_e",
                 signature.dleq.as_ref().map(|dleq| dleq.e.to_secret_hex()),
             )
             .bind(
-                ":dleq_s",
+                "dleq_s",
                 signature.dleq.as_ref().map(|dleq| dleq.s.to_secret_hex()),
             )
-            .bind(":created_time", current_time as i64)
+            .bind("created_time", current_time as i64)
             .execute(&self.inner)
             .await?;
         }
@@ -1162,9 +1162,9 @@ where
                 blind_signature
             WHERE blinded_message IN (:y)
             "#,
-        )
+        )?
         .bind_vec(
-            ":y",
+            "y",
             blinded_messages
                 .iter()
                 .map(|y| y.to_bytes().to_vec())
@@ -1214,9 +1214,9 @@ where
                 blind_signature
             WHERE blinded_message IN (:blinded_message)
             "#,
-        )
+        )?
         .bind_vec(
-            ":blinded_message",
+            "blinded_message",
             blinded_messages
                 .iter()
                 .map(|b_| b_.to_bytes().to_vec())
@@ -1259,8 +1259,8 @@ where
             WHERE
                 keyset_id=:keyset_id
             "#,
-        )
-        .bind(":keyset_id", keyset_id.to_string())
+        )?
+        .bind("keyset_id", keyset_id.to_string())
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -1286,8 +1286,8 @@ where
             WHERE
                 quote_id=:quote_id
             "#,
-        )
-        .bind(":quote_id", quote_id.to_string())
+        )?
+        .bind("quote_id", quote_id.to_string())
         .fetch_all(&self.db)
         .await?
         .into_iter()

+ 182 - 50
crates/cdk-sql-base/src/stmt.rs

@@ -1,5 +1,5 @@
 //! Stataments mod
-use std::collections::HashMap;
+use std::sync::Arc;
 
 use cdk_common::database::Error;
 
@@ -25,28 +25,173 @@ pub enum ExpectedSqlResponse {
     Batch,
 }
 
+/// Part value
+#[derive(Debug, Clone)]
+pub enum PlaceholderValue {
+    /// Value
+    Value(Value),
+    /// Set
+    Set(Vec<Value>),
+}
+
+impl From<Value> for PlaceholderValue {
+    fn from(value: Value) -> Self {
+        PlaceholderValue::Value(value)
+    }
+}
+
+impl From<Vec<Value>> for PlaceholderValue {
+    fn from(value: Vec<Value>) -> Self {
+        PlaceholderValue::Set(value)
+    }
+}
+
+/// SQL Part
+#[derive(Debug, Clone)]
+pub enum SqlPart {
+    /// Raw SQL statement
+    Raw(Arc<str>),
+    /// Placeholder
+    Placeholder(Arc<str>, Option<PlaceholderValue>),
+}
+
+/// SQL parser error
+#[derive(Debug, PartialEq, thiserror::Error)]
+pub enum SqlParseError {
+    /// Invalid SQL
+    #[error("Unterminated String literal")]
+    UnterminatedStringLiteral,
+    /// Invalid placeholder name
+    #[error("Invalid placeholder name")]
+    InvalidPlaceholder,
+}
+
+/// Rudimentary SQL parser.
+///
+/// This function does not validate the SQL statement, it only extracts the placeholder to be
+/// database agnostic.
+pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
+    let mut parts = Vec::new();
+    let mut current = String::new();
+    let mut chars = input.chars().peekable();
+
+    while let Some(&c) = chars.peek() {
+        match c {
+            '\'' | '"' => {
+                // Start of string literal
+                let quote = c;
+                current.push(chars.next().unwrap());
+
+                let mut closed = false;
+                while let Some(&next) = chars.peek() {
+                    current.push(chars.next().unwrap());
+
+                    if next == quote {
+                        if chars.peek() == Some(&quote) {
+                            // Escaped quote (e.g. '' inside strings)
+                            current.push(chars.next().unwrap());
+                        } else {
+                            closed = true;
+                            break;
+                        }
+                    }
+                }
+
+                if !closed {
+                    return Err(SqlParseError::UnterminatedStringLiteral);
+                }
+            }
+
+            ':' => {
+                // Flush current raw SQL
+                if !current.is_empty() {
+                    parts.push(SqlPart::Raw(current.clone().into()));
+                    current.clear();
+                }
+
+                chars.next(); // consume ':'
+                let mut name = String::new();
+
+                while let Some(&next) = chars.peek() {
+                    if next.is_alphanumeric() || next == '_' {
+                        name.push(chars.next().unwrap());
+                    } else {
+                        break;
+                    }
+                }
+
+                if name.is_empty() {
+                    return Err(SqlParseError::InvalidPlaceholder);
+                }
+
+                parts.push(SqlPart::Placeholder(name.into(), None));
+            }
+
+            _ => {
+                current.push(chars.next().unwrap());
+            }
+        }
+    }
+
+    if !current.is_empty() {
+        parts.push(SqlPart::Raw(current.into()));
+    }
+
+    Ok(parts)
+}
+
 /// Sql message
 #[derive(Debug, Default)]
 pub struct Statement {
     /// The SQL statement
-    pub sql: String,
-    /// The list of arguments for the placeholders. It only supports named arguments for simplicity
-    /// sake
-    pub args: HashMap<String, Value>,
+    pub parts: Vec<SqlPart>,
     /// The expected response type
     pub expected_response: ExpectedSqlResponse,
 }
 
 impl Statement {
     /// Creates a new statement
-    pub fn new<T>(sql: T) -> Self
-    where
-        T: ToString,
-    {
-        Self {
-            sql: sql.to_string(),
+    pub fn new(sql: &str) -> Result<Self, SqlParseError> {
+        Ok(Self {
+            parts: split_sql_parts(sql)?,
             ..Default::default()
-        }
+        })
+    }
+
+    /// Convert Statement into a SQL statement and the list of placeholders
+    ///
+    /// By default it converts the statement into placeholder using $1..$n placeholders which seems
+    /// to be more widely supported, although it can be reimplemented with other formats since part
+    /// is public
+    pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
+        let mut placeholder_values = Vec::new();
+        let sql = self
+            .parts
+            .into_iter()
+            .map(|x| match x {
+                SqlPart::Placeholder(name, value) => {
+                    match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
+                        PlaceholderValue::Value(value) => {
+                            placeholder_values.push(value);
+                            Ok::<_, Error>(format!("${}", placeholder_values.len()))
+                        }
+                        PlaceholderValue::Set(mut values) => {
+                            let start_size = placeholder_values.len();
+                            placeholder_values.append(&mut values);
+                            let placeholders = (start_size + 1..=placeholder_values.len())
+                                .map(|i| format!("${i}"))
+                                .collect::<Vec<_>>()
+                                .join(", ");
+                            Ok(placeholders)
+                        }
+                    }
+                }
+                SqlPart::Raw(raw) => Ok(raw.to_string()),
+            })
+            .collect::<Result<Vec<String>, _>>()?
+            .join("");
+
+        Ok((sql, placeholder_values))
     }
 
     /// Binds a given placeholder to a value.
@@ -56,7 +201,18 @@ impl Statement {
         C: ToString,
         V: Into<Value>,
     {
-        self.args.insert(name.to_string(), value.into());
+        let name = name.to_string();
+        let value = value.into();
+        let value: PlaceholderValue = value.into();
+
+        for part in self.parts.iter_mut() {
+            if let SqlPart::Placeholder(part_name, part_value) = part {
+                if **part_name == *name.as_str() {
+                    *part_value = Some(value.clone());
+                }
+            }
+        }
+
         self
     }
 
@@ -70,42 +226,21 @@ impl Statement {
         C: ToString,
         V: Into<Value>,
     {
-        let mut new_sql = String::with_capacity(self.sql.len());
-        let target = name.to_string();
-        let mut i = 0;
-
-        let placeholders = value
+        let name = name.to_string();
+        let value: PlaceholderValue = value
             .into_iter()
-            .enumerate()
-            .map(|(key, value)| {
-                let key = format!("{target}{key}");
-                self.args.insert(key.clone(), value.into());
-                key
-            })
-            .collect::<Vec<_>>()
-            .join(",");
-
-        while let Some(pos) = self.sql[i..].find(&target) {
-            let abs_pos = i + pos;
-            let after = abs_pos + target.len();
-            let is_word_boundary = self.sql[after..]
-                .chars()
-                .next()
-                .map_or(true, |c| !c.is_alphanumeric() && c != '_');
-
-            if is_word_boundary {
-                new_sql.push_str(&self.sql[i..abs_pos]);
-                new_sql.push_str(&placeholders);
-                i = after;
-            } else {
-                new_sql.push_str(&self.sql[i..=abs_pos]);
-                i = abs_pos + 1;
+            .map(|x| x.into())
+            .collect::<Vec<Value>>()
+            .into();
+
+        for part in self.parts.iter_mut() {
+            if let SqlPart::Placeholder(part_name, part_value) = part {
+                if **part_name == *name.as_str() {
+                    *part_value = Some(value.clone());
+                }
             }
         }
 
-        new_sql.push_str(&self.sql[i..]);
-
-        self.sql = new_sql;
         self
     }
 
@@ -152,9 +287,6 @@ impl Statement {
 
 /// Creates a new query statement
 #[inline(always)]
-pub fn query<T>(sql: T) -> Statement
-where
-    T: ToString,
-{
-    Statement::new(sql)
+pub fn query(sql: &str) -> Result<Statement, Error> {
+    Statement::new(sql).map_err(|e| Error::Database(Box::new(e)))
 }

+ 107 - 224
crates/cdk-sql-base/src/wallet/mod.rs

@@ -147,20 +147,20 @@ ON CONFLICT(mint_url) DO UPDATE SET
     tos_url = excluded.tos_url
 ;
         "#,
-        )
-        .bind(":mint_url", mint_url.to_string())
-        .bind(":name", name)
-        .bind(":pubkey", pubkey)
-        .bind(":version", version)
-        .bind(":description", description)
-        .bind(":description_long", description_long)
-        .bind(":contact", contact)
-        .bind(":nuts", nuts)
-        .bind(":icon_url", icon_url)
-        .bind(":urls", urls)
-        .bind(":motd", motd)
-        .bind(":mint_time", time.map(|v| v as i64))
-        .bind(":tos_url", tos_url)
+        )?
+        .bind("mint_url", mint_url.to_string())
+        .bind("name", name)
+        .bind("pubkey", pubkey)
+        .bind("version", version)
+        .bind("description", description)
+        .bind("description_long", description_long)
+        .bind("contact", contact)
+        .bind("nuts", nuts)
+        .bind("icon_url", icon_url)
+        .bind("urls", urls)
+        .bind("motd", motd)
+        .bind("mint_time", time.map(|v| v as i64))
+        .bind("tos_url", tos_url)
         .execute(&self.db)
         .await?;
 
@@ -169,8 +169,8 @@ ON CONFLICT(mint_url) DO UPDATE SET
 
     #[instrument(skip(self))]
     async fn remove_mint(&self, mint_url: MintUrl) -> Result<(), Self::Err> {
-        query(r#"DELETE FROM mint WHERE mint_url=:mint_url"#)
-            .bind(":mint_url", mint_url.to_string())
+        query(r#"DELETE FROM mint WHERE mint_url=:mint_url"#)?
+            .bind("mint_url", mint_url.to_string())
             .execute(&self.db)
             .await?;
 
@@ -198,8 +198,8 @@ ON CONFLICT(mint_url) DO UPDATE SET
                 mint
             WHERE mint_url = :mint_url
             "#,
-        )
-        .bind(":mint_url", mint_url.to_string())
+        )?
+        .bind("mint_url", mint_url.to_string())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_mint_info)
@@ -227,7 +227,7 @@ ON CONFLICT(mint_url) DO UPDATE SET
                 FROM
                     mint
                 "#,
-        )
+        )?
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -259,9 +259,9 @@ ON CONFLICT(mint_url) DO UPDATE SET
             "#
             );
 
-            query(&str_query)
-                .bind(":new_mint_url", new_mint_url.to_string())
-                .bind(":old_mint_url", old_mint_url.to_string())
+            query(&str_query)?
+                .bind("new_mint_url", new_mint_url.to_string())
+                .bind("old_mint_url", old_mint_url.to_string())
                 .execute(&self.db)
                 .await?;
         }
@@ -289,13 +289,13 @@ ON CONFLICT(mint_url) DO UPDATE SET
         input_fee_ppk = excluded.input_fee_ppk,
         final_expiry = excluded.final_expiry;
     "#,
-            )
-            .bind(":mint_url", mint_url.to_string())
-            .bind(":id", keyset.id.to_string())
-            .bind(":unit", keyset.unit.to_string())
-            .bind(":active", keyset.active)
-            .bind(":input_fee_ppk", keyset.input_fee_ppk as i64)
-            .bind(":final_expiry", keyset.final_expiry.map(|v| v as i64))
+            )?
+            .bind("mint_url", mint_url.to_string())
+            .bind("id", keyset.id.to_string())
+            .bind("unit", keyset.unit.to_string())
+            .bind("active", keyset.active)
+            .bind("input_fee_ppk", keyset.input_fee_ppk as i64)
+            .bind("final_expiry", keyset.final_expiry.map(|v| v as i64))
             .execute(&self.db)
             .await?;
         }
@@ -320,8 +320,8 @@ ON CONFLICT(mint_url) DO UPDATE SET
                 keyset
             WHERE mint_url = :mint_url
             "#,
-        )
-        .bind(":mint_url", mint_url.to_string())
+        )?
+        .bind("mint_url", mint_url.to_string())
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -348,8 +348,8 @@ ON CONFLICT(mint_url) DO UPDATE SET
                 keyset
             WHERE id = :id
             "#,
-        )
-        .bind(":id", keyset_id.to_string())
+        )?
+        .bind("id", keyset_id.to_string())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_keyset)
@@ -374,15 +374,15 @@ ON CONFLICT(id) DO UPDATE SET
     secret_key = excluded.secret_key
 ;
         "#,
-        )
-        .bind(":id", quote.id.to_string())
-        .bind(":mint_url", quote.mint_url.to_string())
-        .bind(":amount", u64::from(quote.amount) as i64)
-        .bind(":unit", quote.unit.to_string())
-        .bind(":request", quote.request)
-        .bind(":state", quote.state.to_string())
-        .bind(":expiry", quote.expiry as i64)
-        .bind(":secret_key", quote.secret_key.map(|p| p.to_string()))
+        )?
+        .bind("id", quote.id.to_string())
+        .bind("mint_url", quote.mint_url.to_string())
+        .bind("amount", u64::from(quote.amount) as i64)
+        .bind("unit", quote.unit.to_string())
+        .bind("request", quote.request)
+        .bind("state", quote.state.to_string())
+        .bind("expiry", quote.expiry as i64)
+        .bind("secret_key", quote.secret_key.map(|p| p.to_string()))
         .execute(&self.db)
         .await?;
 
@@ -407,8 +407,8 @@ ON CONFLICT(id) DO UPDATE SET
             WHERE
                 id = :id
             "#,
-        )
-        .bind(":id", quote_id.to_string())
+        )?
+        .bind("id", quote_id.to_string())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_mint_quote)
@@ -431,7 +431,7 @@ ON CONFLICT(id) DO UPDATE SET
             FROM
                 mint_quote
             "#,
-        )
+        )?
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -441,8 +441,8 @@ ON CONFLICT(id) DO UPDATE SET
 
     #[instrument(skip(self))]
     async fn remove_mint_quote(&self, quote_id: &str) -> Result<(), Self::Err> {
-        query(r#"DELETE FROM mint_quote WHERE id=:id"#)
-            .bind(":id", quote_id.to_string())
+        query(r#"DELETE FROM mint_quote WHERE id=:id"#)?
+            .bind("id", quote_id.to_string())
             .execute(&self.db)
             .await?;
 
@@ -466,14 +466,14 @@ ON CONFLICT(id) DO UPDATE SET
     expiry = excluded.expiry
 ;
         "#,
-        )
-        .bind(":id", quote.id.to_string())
-        .bind(":unit", quote.unit.to_string())
-        .bind(":amount", u64::from(quote.amount) as i64)
-        .bind(":request", quote.request)
-        .bind(":fee_reserve", u64::from(quote.fee_reserve) as i64)
-        .bind(":state", quote.state.to_string())
-        .bind(":expiry", quote.expiry as i64)
+        )?
+        .bind("id", quote.id.to_string())
+        .bind("unit", quote.unit.to_string())
+        .bind("amount", u64::from(quote.amount) as i64)
+        .bind("request", quote.request)
+        .bind("fee_reserve", u64::from(quote.fee_reserve) as i64)
+        .bind("state", quote.state.to_string())
+        .bind("expiry", quote.expiry as i64)
         .execute(&self.db)
         .await?;
 
@@ -498,8 +498,8 @@ ON CONFLICT(id) DO UPDATE SET
             WHERE
                 id=:id
             "#,
-        )
-        .bind(":id", quote_id.to_owned())
+        )?
+        .bind("id", quote_id.to_owned())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_melt_quote)
@@ -508,8 +508,8 @@ ON CONFLICT(id) DO UPDATE SET
 
     #[instrument(skip(self))]
     async fn remove_melt_quote(&self, quote_id: &str) -> Result<(), Self::Err> {
-        query(r#"DELETE FROM melt_quote WHERE id=:id"#)
-            .bind(":id", quote_id.to_owned())
+        query(r#"DELETE FROM melt_quote WHERE id=:id"#)?
+            .bind("id", quote_id.to_owned())
             .execute(&self.db)
             .await?;
 
@@ -530,10 +530,10 @@ ON CONFLICT(id) DO UPDATE SET
             ON CONFLICT(id) DO UPDATE SET
                 keys = excluded.keys
         "#,
-        )
-        .bind(":id", keyset.id.to_string())
+        )?
+        .bind("id", keyset.id.to_string())
         .bind(
-            ":keys",
+            "keys",
             serde_json::to_string(&keyset.keys).map_err(Error::from)?,
         )
         .execute(&self.db)
@@ -551,8 +551,8 @@ ON CONFLICT(id) DO UPDATE SET
             FROM key
             WHERE id = :id
             "#,
-        )
-        .bind(":id", keyset_id.to_string())
+        )?
+        .bind("id", keyset_id.to_string())
         .pluck(&self.db)
         .await?
         .map(|keys| {
@@ -564,8 +564,8 @@ ON CONFLICT(id) DO UPDATE SET
 
     #[instrument(skip(self))]
     async fn remove_keys(&self, id: &Id) -> Result<(), Self::Err> {
-        query(r#"DELETE FROM key WHERE id = :id"#)
-            .bind(":id", id.to_string())
+        query(r#"DELETE FROM key WHERE id = :id"#)?
+            .bind("id", id.to_string())
             .pluck(&self.db)
             .await?;
 
@@ -600,46 +600,46 @@ ON CONFLICT(id) DO UPDATE SET
         dleq_r = excluded.dleq_r
     ;
             "#,
-            )
-            .bind(":y", proof.y.to_bytes().to_vec())
-            .bind(":mint_url", proof.mint_url.to_string())
-            .bind(":state",proof.state.to_string())
+            )?
+            .bind("y", proof.y.to_bytes().to_vec())
+            .bind("mint_url", proof.mint_url.to_string())
+            .bind("state",proof.state.to_string())
             .bind(
-                ":spending_condition",
+                "spending_condition",
                 proof
                     .spending_condition
                     .map(|s| serde_json::to_string(&s).ok()),
             )
-            .bind(":unit", proof.unit.to_string())
-            .bind(":amount", u64::from(proof.proof.amount) as i64)
-            .bind(":keyset_id", proof.proof.keyset_id.to_string())
-            .bind(":secret", proof.proof.secret.to_string())
-            .bind(":c", proof.proof.c.to_bytes().to_vec())
+            .bind("unit", proof.unit.to_string())
+            .bind("amount", u64::from(proof.proof.amount) as i64)
+            .bind("keyset_id", proof.proof.keyset_id.to_string())
+            .bind("secret", proof.proof.secret.to_string())
+            .bind("c", proof.proof.c.to_bytes().to_vec())
             .bind(
-                ":witness",
+                "witness",
                 proof
                     .proof
                     .witness
                     .map(|w| serde_json::to_string(&w).unwrap()),
             )
             .bind(
-                ":dleq_e",
+                "dleq_e",
                 proof.proof.dleq.as_ref().map(|dleq| dleq.e.to_secret_bytes().to_vec()),
             )
             .bind(
-                ":dleq_s",
+                "dleq_s",
                 proof.proof.dleq.as_ref().map(|dleq| dleq.s.to_secret_bytes().to_vec()),
             )
             .bind(
-                ":dleq_r",
+                "dleq_r",
                 proof.proof.dleq.as_ref().map(|dleq| dleq.r.to_secret_bytes().to_vec()),
             )
             .execute(&self.db).await?;
         }
 
-        query(r#"DELETE FROM proof WHERE y IN (:ys)"#)
+        query(r#"DELETE FROM proof WHERE y IN (:ys)"#)?
             .bind_vec(
-                ":ys",
+                "ys",
                 removed_ys.iter().map(|y| y.to_bytes().to_vec()).collect(),
             )
             .execute(&self.db)
@@ -674,7 +674,7 @@ ON CONFLICT(id) DO UPDATE SET
                 spending_condition
             FROM proof
         "#,
-        )
+        )?
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -691,9 +691,9 @@ ON CONFLICT(id) DO UPDATE SET
     }
 
     async fn update_proofs_state(&self, ys: Vec<PublicKey>, state: State) -> Result<(), Self::Err> {
-        query("UPDATE proof SET state = :state WHERE y IN (:ys)")
-            .bind_vec(":ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
-            .bind(":state", state.to_string())
+        query("UPDATE proof SET state = :state WHERE y IN (:ys)")?
+            .bind_vec("ys", ys.iter().map(|y| y.to_bytes().to_vec()).collect())
+            .bind("state", state.to_string())
             .execute(&self.db)
             .await?;
 
@@ -708,9 +708,9 @@ ON CONFLICT(id) DO UPDATE SET
             SET counter=counter+:count
             WHERE id=:id
             "#,
-        )
-        .bind(":count", count)
-        .bind(":id", keyset_id.to_string())
+        )?
+        .bind("count", count)
+        .bind("id", keyset_id.to_string())
         .execute(&self.db)
         .await?;
 
@@ -728,8 +728,8 @@ ON CONFLICT(id) DO UPDATE SET
             WHERE
                 id=:id
             "#,
-        )
-        .bind(":id", keyset_id.to_string())
+        )?
+        .bind("id", keyset_id.to_string())
         .pluck(&self.db)
         .await?
         .map(|n| Ok::<_, Error>(column_as_number!(n)))
@@ -767,18 +767,18 @@ ON CONFLICT(id) DO UPDATE SET
     metadata = excluded.metadata
 ;
         "#,
-        )
-        .bind(":id", transaction.id().as_slice().to_vec())
-        .bind(":mint_url", mint_url)
-        .bind(":direction", direction)
-        .bind(":unit", unit)
-        .bind(":amount", amount)
-        .bind(":fee", fee)
-        .bind(":ys", ys)
-        .bind(":timestamp", transaction.timestamp as i64)
-        .bind(":memo", transaction.memo)
+        )?
+        .bind("id", transaction.id().as_slice().to_vec())
+        .bind("mint_url", mint_url)
+        .bind("direction", direction)
+        .bind("unit", unit)
+        .bind("amount", amount)
+        .bind("fee", fee)
+        .bind("ys", ys)
+        .bind("timestamp", transaction.timestamp as i64)
+        .bind("memo", transaction.memo)
         .bind(
-            ":metadata",
+            "metadata",
             serde_json::to_string(&transaction.metadata).map_err(Error::from)?,
         )
         .execute(&self.db)
@@ -809,8 +809,8 @@ ON CONFLICT(id) DO UPDATE SET
             WHERE
                 id = :id
             "#,
-        )
-        .bind(":id", transaction_id.as_slice().to_vec())
+        )?
+        .bind("id", transaction_id.as_slice().to_vec())
         .fetch_one(&self.db)
         .await?
         .map(sql_row_to_transaction)
@@ -839,7 +839,7 @@ ON CONFLICT(id) DO UPDATE SET
             FROM
                 transactions
             "#,
-        )
+        )?
         .fetch_all(&self.db)
         .await?
         .into_iter()
@@ -857,8 +857,8 @@ ON CONFLICT(id) DO UPDATE SET
 
     #[instrument(skip(self))]
     async fn remove_transaction(&self, transaction_id: TransactionId) -> Result<(), Self::Err> {
-        query(r#"DELETE FROM transactions WHERE id=:id"#)
-            .bind(":id", transaction_id.as_slice().to_vec())
+        query(r#"DELETE FROM transactions WHERE id=:id"#)?
+            .bind("id", transaction_id.as_slice().to_vec())
             .execute(&self.db)
             .await?;
 
@@ -1079,120 +1079,3 @@ fn sql_row_to_transaction(row: Vec<Column>) -> Result<Transaction, Error> {
         .unwrap_or_default(),
     })
 }
-
-#[cfg(test)]
-mod tests {
-    use cdk_common::database::WalletDatabase;
-    use cdk_common::nuts::{ProofDleq, State};
-    use cdk_common::secret::Secret;
-
-    use crate::SQLWalletDatabase;
-
-    #[tokio::test]
-    #[cfg(feature = "sqlcipher")]
-    async fn test_sqlcipher() {
-        use cdk_common::mint_url::MintUrl;
-        use cdk_common::MintInfo;
-
-        use super::*;
-        let path = std::env::temp_dir()
-            .to_path_buf()
-            .join(format!("cdk-test-{}.sqlite", uuid::Uuid::new_v4()));
-        let db = SQLWalletDatabase::new(path, "password".to_string())
-            .await
-            .unwrap();
-
-        let mint_info = MintInfo::new().description("test");
-        let mint_url = MintUrl::from_str("https://mint.xyz").unwrap();
-
-        db.add_mint(mint_url.clone(), Some(mint_info.clone()))
-            .await
-            .unwrap();
-
-        let res = db.get_mint(mint_url).await.unwrap();
-        assert_eq!(mint_info, res.clone().unwrap());
-        assert_eq!("test", &res.unwrap().description.unwrap());
-    }
-
-    #[tokio::test]
-    async fn test_proof_with_dleq() {
-        use std::str::FromStr;
-
-        use cdk_common::common::ProofInfo;
-        use cdk_common::mint_url::MintUrl;
-        use cdk_common::nuts::{CurrencyUnit, Id, Proof, PublicKey, SecretKey};
-        use cdk_common::Amount;
-
-        // Create a temporary database
-        let path = std::env::temp_dir()
-            .to_path_buf()
-            .join(format!("cdk-test-dleq-{}.sqlite", uuid::Uuid::new_v4()));
-
-        #[cfg(feature = "sqlcipher")]
-        let db = SQLWalletDatabase::new(path, "password".to_string())
-            .await
-            .unwrap();
-
-        #[cfg(not(feature = "sqlcipher"))]
-        let db = SQLWalletDatabase::new(path).await.unwrap();
-
-        // Create a proof with DLEQ
-        let keyset_id = Id::from_str("00deadbeef123456").unwrap();
-        let mint_url = MintUrl::from_str("https://example.com").unwrap();
-        let secret = Secret::new("test_secret_for_dleq");
-
-        // Create DLEQ components
-        let e = SecretKey::generate();
-        let s = SecretKey::generate();
-        let r = SecretKey::generate();
-
-        let dleq = ProofDleq::new(e.clone(), s.clone(), r.clone());
-
-        let mut proof = Proof::new(
-            Amount::from(64),
-            keyset_id,
-            secret,
-            PublicKey::from_hex(
-                "02deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef",
-            )
-            .unwrap(),
-        );
-
-        // Add DLEQ to the proof
-        proof.dleq = Some(dleq);
-
-        // Create ProofInfo
-        let proof_info =
-            ProofInfo::new(proof, mint_url.clone(), State::Unspent, CurrencyUnit::Sat).unwrap();
-
-        // Store the proof in the database
-        db.update_proofs(vec![proof_info.clone()], vec![])
-            .await
-            .unwrap();
-
-        // Retrieve the proof from the database
-        let retrieved_proofs = db
-            .get_proofs(
-                Some(mint_url),
-                Some(CurrencyUnit::Sat),
-                Some(vec![State::Unspent]),
-                None,
-            )
-            .await
-            .unwrap();
-
-        // Verify we got back exactly one proof
-        assert_eq!(retrieved_proofs.len(), 1);
-
-        // Verify the DLEQ data was preserved
-        let retrieved_proof = &retrieved_proofs[0];
-        assert!(retrieved_proof.proof.dleq.is_some());
-
-        let retrieved_dleq = retrieved_proof.proof.dleq.as_ref().unwrap();
-
-        // Verify DLEQ components match what we stored
-        assert_eq!(retrieved_dleq.e.to_string(), e.to_string());
-        assert_eq!(retrieved_dleq.s.to_string(), s.to_string());
-        assert_eq!(retrieved_dleq.r.to_string(), r.to_string());
-    }
-}

+ 18 - 34
crates/cdk-sqlite/src/mint/async_rusqlite.rs

@@ -97,8 +97,8 @@ enum SqliteError {
     #[error(transparent)]
     Sqlite(#[from] rusqlite::Error),
 
-    #[error("Invalid usage")]
-    InvalidUsage,
+    #[error(transparent)]
+    Inner(#[from] Error),
 
     #[error(transparent)]
     Pool(#[from] pool::Error<rusqlite::Error>),
@@ -123,37 +123,22 @@ impl From<SqliteError> for Error {
 
 /// Process a query
 #[inline(always)]
-fn process_query(conn: &Connection, sql: InnerStatement) -> Result<DbResponse, SqliteError> {
+fn process_query(conn: &Connection, statement: InnerStatement) -> Result<DbResponse, SqliteError> {
     let start = Instant::now();
-    let mut args = sql.args;
-    let mut stmt = conn.prepare_cached(&sql.sql)?;
-    let total_parameters = stmt.parameter_count();
-    let total_args = args.len();
-
-    for index in 1..=total_parameters {
-        let value = if let Some(value) = stmt.parameter_name(index).map(|name| {
-            args.remove(name)
-                .ok_or(ConversionError::MissingParameter(name.to_owned()))
-        }) {
-            value?
-        } else {
-            continue;
-        };
-
-        stmt.raw_bind_parameter(index, to_sqlite(value))?;
+    let expected_response = statement.expected_response;
+    let (sql, placeholder_values) = statement.to_sql()?;
+
+    let mut stmt = conn.prepare_cached(&sql)?;
+    for (i, value) in placeholder_values.into_iter().enumerate() {
+        stmt.raw_bind_parameter(i + 1, to_sqlite(value))?;
     }
 
     let columns = stmt.column_count();
 
-    let to_return = match sql.expected_response {
+    let to_return = match expected_response {
         ExpectedSqlResponse::AffectedRows => DbResponse::AffectedRows(stmt.raw_execute()?),
         ExpectedSqlResponse::Batch => {
-            if total_args > 0 {
-                return Err(SqliteError::InvalidUsage);
-            }
-
-            conn.execute_batch(&sql.sql)?;
-
+            conn.execute_batch(&sql)?;
             DbResponse::Ok
         }
         ExpectedSqlResponse::ManyRows => {
@@ -195,7 +180,7 @@ fn process_query(conn: &Connection, sql: InnerStatement) -> Result<DbResponse, S
     let duration = start.elapsed();
 
     if duration.as_millis() > SLOW_QUERY_THRESHOLD_MS {
-        tracing::warn!("[SLOW QUERY] Took {} ms: {}", duration.as_millis(), sql.sql);
+        tracing::warn!("[SLOW QUERY] Took {} ms: {}", duration.as_millis(), sql);
     }
 
     Ok(to_return)
@@ -230,7 +215,6 @@ fn rusqlite_spawn_worker_threads(
         let inflight_requests = inflight_requests.clone();
         spawn(move || loop {
             while let Ok((conn, sql, reply_to)) = rx.lock().expect("failed to acquire").recv() {
-                tracing::trace!("Execute query: {}", sql.sql);
                 let result = process_query(&conn, sql);
                 let _ = match result {
                     Ok(ok) => reply_to.send(ok),
@@ -290,7 +274,7 @@ fn rusqlite_worker_manager(
     while let Some(request) = receiver.blocking_recv() {
         inflight_requests.fetch_add(1, Ordering::Relaxed);
         match request {
-            DbRequest::Sql(sql, reply_to) => {
+            DbRequest::Sql(statement, reply_to) => {
                 let conn = match pool.get() {
                     Ok(conn) => conn,
                     Err(err) => {
@@ -301,7 +285,7 @@ fn rusqlite_worker_manager(
                     }
                 };
 
-                let _ = send_sql_to_thread.send((conn, sql, reply_to));
+                let _ = send_sql_to_thread.send((conn, statement, reply_to));
                 continue;
             }
             DbRequest::Begin(reply_to) => {
@@ -375,9 +359,9 @@ fn rusqlite_worker_manager(
                         DbRequest::Begin(reply_to) => {
                             let _ = reply_to.send(DbResponse::Unexpected);
                         }
-                        DbRequest::Sql(sql, reply_to) => {
-                            tracing::trace!("Tx {}: SQL {}", tx_id, sql.sql);
-                            let _ = match process_query(&tx, sql) {
+                        DbRequest::Sql(statement, reply_to) => {
+                            tracing::trace!("Tx {}: SQL {:?}", tx_id, statement);
+                            let _ = match process_query(&tx, statement) {
                                 Ok(ok) => reply_to.send(ok),
                                 Err(err) => {
                                     tracing::error!(
@@ -584,7 +568,7 @@ pub struct Transaction<'conn> {
     _marker: PhantomData<&'conn ()>,
 }
 
-impl<'conn> Transaction<'conn> {
+impl Transaction<'_> {
     fn get_queue_sender(&self) -> &mpsc::Sender<DbRequest> {
         &self.db_sender
     }

+ 33 - 15
crates/cdk-sqlite/src/wallet/mod.rs

@@ -6,7 +6,7 @@ use std::sync::Arc;
 use cdk_common::database::Error;
 use cdk_sql_base::database::DatabaseExecutor;
 use cdk_sql_base::pool::{Pool, PooledResource};
-use cdk_sql_base::stmt::{Column, Statement};
+use cdk_sql_base::stmt::{Column, SqlPart, Statement};
 use cdk_sql_base::SQLWalletDatabase;
 use rusqlite::CachedStatement;
 
@@ -24,15 +24,15 @@ impl SimpleAsyncRusqlite {
         &self,
         conn: &'a PooledResource<SqliteConnectionManager>,
         statement: Statement,
-    ) -> rusqlite::Result<CachedStatement<'a>> {
-        let mut stmt = conn.prepare_cached(&statement.sql)?;
-        for (name, value) in statement.args {
-            let index = stmt
-                .parameter_index(&name)
-                .map_err(|_| rusqlite::Error::InvalidColumnName(name.clone()))?
-                .ok_or(rusqlite::Error::InvalidColumnName(name))?;
-
-            stmt.raw_bind_parameter(index, to_sqlite(value))?;
+    ) -> Result<CachedStatement<'a>, Error> {
+        let (sql, placeholder_values) = statement.to_sql()?;
+        let mut stmt = conn
+            .prepare_cached(&sql)
+            .map_err(|e| Error::Database(Box::new(e)))?;
+
+        for (i, value) in placeholder_values.into_iter().enumerate() {
+            stmt.raw_bind_parameter(i + 1, to_sqlite(value))
+                .map_err(|e| Error::Database(Box::new(e)))?;
         }
 
         Ok(stmt)
@@ -109,10 +109,28 @@ impl DatabaseExecutor for SimpleAsyncRusqlite {
             .map_err(|e| Error::Database(Box::new(e)))
     }
 
-    async fn batch(&self, statement: Statement) -> Result<(), Error> {
+    async fn batch(&self, mut statement: Statement) -> Result<(), Error> {
         let conn = self.0.get().map_err(|e| Error::Database(Box::new(e)))?;
 
-        conn.execute_batch(&statement.sql)
+        let sql = {
+            let part = statement
+                .parts
+                .pop()
+                .ok_or(Error::Internal("Empty SQL".to_owned()))?;
+
+            if !statement.parts.is_empty() || matches!(part, SqlPart::Placeholder(_, _)) {
+                return Err(Error::Internal(
+                    "Invalid usage, batch does not support placeholders".to_owned(),
+                ));
+            }
+            if let SqlPart::Raw(sql) = part {
+                sql
+            } else {
+                unreachable!()
+            }
+        };
+
+        conn.execute_batch(&sql)
             .map_err(|e| Error::Database(Box::new(e)))
     }
 }
@@ -164,6 +182,8 @@ pub type WalletSqliteDatabase = SQLWalletDatabase<SimpleAsyncRusqlite>;
 
 #[cfg(test)]
 mod tests {
+    use std::str::FromStr;
+
     use cdk_common::database::WalletDatabase;
     use cdk_common::nuts::{ProofDleq, State};
     use cdk_common::secret::Secret;
@@ -180,7 +200,7 @@ mod tests {
         let path = std::env::temp_dir()
             .to_path_buf()
             .join(format!("cdk-test-{}.sqlite", uuid::Uuid::new_v4()));
-        let db = WalletSqliteDatabase::new(path, "password".to_string())
+        let db = WalletSqliteDatabase::new((path, "password".to_string()))
             .await
             .unwrap();
 
@@ -198,8 +218,6 @@ mod tests {
 
     #[tokio::test]
     async fn test_proof_with_dleq() {
-        use std::str::FromStr;
-
         use cdk_common::common::ProofInfo;
         use cdk_common::mint_url::MintUrl;
         use cdk_common::nuts::{CurrencyUnit, Id, Proof, PublicKey, SecretKey};