stmt.rs 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. //! Stataments mod
  2. use std::sync::Arc;
  3. use cdk_common::database::Error;
  4. use crate::database::DatabaseExecutor;
  5. use crate::value::Value;
  6. /// The Column type
  7. pub type Column = Value;
  8. /// Expected response type for a given SQL statement
  9. #[derive(Debug, Clone, Copy, Default)]
  10. pub enum ExpectedSqlResponse {
  11. /// A single row
  12. SingleRow,
  13. /// All the rows that matches a query
  14. #[default]
  15. ManyRows,
  16. /// How many rows were affected by the query
  17. AffectedRows,
  18. /// Return the first column of the first row
  19. Pluck,
  20. /// Batch
  21. Batch,
  22. }
  23. /// Part value
  24. #[derive(Debug, Clone)]
  25. pub enum PlaceholderValue {
  26. /// Value
  27. Value(Value),
  28. /// Set
  29. Set(Vec<Value>),
  30. }
  31. impl From<Value> for PlaceholderValue {
  32. fn from(value: Value) -> Self {
  33. PlaceholderValue::Value(value)
  34. }
  35. }
  36. impl From<Vec<Value>> for PlaceholderValue {
  37. fn from(value: Vec<Value>) -> Self {
  38. PlaceholderValue::Set(value)
  39. }
  40. }
  41. /// SQL Part
  42. #[derive(Debug, Clone)]
  43. pub enum SqlPart {
  44. /// Raw SQL statement
  45. Raw(Arc<str>),
  46. /// Placeholder
  47. Placeholder(Arc<str>, Option<PlaceholderValue>),
  48. }
  49. /// SQL parser error
  50. #[derive(Debug, PartialEq, thiserror::Error)]
  51. pub enum SqlParseError {
  52. /// Invalid SQL
  53. #[error("Unterminated String literal")]
  54. UnterminatedStringLiteral,
  55. /// Invalid placeholder name
  56. #[error("Invalid placeholder name")]
  57. InvalidPlaceholder,
  58. }
  59. /// Rudimentary SQL parser.
  60. ///
  61. /// This function does not validate the SQL statement, it only extracts the placeholder to be
  62. /// database agnostic.
  63. pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
  64. let mut parts = Vec::new();
  65. let mut current = String::new();
  66. let mut chars = input.chars().peekable();
  67. while let Some(&c) = chars.peek() {
  68. match c {
  69. '\'' | '"' => {
  70. // Start of string literal
  71. let quote = c;
  72. current.push(chars.next().unwrap());
  73. let mut closed = false;
  74. while let Some(&next) = chars.peek() {
  75. current.push(chars.next().unwrap());
  76. if next == quote {
  77. if chars.peek() == Some(&quote) {
  78. // Escaped quote (e.g. '' inside strings)
  79. current.push(chars.next().unwrap());
  80. } else {
  81. closed = true;
  82. break;
  83. }
  84. }
  85. }
  86. if !closed {
  87. return Err(SqlParseError::UnterminatedStringLiteral);
  88. }
  89. }
  90. ':' => {
  91. // Flush current raw SQL
  92. if !current.is_empty() {
  93. parts.push(SqlPart::Raw(current.clone().into()));
  94. current.clear();
  95. }
  96. chars.next(); // consume ':'
  97. let mut name = String::new();
  98. while let Some(&next) = chars.peek() {
  99. if next.is_alphanumeric() || next == '_' {
  100. name.push(chars.next().unwrap());
  101. } else {
  102. break;
  103. }
  104. }
  105. if name.is_empty() {
  106. return Err(SqlParseError::InvalidPlaceholder);
  107. }
  108. parts.push(SqlPart::Placeholder(name.into(), None));
  109. }
  110. _ => {
  111. current.push(chars.next().unwrap());
  112. }
  113. }
  114. }
  115. if !current.is_empty() {
  116. parts.push(SqlPart::Raw(current.into()));
  117. }
  118. Ok(parts)
  119. }
  120. /// Sql message
  121. #[derive(Debug, Default)]
  122. pub struct Statement {
  123. /// The SQL statement
  124. pub parts: Vec<SqlPart>,
  125. /// The expected response type
  126. pub expected_response: ExpectedSqlResponse,
  127. }
  128. impl Statement {
  129. /// Creates a new statement
  130. pub fn new(sql: &str) -> Result<Self, SqlParseError> {
  131. Ok(Self {
  132. parts: split_sql_parts(sql)?,
  133. ..Default::default()
  134. })
  135. }
  136. /// Convert Statement into a SQL statement and the list of placeholders
  137. ///
  138. /// By default it converts the statement into placeholder using $1..$n placeholders which seems
  139. /// to be more widely supported, although it can be reimplemented with other formats since part
  140. /// is public
  141. pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
  142. let mut placeholder_values = Vec::new();
  143. let sql = self
  144. .parts
  145. .into_iter()
  146. .map(|x| match x {
  147. SqlPart::Placeholder(name, value) => {
  148. match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
  149. PlaceholderValue::Value(value) => {
  150. placeholder_values.push(value);
  151. Ok::<_, Error>(format!("${}", placeholder_values.len()))
  152. }
  153. PlaceholderValue::Set(mut values) => {
  154. let start_size = placeholder_values.len();
  155. placeholder_values.append(&mut values);
  156. let placeholders = (start_size + 1..=placeholder_values.len())
  157. .map(|i| format!("${i}"))
  158. .collect::<Vec<_>>()
  159. .join(", ");
  160. Ok(placeholders)
  161. }
  162. }
  163. }
  164. SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
  165. })
  166. .collect::<Result<Vec<String>, _>>()?
  167. .join(" ");
  168. Ok((sql, placeholder_values))
  169. }
  170. /// Binds a given placeholder to a value.
  171. #[inline]
  172. pub fn bind<C, V>(mut self, name: C, value: V) -> Self
  173. where
  174. C: ToString,
  175. V: Into<Value>,
  176. {
  177. let name = name.to_string();
  178. let value = value.into();
  179. let value: PlaceholderValue = value.into();
  180. for part in self.parts.iter_mut() {
  181. if let SqlPart::Placeholder(part_name, part_value) = part {
  182. if **part_name == *name.as_str() {
  183. *part_value = Some(value.clone());
  184. }
  185. }
  186. }
  187. self
  188. }
  189. /// Binds a single variable with a vector.
  190. ///
  191. /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
  192. /// :foo2` and binds each value from the value vector accordingly.
  193. #[inline]
  194. pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
  195. where
  196. C: ToString,
  197. V: Into<Value>,
  198. {
  199. let name = name.to_string();
  200. let value: PlaceholderValue = value
  201. .into_iter()
  202. .map(|x| x.into())
  203. .collect::<Vec<Value>>()
  204. .into();
  205. for part in self.parts.iter_mut() {
  206. if let SqlPart::Placeholder(part_name, part_value) = part {
  207. if **part_name == *name.as_str() {
  208. *part_value = Some(value.clone());
  209. }
  210. }
  211. }
  212. self
  213. }
  214. /// Executes a query and returns the affected rows
  215. pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
  216. where
  217. C: DatabaseExecutor,
  218. {
  219. conn.pluck(self).await
  220. }
  221. /// Executes a query and returns the affected rows
  222. pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
  223. where
  224. C: DatabaseExecutor,
  225. {
  226. conn.batch(self).await
  227. }
  228. /// Executes a query and returns the affected rows
  229. pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
  230. where
  231. C: DatabaseExecutor,
  232. {
  233. conn.execute(self).await
  234. }
  235. /// Runs the query and returns the first row or None
  236. pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
  237. where
  238. C: DatabaseExecutor,
  239. {
  240. conn.fetch_one(self).await
  241. }
  242. /// Runs the query and returns the first row or None
  243. pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
  244. where
  245. C: DatabaseExecutor,
  246. {
  247. conn.fetch_all(self).await
  248. }
  249. }
  250. /// Creates a new query statement
  251. #[inline(always)]
  252. pub fn query(sql: &str) -> Result<Statement, Error> {
  253. Statement::new(sql).map_err(|e| Error::Database(Box::new(e)))
  254. }