stmt.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. //! Stataments mod
  2. use std::collections::HashMap;
  3. use std::sync::{Arc, RwLock};
  4. use cdk_common::database::Error;
  5. use once_cell::sync::Lazy;
  6. use crate::database::DatabaseExecutor;
  7. use crate::value::Value;
  8. /// The Column type
  9. pub type Column = Value;
  10. /// Expected response type for a given SQL statement
  11. #[derive(Debug, Clone, Copy, Default)]
  12. pub enum ExpectedSqlResponse {
  13. /// A single row
  14. SingleRow,
  15. /// All the rows that matches a query
  16. #[default]
  17. ManyRows,
  18. /// How many rows were affected by the query
  19. AffectedRows,
  20. /// Return the first column of the first row
  21. Pluck,
  22. /// Batch
  23. Batch,
  24. }
  25. /// Part value
  26. #[derive(Debug, Clone)]
  27. pub enum PlaceholderValue {
  28. /// Value
  29. Value(Value),
  30. /// Set
  31. Set(Vec<Value>),
  32. }
  33. impl From<Value> for PlaceholderValue {
  34. fn from(value: Value) -> Self {
  35. PlaceholderValue::Value(value)
  36. }
  37. }
  38. impl From<Vec<Value>> for PlaceholderValue {
  39. fn from(value: Vec<Value>) -> Self {
  40. PlaceholderValue::Set(value)
  41. }
  42. }
  43. /// SQL Part
  44. #[derive(Debug, Clone)]
  45. pub enum SqlPart {
  46. /// Raw SQL statement
  47. Raw(Arc<str>),
  48. /// Placeholder
  49. Placeholder(Arc<str>, Option<PlaceholderValue>),
  50. }
  51. /// SQL parser error
  52. #[derive(Debug, PartialEq, thiserror::Error)]
  53. pub enum SqlParseError {
  54. /// Invalid SQL
  55. #[error("Unterminated String literal")]
  56. UnterminatedStringLiteral,
  57. /// Invalid placeholder name
  58. #[error("Invalid placeholder name")]
  59. InvalidPlaceholder,
  60. }
  61. /// Rudimentary SQL parser.
  62. ///
  63. /// This function does not validate the SQL statement, it only extracts the placeholder to be
  64. /// database agnostic.
  65. pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
  66. let mut parts = Vec::new();
  67. let mut current = String::new();
  68. let mut chars = input.chars().peekable();
  69. while let Some(&c) = chars.peek() {
  70. match c {
  71. '\'' | '"' => {
  72. // Start of string literal
  73. let quote = c;
  74. current.push(chars.next().unwrap());
  75. let mut closed = false;
  76. while let Some(&next) = chars.peek() {
  77. current.push(chars.next().unwrap());
  78. if next == quote {
  79. if chars.peek() == Some(&quote) {
  80. // Escaped quote (e.g. '' inside strings)
  81. current.push(chars.next().unwrap());
  82. } else {
  83. closed = true;
  84. break;
  85. }
  86. }
  87. }
  88. if !closed {
  89. return Err(SqlParseError::UnterminatedStringLiteral);
  90. }
  91. }
  92. ':' => {
  93. // Flush current raw SQL
  94. if !current.is_empty() {
  95. parts.push(SqlPart::Raw(current.clone().into()));
  96. current.clear();
  97. }
  98. chars.next(); // consume ':'
  99. let mut name = String::new();
  100. while let Some(&next) = chars.peek() {
  101. if next.is_alphanumeric() || next == '_' {
  102. name.push(chars.next().unwrap());
  103. } else {
  104. break;
  105. }
  106. }
  107. if name.is_empty() {
  108. return Err(SqlParseError::InvalidPlaceholder);
  109. }
  110. parts.push(SqlPart::Placeholder(name.into(), None));
  111. }
  112. _ => {
  113. current.push(chars.next().unwrap());
  114. }
  115. }
  116. }
  117. if !current.is_empty() {
  118. parts.push(SqlPart::Raw(current.into()));
  119. }
  120. Ok(parts)
  121. }
  122. type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
  123. /// Sql message
  124. #[derive(Debug, Default)]
  125. pub struct Statement {
  126. cache: Arc<RwLock<Cache>>,
  127. cached_sql: Option<Arc<str>>,
  128. sql: Option<String>,
  129. /// The SQL statement
  130. pub parts: Vec<SqlPart>,
  131. /// The expected response type
  132. pub expected_response: ExpectedSqlResponse,
  133. }
  134. impl Statement {
  135. /// Creates a new statement
  136. fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
  137. let parsed = cache
  138. .read()
  139. .map(|cache| cache.get(sql).cloned())
  140. .ok()
  141. .flatten();
  142. if let Some((parts, cached_sql)) = parsed {
  143. Ok(Self {
  144. parts,
  145. cached_sql,
  146. sql: None,
  147. cache,
  148. ..Default::default()
  149. })
  150. } else {
  151. let parts = split_sql_parts(sql)?;
  152. if let Ok(mut cache) = cache.write() {
  153. cache.insert(sql.to_owned(), (parts.clone(), None));
  154. } else {
  155. tracing::warn!("Failed to acquire write lock for SQL statement cache");
  156. }
  157. Ok(Self {
  158. parts,
  159. sql: Some(sql.to_owned()),
  160. cache,
  161. ..Default::default()
  162. })
  163. }
  164. }
  165. /// Convert Statement into a SQL statement and the list of placeholders
  166. ///
  167. /// By default it converts the statement into placeholder using $1..$n placeholders which seems
  168. /// to be more widely supported, although it can be reimplemented with other formats since part
  169. /// is public
  170. pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
  171. if let Some(cached_sql) = self.cached_sql {
  172. let sql = cached_sql.to_string();
  173. let values = self
  174. .parts
  175. .into_iter()
  176. .map(|x| match x {
  177. SqlPart::Placeholder(name, value) => {
  178. match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
  179. PlaceholderValue::Value(value) => Ok(vec![value]),
  180. PlaceholderValue::Set(values) => Ok(values),
  181. }
  182. }
  183. SqlPart::Raw(_) => Ok(vec![]),
  184. })
  185. .collect::<Result<Vec<_>, Error>>()?
  186. .into_iter()
  187. .flatten()
  188. .collect::<Vec<_>>();
  189. return Ok((sql, values));
  190. }
  191. let mut placeholder_values = Vec::new();
  192. let mut can_be_cached = true;
  193. let sql = self
  194. .parts
  195. .into_iter()
  196. .map(|x| match x {
  197. SqlPart::Placeholder(name, value) => {
  198. match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
  199. PlaceholderValue::Value(value) => {
  200. placeholder_values.push(value);
  201. Ok::<_, Error>(format!("${}", placeholder_values.len()))
  202. }
  203. PlaceholderValue::Set(mut values) => {
  204. can_be_cached = false;
  205. let start_size = placeholder_values.len();
  206. placeholder_values.append(&mut values);
  207. let placeholders = (start_size + 1..=placeholder_values.len())
  208. .map(|i| format!("${i}"))
  209. .collect::<Vec<_>>()
  210. .join(", ");
  211. Ok(placeholders)
  212. }
  213. }
  214. }
  215. SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
  216. })
  217. .collect::<Result<Vec<String>, _>>()?
  218. .join(" ");
  219. if can_be_cached {
  220. if let Some(original_sql) = self.sql {
  221. let _ = self.cache.write().map(|mut cache| {
  222. if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
  223. *cached_sql = Some(sql.clone().into());
  224. }
  225. });
  226. }
  227. }
  228. Ok((sql, placeholder_values))
  229. }
  230. /// Binds a given placeholder to a value.
  231. #[inline]
  232. pub fn bind<C, V>(mut self, name: C, value: V) -> Self
  233. where
  234. C: ToString,
  235. V: Into<Value>,
  236. {
  237. let name = name.to_string();
  238. let value = value.into();
  239. let value: PlaceholderValue = value.into();
  240. for part in self.parts.iter_mut() {
  241. if let SqlPart::Placeholder(part_name, part_value) = part {
  242. if **part_name == *name.as_str() {
  243. *part_value = Some(value.clone());
  244. }
  245. }
  246. }
  247. self
  248. }
  249. /// Binds a single variable with a vector.
  250. ///
  251. /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
  252. /// :foo2` and binds each value from the value vector accordingly.
  253. #[inline]
  254. pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
  255. where
  256. C: ToString,
  257. V: Into<Value>,
  258. {
  259. let name = name.to_string();
  260. let value: PlaceholderValue = value
  261. .into_iter()
  262. .map(|x| x.into())
  263. .collect::<Vec<Value>>()
  264. .into();
  265. for part in self.parts.iter_mut() {
  266. if let SqlPart::Placeholder(part_name, part_value) = part {
  267. if **part_name == *name.as_str() {
  268. *part_value = Some(value.clone());
  269. }
  270. }
  271. }
  272. self
  273. }
  274. /// Executes a query and returns the affected rows
  275. pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
  276. where
  277. C: DatabaseExecutor,
  278. {
  279. conn.pluck(self).await
  280. }
  281. /// Executes a query and returns the affected rows
  282. pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
  283. where
  284. C: DatabaseExecutor,
  285. {
  286. conn.batch(self).await
  287. }
  288. /// Executes a query and returns the affected rows
  289. pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
  290. where
  291. C: DatabaseExecutor,
  292. {
  293. conn.execute(self).await
  294. }
  295. /// Runs the query and returns the first row or None
  296. pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
  297. where
  298. C: DatabaseExecutor,
  299. {
  300. conn.fetch_one(self).await
  301. }
  302. /// Runs the query and returns the first row or None
  303. pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
  304. where
  305. C: DatabaseExecutor,
  306. {
  307. conn.fetch_all(self).await
  308. }
  309. }
  310. /// Creates a new query statement
  311. #[inline(always)]
  312. pub fn query(sql: &str) -> Result<Statement, Error> {
  313. static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
  314. Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
  315. }