123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- //! Stataments mod
- use std::sync::Arc;
- use cdk_common::database::Error;
- use crate::database::DatabaseExecutor;
- use crate::value::Value;
- /// The Column type
- pub type Column = Value;
- /// Expected response type for a given SQL statement
- #[derive(Debug, Clone, Copy, Default)]
- pub enum ExpectedSqlResponse {
- /// A single row
- SingleRow,
- /// All the rows that matches a query
- #[default]
- ManyRows,
- /// How many rows were affected by the query
- AffectedRows,
- /// Return the first column of the first row
- Pluck,
- /// Batch
- Batch,
- }
- /// Part value
- #[derive(Debug, Clone)]
- pub enum PlaceholderValue {
- /// Value
- Value(Value),
- /// Set
- Set(Vec<Value>),
- }
- impl From<Value> for PlaceholderValue {
- fn from(value: Value) -> Self {
- PlaceholderValue::Value(value)
- }
- }
- impl From<Vec<Value>> for PlaceholderValue {
- fn from(value: Vec<Value>) -> Self {
- PlaceholderValue::Set(value)
- }
- }
- /// SQL Part
- #[derive(Debug, Clone)]
- pub enum SqlPart {
- /// Raw SQL statement
- Raw(Arc<str>),
- /// Placeholder
- Placeholder(Arc<str>, Option<PlaceholderValue>),
- }
- /// SQL parser error
- #[derive(Debug, PartialEq, thiserror::Error)]
- pub enum SqlParseError {
- /// Invalid SQL
- #[error("Unterminated String literal")]
- UnterminatedStringLiteral,
- /// Invalid placeholder name
- #[error("Invalid placeholder name")]
- InvalidPlaceholder,
- }
- /// Rudimentary SQL parser.
- ///
- /// This function does not validate the SQL statement, it only extracts the placeholder to be
- /// database agnostic.
- pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
- let mut parts = Vec::new();
- let mut current = String::new();
- let mut chars = input.chars().peekable();
- while let Some(&c) = chars.peek() {
- match c {
- '\'' | '"' => {
- // Start of string literal
- let quote = c;
- current.push(chars.next().unwrap());
- let mut closed = false;
- while let Some(&next) = chars.peek() {
- current.push(chars.next().unwrap());
- if next == quote {
- if chars.peek() == Some("e) {
- // Escaped quote (e.g. '' inside strings)
- current.push(chars.next().unwrap());
- } else {
- closed = true;
- break;
- }
- }
- }
- if !closed {
- return Err(SqlParseError::UnterminatedStringLiteral);
- }
- }
- ':' => {
- // Flush current raw SQL
- if !current.is_empty() {
- parts.push(SqlPart::Raw(current.clone().into()));
- current.clear();
- }
- chars.next(); // consume ':'
- let mut name = String::new();
- while let Some(&next) = chars.peek() {
- if next.is_alphanumeric() || next == '_' {
- name.push(chars.next().unwrap());
- } else {
- break;
- }
- }
- if name.is_empty() {
- return Err(SqlParseError::InvalidPlaceholder);
- }
- parts.push(SqlPart::Placeholder(name.into(), None));
- }
- _ => {
- current.push(chars.next().unwrap());
- }
- }
- }
- if !current.is_empty() {
- parts.push(SqlPart::Raw(current.into()));
- }
- Ok(parts)
- }
- /// Sql message
- #[derive(Debug, Default)]
- pub struct Statement {
- /// The SQL statement
- pub parts: Vec<SqlPart>,
- /// The expected response type
- pub expected_response: ExpectedSqlResponse,
- }
- impl Statement {
- /// Creates a new statement
- pub fn new(sql: &str) -> Result<Self, SqlParseError> {
- Ok(Self {
- parts: split_sql_parts(sql)?,
- ..Default::default()
- })
- }
- /// Convert Statement into a SQL statement and the list of placeholders
- ///
- /// By default it converts the statement into placeholder using $1..$n placeholders which seems
- /// to be more widely supported, although it can be reimplemented with other formats since part
- /// is public
- pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
- let mut placeholder_values = Vec::new();
- let sql = self
- .parts
- .into_iter()
- .map(|x| match x {
- SqlPart::Placeholder(name, value) => {
- match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
- PlaceholderValue::Value(value) => {
- placeholder_values.push(value);
- Ok::<_, Error>(format!("${}", placeholder_values.len()))
- }
- PlaceholderValue::Set(mut values) => {
- let start_size = placeholder_values.len();
- placeholder_values.append(&mut values);
- let placeholders = (start_size + 1..=placeholder_values.len())
- .map(|i| format!("${i}"))
- .collect::<Vec<_>>()
- .join(", ");
- Ok(placeholders)
- }
- }
- }
- SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
- })
- .collect::<Result<Vec<String>, _>>()?
- .join(" ");
- Ok((sql, placeholder_values))
- }
- /// Binds a given placeholder to a value.
- #[inline]
- pub fn bind<C, V>(mut self, name: C, value: V) -> Self
- where
- C: ToString,
- V: Into<Value>,
- {
- let name = name.to_string();
- let value = value.into();
- let value: PlaceholderValue = value.into();
- for part in self.parts.iter_mut() {
- if let SqlPart::Placeholder(part_name, part_value) = part {
- if **part_name == *name.as_str() {
- *part_value = Some(value.clone());
- }
- }
- }
- self
- }
- /// Binds a single variable with a vector.
- ///
- /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
- /// :foo2` and binds each value from the value vector accordingly.
- #[inline]
- pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
- where
- C: ToString,
- V: Into<Value>,
- {
- let name = name.to_string();
- let value: PlaceholderValue = value
- .into_iter()
- .map(|x| x.into())
- .collect::<Vec<Value>>()
- .into();
- for part in self.parts.iter_mut() {
- if let SqlPart::Placeholder(part_name, part_value) = part {
- if **part_name == *name.as_str() {
- *part_value = Some(value.clone());
- }
- }
- }
- self
- }
- /// Executes a query and returns the affected rows
- pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
- where
- C: DatabaseExecutor,
- {
- conn.pluck(self).await
- }
- /// Executes a query and returns the affected rows
- pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
- where
- C: DatabaseExecutor,
- {
- conn.batch(self).await
- }
- /// Executes a query and returns the affected rows
- pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
- where
- C: DatabaseExecutor,
- {
- conn.execute(self).await
- }
- /// Runs the query and returns the first row or None
- pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
- where
- C: DatabaseExecutor,
- {
- conn.fetch_one(self).await
- }
- /// Runs the query and returns the first row or None
- pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
- where
- C: DatabaseExecutor,
- {
- conn.fetch_all(self).await
- }
- }
- /// Creates a new query statement
- #[inline(always)]
- pub fn query(sql: &str) -> Result<Statement, Error> {
- Statement::new(sql).map_err(|e| Error::Database(Box::new(e)))
- }
|