//! Stataments mod use std::sync::Arc; use cdk_common::database::Error; use crate::database::DatabaseExecutor; use crate::value::Value; /// The Column type pub type Column = Value; /// Expected response type for a given SQL statement #[derive(Debug, Clone, Copy, Default)] pub enum ExpectedSqlResponse { /// A single row SingleRow, /// All the rows that matches a query #[default] ManyRows, /// How many rows were affected by the query AffectedRows, /// Return the first column of the first row Pluck, /// Batch Batch, } /// Part value #[derive(Debug, Clone)] pub enum PlaceholderValue { /// Value Value(Value), /// Set Set(Vec), } impl From for PlaceholderValue { fn from(value: Value) -> Self { PlaceholderValue::Value(value) } } impl From> for PlaceholderValue { fn from(value: Vec) -> Self { PlaceholderValue::Set(value) } } /// SQL Part #[derive(Debug, Clone)] pub enum SqlPart { /// Raw SQL statement Raw(Arc), /// Placeholder Placeholder(Arc, Option), } /// SQL parser error #[derive(Debug, PartialEq, thiserror::Error)] pub enum SqlParseError { /// Invalid SQL #[error("Unterminated String literal")] UnterminatedStringLiteral, /// Invalid placeholder name #[error("Invalid placeholder name")] InvalidPlaceholder, } /// Rudimentary SQL parser. /// /// This function does not validate the SQL statement, it only extracts the placeholder to be /// database agnostic. pub fn split_sql_parts(input: &str) -> Result, SqlParseError> { let mut parts = Vec::new(); let mut current = String::new(); let mut chars = input.chars().peekable(); while let Some(&c) = chars.peek() { match c { '\'' | '"' => { // Start of string literal let quote = c; current.push(chars.next().unwrap()); let mut closed = false; while let Some(&next) = chars.peek() { current.push(chars.next().unwrap()); if next == quote { if chars.peek() == Some("e) { // Escaped quote (e.g. '' inside strings) current.push(chars.next().unwrap()); } else { closed = true; break; } } } if !closed { return Err(SqlParseError::UnterminatedStringLiteral); } } ':' => { // Flush current raw SQL if !current.is_empty() { parts.push(SqlPart::Raw(current.clone().into())); current.clear(); } chars.next(); // consume ':' let mut name = String::new(); while let Some(&next) = chars.peek() { if next.is_alphanumeric() || next == '_' { name.push(chars.next().unwrap()); } else { break; } } if name.is_empty() { return Err(SqlParseError::InvalidPlaceholder); } parts.push(SqlPart::Placeholder(name.into(), None)); } _ => { current.push(chars.next().unwrap()); } } } if !current.is_empty() { parts.push(SqlPart::Raw(current.into())); } Ok(parts) } /// Sql message #[derive(Debug, Default)] pub struct Statement { /// The SQL statement pub parts: Vec, /// The expected response type pub expected_response: ExpectedSqlResponse, } impl Statement { /// Creates a new statement pub fn new(sql: &str) -> Result { Ok(Self { parts: split_sql_parts(sql)?, ..Default::default() }) } /// Convert Statement into a SQL statement and the list of placeholders /// /// By default it converts the statement into placeholder using $1..$n placeholders which seems /// to be more widely supported, although it can be reimplemented with other formats since part /// is public pub fn to_sql(self) -> Result<(String, Vec), Error> { let mut placeholder_values = Vec::new(); let sql = self .parts .into_iter() .map(|x| match x { SqlPart::Placeholder(name, value) => { match value.ok_or(Error::MissingPlaceholder(name.to_string()))? { PlaceholderValue::Value(value) => { placeholder_values.push(value); Ok::<_, Error>(format!("${}", placeholder_values.len())) } PlaceholderValue::Set(mut values) => { let start_size = placeholder_values.len(); placeholder_values.append(&mut values); let placeholders = (start_size + 1..=placeholder_values.len()) .map(|i| format!("${i}")) .collect::>() .join(", "); Ok(placeholders) } } } SqlPart::Raw(raw) => Ok(raw.trim().to_string()), }) .collect::, _>>()? .join(" "); Ok((sql, placeholder_values)) } /// Binds a given placeholder to a value. #[inline] pub fn bind(mut self, name: C, value: V) -> Self where C: ToString, V: Into, { let name = name.to_string(); let value = value.into(); let value: PlaceholderValue = value.into(); for part in self.parts.iter_mut() { if let SqlPart::Placeholder(part_name, part_value) = part { if **part_name == *name.as_str() { *part_value = Some(value.clone()); } } } self } /// Binds a single variable with a vector. /// /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1, /// :foo2` and binds each value from the value vector accordingly. #[inline] pub fn bind_vec(mut self, name: C, value: Vec) -> Self where C: ToString, V: Into, { let name = name.to_string(); let value: PlaceholderValue = value .into_iter() .map(|x| x.into()) .collect::>() .into(); for part in self.parts.iter_mut() { if let SqlPart::Placeholder(part_name, part_value) = part { if **part_name == *name.as_str() { *part_value = Some(value.clone()); } } } self } /// Executes a query and returns the affected rows pub async fn pluck(self, conn: &C) -> Result, Error> where C: DatabaseExecutor, { conn.pluck(self).await } /// Executes a query and returns the affected rows pub async fn batch(self, conn: &C) -> Result<(), Error> where C: DatabaseExecutor, { conn.batch(self).await } /// Executes a query and returns the affected rows pub async fn execute(self, conn: &C) -> Result where C: DatabaseExecutor, { conn.execute(self).await } /// Runs the query and returns the first row or None pub async fn fetch_one(self, conn: &C) -> Result>, Error> where C: DatabaseExecutor, { conn.fetch_one(self).await } /// Runs the query and returns the first row or None pub async fn fetch_all(self, conn: &C) -> Result>, Error> where C: DatabaseExecutor, { conn.fetch_all(self).await } } /// Creates a new query statement #[inline(always)] pub fn query(sql: &str) -> Result { Statement::new(sql).map_err(|e| Error::Database(Box::new(e))) }