Sfoglia il codice sorgente

Add Cache to SQL stmt

The cache will store the placeholders and if possible the RAW SQL with position
placeholders, to avoid repetitive computations
Cesar Rodas 2 mesi fa
parent
commit
034af74013
2 ha cambiato i file con 73 aggiunte e 7 eliminazioni
  1. 1 0
      crates/cdk-sql-common/Cargo.toml
  2. 72 7
      crates/cdk-sql-common/src/stmt.rs

+ 1 - 0
crates/cdk-sql-common/Cargo.toml

@@ -28,3 +28,4 @@ serde.workspace = true
 serde_json.workspace = true
 lightning-invoice.workspace = true
 uuid.workspace = true
+once_cell.workspace = true

+ 72 - 7
crates/cdk-sql-common/src/stmt.rs

@@ -1,7 +1,9 @@
 //! Stataments mod
-use std::sync::Arc;
+use std::collections::HashMap;
+use std::sync::{Arc, RwLock};
 
 use cdk_common::database::Error;
+use once_cell::sync::Lazy;
 
 use crate::database::DatabaseExecutor;
 use crate::value::Value;
@@ -140,9 +142,14 @@ pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
     Ok(parts)
 }
 
+type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
+
 /// Sql message
 #[derive(Debug, Default)]
 pub struct Statement {
+    cache: Arc<RwLock<Cache>>,
+    cached_sql: Option<Arc<str>>,
+    sql: Option<String>,
     /// The SQL statement
     pub parts: Vec<SqlPart>,
     /// The expected response type
@@ -151,11 +158,35 @@ pub struct Statement {
 
 impl Statement {
     /// Creates a new statement
-    pub fn new(sql: &str) -> Result<Self, SqlParseError> {
-        Ok(Self {
-            parts: split_sql_parts(sql)?,
-            ..Default::default()
-        })
+    fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
+        let parsed = cache
+            .read()
+            .map(|cache| cache.get(sql).cloned())
+            .ok()
+            .flatten();
+
+        if let Some((parts, cached_sql)) = parsed {
+            Ok(Self {
+                parts,
+                cached_sql,
+                sql: None,
+                cache,
+                ..Default::default()
+            })
+        } else {
+            let parts = split_sql_parts(sql)?;
+
+            let _ = cache.write().map(|mut cache| {
+                cache.insert(sql.to_owned(), (parts.clone(), None));
+            });
+
+            Ok(Self {
+                parts,
+                sql: Some(sql.to_owned()),
+                cache,
+                ..Default::default()
+            })
+        }
     }
 
     /// Convert Statement into a SQL statement and the list of placeholders
@@ -164,7 +195,29 @@ impl Statement {
     /// to be more widely supported, although it can be reimplemented with other formats since part
     /// is public
     pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
+        if let Some(cached_sql) = self.cached_sql {
+            let sql = cached_sql.to_string();
+            let values = self
+                .parts
+                .into_iter()
+                .map(|x| match x {
+                    SqlPart::Placeholder(name, value) => {
+                        match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
+                            PlaceholderValue::Value(value) => Ok(vec![value]),
+                            PlaceholderValue::Set(values) => Ok(values),
+                        }
+                    }
+                    SqlPart::Raw(_) => Ok(vec![]),
+                })
+                .collect::<Result<Vec<_>, Error>>()?
+                .into_iter()
+                .flatten()
+                .collect::<Vec<_>>();
+            return Ok((sql, values));
+        }
+
         let mut placeholder_values = Vec::new();
+        let mut can_be_cached = true;
         let sql = self
             .parts
             .into_iter()
@@ -176,6 +229,7 @@ impl Statement {
                             Ok::<_, Error>(format!("${}", placeholder_values.len()))
                         }
                         PlaceholderValue::Set(mut values) => {
+                            can_be_cached = false;
                             let start_size = placeholder_values.len();
                             placeholder_values.append(&mut values);
                             let placeholders = (start_size + 1..=placeholder_values.len())
@@ -191,6 +245,16 @@ impl Statement {
             .collect::<Result<Vec<String>, _>>()?
             .join(" ");
 
+        if can_be_cached {
+            if let Some(original_sql) = self.sql {
+                let _ = self.cache.write().map(|mut cache| {
+                    if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
+                        *cached_sql = Some(sql.clone().into());
+                    }
+                });
+            }
+        }
+
         Ok((sql, placeholder_values))
     }
 
@@ -288,5 +352,6 @@ impl Statement {
 /// Creates a new query statement
 #[inline(always)]
 pub fn query(sql: &str) -> Result<Statement, Error> {
-    Statement::new(sql).map_err(|e| Error::Database(Box::new(e)))
+    static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
+    Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
 }