stmt.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. //! Stataments mod
  2. use std::collections::HashMap;
  3. use std::sync::{Arc, RwLock};
  4. use cdk_common::database::Error;
  5. use once_cell::sync::Lazy;
  6. use crate::database::DatabaseExecutor;
  7. use crate::value::Value;
  8. /// The Column type
  9. pub type Column = Value;
  10. /// Expected response type for a given SQL statement
  11. #[derive(Debug, Clone, Copy, Default)]
  12. pub enum ExpectedSqlResponse {
  13. /// A single row
  14. SingleRow,
  15. /// All the rows that matches a query
  16. #[default]
  17. ManyRows,
  18. /// How many rows were affected by the query
  19. AffectedRows,
  20. /// Return the first column of the first row
  21. Pluck,
  22. /// Batch
  23. Batch,
  24. }
  25. /// Part value
  26. #[derive(Debug, Clone)]
  27. pub enum PlaceholderValue {
  28. /// Value
  29. Value(Value),
  30. /// Set
  31. Set(Vec<Value>),
  32. }
  33. impl From<Value> for PlaceholderValue {
  34. fn from(value: Value) -> Self {
  35. PlaceholderValue::Value(value)
  36. }
  37. }
  38. impl From<Vec<Value>> for PlaceholderValue {
  39. fn from(value: Vec<Value>) -> Self {
  40. PlaceholderValue::Set(value)
  41. }
  42. }
  43. /// SQL Part
  44. #[derive(Debug, Clone)]
  45. pub enum SqlPart {
  46. /// Raw SQL statement
  47. Raw(Arc<str>),
  48. /// Placeholder
  49. Placeholder(Arc<str>, Option<PlaceholderValue>),
  50. }
  51. /// SQL parser error
  52. #[derive(Debug, PartialEq, thiserror::Error)]
  53. pub enum SqlParseError {
  54. /// Invalid SQL
  55. #[error("Unterminated String literal")]
  56. UnterminatedStringLiteral,
  57. /// Invalid placeholder name
  58. #[error("Invalid placeholder name")]
  59. InvalidPlaceholder,
  60. }
  61. /// Rudimentary SQL parser.
  62. ///
  63. /// This function does not validate the SQL statement, it only extracts the placeholder to be
  64. /// database agnostic.
  65. pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
  66. let mut parts = Vec::new();
  67. let mut current = String::new();
  68. let mut chars = input.chars().peekable();
  69. while let Some(&c) = chars.peek() {
  70. match c {
  71. '\'' | '"' => {
  72. // Start of string literal
  73. let quote = c;
  74. current.push(chars.next().unwrap());
  75. let mut closed = false;
  76. while let Some(&next) = chars.peek() {
  77. current.push(chars.next().unwrap());
  78. if next == quote {
  79. if chars.peek() == Some(&quote) {
  80. // Escaped quote (e.g. '' inside strings)
  81. current.push(chars.next().unwrap());
  82. } else {
  83. closed = true;
  84. break;
  85. }
  86. }
  87. }
  88. if !closed {
  89. return Err(SqlParseError::UnterminatedStringLiteral);
  90. }
  91. }
  92. '-' => {
  93. current.push(chars.next().unwrap());
  94. if chars.peek() == Some(&'-') {
  95. while let Some(&next) = chars.peek() {
  96. current.push(chars.next().unwrap());
  97. if next == '\n' {
  98. break;
  99. }
  100. }
  101. }
  102. }
  103. ':' => {
  104. // Flush current raw SQL
  105. if !current.is_empty() {
  106. parts.push(SqlPart::Raw(current.clone().into()));
  107. current.clear();
  108. }
  109. chars.next(); // consume ':'
  110. let mut name = String::new();
  111. while let Some(&next) = chars.peek() {
  112. if next.is_alphanumeric() || next == '_' {
  113. name.push(chars.next().unwrap());
  114. } else {
  115. break;
  116. }
  117. }
  118. if name.is_empty() {
  119. return Err(SqlParseError::InvalidPlaceholder);
  120. }
  121. parts.push(SqlPart::Placeholder(name.into(), None));
  122. }
  123. _ => {
  124. current.push(chars.next().unwrap());
  125. }
  126. }
  127. }
  128. if !current.is_empty() {
  129. parts.push(SqlPart::Raw(current.into()));
  130. }
  131. Ok(parts)
  132. }
  133. type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
  134. /// Sql message
  135. #[derive(Debug, Default)]
  136. pub struct Statement {
  137. cache: Arc<RwLock<Cache>>,
  138. cached_sql: Option<Arc<str>>,
  139. sql: Option<String>,
  140. /// The SQL statement
  141. pub parts: Vec<SqlPart>,
  142. /// The expected response type
  143. pub expected_response: ExpectedSqlResponse,
  144. }
  145. impl Statement {
  146. /// Creates a new statement
  147. fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
  148. let parsed = cache
  149. .read()
  150. .map(|cache| cache.get(sql).cloned())
  151. .ok()
  152. .flatten();
  153. if let Some((parts, cached_sql)) = parsed {
  154. Ok(Self {
  155. parts,
  156. cached_sql,
  157. sql: None,
  158. cache,
  159. ..Default::default()
  160. })
  161. } else {
  162. let parts = split_sql_parts(sql)?;
  163. if let Ok(mut cache) = cache.write() {
  164. cache.insert(sql.to_owned(), (parts.clone(), None));
  165. } else {
  166. tracing::warn!("Failed to acquire write lock for SQL statement cache");
  167. }
  168. Ok(Self {
  169. parts,
  170. sql: Some(sql.to_owned()),
  171. cache,
  172. ..Default::default()
  173. })
  174. }
  175. }
  176. /// Convert Statement into a SQL statement and the list of placeholders
  177. ///
  178. /// By default it converts the statement into placeholder using $1..$n placeholders which seems
  179. /// to be more widely supported, although it can be reimplemented with other formats since part
  180. /// is public
  181. pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
  182. if let Some(cached_sql) = self.cached_sql {
  183. let sql = cached_sql.to_string();
  184. let values = self
  185. .parts
  186. .into_iter()
  187. .map(|x| match x {
  188. SqlPart::Placeholder(name, value) => {
  189. match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
  190. PlaceholderValue::Value(value) => Ok(vec![value]),
  191. PlaceholderValue::Set(values) => Ok(values),
  192. }
  193. }
  194. SqlPart::Raw(_) => Ok(vec![]),
  195. })
  196. .collect::<Result<Vec<_>, Error>>()?
  197. .into_iter()
  198. .flatten()
  199. .collect::<Vec<_>>();
  200. return Ok((sql, values));
  201. }
  202. let mut placeholder_values = Vec::new();
  203. let mut can_be_cached = true;
  204. let sql = self
  205. .parts
  206. .into_iter()
  207. .map(|x| match x {
  208. SqlPart::Placeholder(name, value) => {
  209. match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
  210. PlaceholderValue::Value(value) => {
  211. placeholder_values.push(value);
  212. Ok::<_, Error>(format!("${}", placeholder_values.len()))
  213. }
  214. PlaceholderValue::Set(mut values) => {
  215. can_be_cached = false;
  216. let start_size = placeholder_values.len();
  217. placeholder_values.append(&mut values);
  218. let placeholders = (start_size + 1..=placeholder_values.len())
  219. .map(|i| format!("${i}"))
  220. .collect::<Vec<_>>()
  221. .join(", ");
  222. Ok(placeholders)
  223. }
  224. }
  225. }
  226. SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
  227. })
  228. .collect::<Result<Vec<String>, _>>()?
  229. .join(" ");
  230. if can_be_cached {
  231. if let Some(original_sql) = self.sql {
  232. let _ = self.cache.write().map(|mut cache| {
  233. if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
  234. *cached_sql = Some(sql.clone().into());
  235. }
  236. });
  237. }
  238. }
  239. Ok((sql, placeholder_values))
  240. }
  241. /// Binds a given placeholder to a value.
  242. #[inline]
  243. pub fn bind<C, V>(mut self, name: C, value: V) -> Self
  244. where
  245. C: ToString,
  246. V: Into<Value>,
  247. {
  248. let name = name.to_string();
  249. let value = value.into();
  250. let value: PlaceholderValue = value.into();
  251. for part in self.parts.iter_mut() {
  252. if let SqlPart::Placeholder(part_name, part_value) = part {
  253. if **part_name == *name.as_str() {
  254. *part_value = Some(value.clone());
  255. }
  256. }
  257. }
  258. self
  259. }
  260. /// Binds a single variable with a vector.
  261. ///
  262. /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
  263. /// :foo2` and binds each value from the value vector accordingly.
  264. #[inline]
  265. pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
  266. where
  267. C: ToString,
  268. V: Into<Value>,
  269. {
  270. let name = name.to_string();
  271. let value: PlaceholderValue = value
  272. .into_iter()
  273. .map(|x| x.into())
  274. .collect::<Vec<Value>>()
  275. .into();
  276. for part in self.parts.iter_mut() {
  277. if let SqlPart::Placeholder(part_name, part_value) = part {
  278. if **part_name == *name.as_str() {
  279. *part_value = Some(value.clone());
  280. }
  281. }
  282. }
  283. self
  284. }
  285. /// Executes a query and returns the affected rows
  286. pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
  287. where
  288. C: DatabaseExecutor,
  289. {
  290. conn.pluck(self).await
  291. }
  292. /// Executes a query and returns the affected rows
  293. pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
  294. where
  295. C: DatabaseExecutor,
  296. {
  297. conn.batch(self).await
  298. }
  299. /// Executes a query and returns the affected rows
  300. pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
  301. where
  302. C: DatabaseExecutor,
  303. {
  304. conn.execute(self).await
  305. }
  306. /// Runs the query and returns the first row or None
  307. pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
  308. where
  309. C: DatabaseExecutor,
  310. {
  311. conn.fetch_one(self).await
  312. }
  313. /// Runs the query and returns the first row or None
  314. pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
  315. where
  316. C: DatabaseExecutor,
  317. {
  318. conn.fetch_all(self).await
  319. }
  320. }
  321. /// Creates a new query statement
  322. #[inline(always)]
  323. pub fn query(sql: &str) -> Result<Statement, Error> {
  324. static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
  325. Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
  326. }