stmt.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. //! Stataments mod
  2. use std::collections::HashMap;
  3. use std::sync::{Arc, RwLock};
  4. use cdk_common::database::Error;
  5. use once_cell::sync::Lazy;
  6. use crate::database::DatabaseExecutor;
  7. use crate::value::Value;
  8. /// The Column type
  9. pub type Column = Value;
  10. /// Expected response type for a given SQL statement
  11. #[derive(Debug, Clone, Copy, Default)]
  12. pub enum ExpectedSqlResponse {
  13. /// A single row
  14. SingleRow,
  15. /// All the rows that matches a query
  16. #[default]
  17. ManyRows,
  18. /// How many rows were affected by the query
  19. AffectedRows,
  20. /// Return the first column of the first row
  21. Pluck,
  22. /// Batch
  23. Batch,
  24. }
  25. /// Part value
  26. #[derive(Debug, Clone)]
  27. pub enum PlaceholderValue {
  28. /// Value
  29. Value(Value),
  30. /// Set
  31. Set(Vec<Value>),
  32. }
  33. impl From<Value> for PlaceholderValue {
  34. fn from(value: Value) -> Self {
  35. PlaceholderValue::Value(value)
  36. }
  37. }
  38. impl From<Vec<Value>> for PlaceholderValue {
  39. fn from(value: Vec<Value>) -> Self {
  40. PlaceholderValue::Set(value)
  41. }
  42. }
  43. /// SQL Part
  44. #[derive(Debug, Clone)]
  45. pub enum SqlPart {
  46. /// Raw SQL statement
  47. Raw(Arc<str>),
  48. /// Placeholder
  49. Placeholder(Arc<str>, Option<PlaceholderValue>),
  50. }
  51. /// SQL parser error
  52. #[derive(Debug, PartialEq, thiserror::Error)]
  53. pub enum SqlParseError {
  54. /// Invalid SQL
  55. #[error("Unterminated String literal")]
  56. UnterminatedStringLiteral,
  57. /// Invalid placeholder name
  58. #[error("Invalid placeholder name")]
  59. InvalidPlaceholder,
  60. }
  61. /// Rudimentary SQL parser.
  62. ///
  63. /// This function does not validate the SQL statement, it only extracts the placeholder to be
  64. /// database agnostic.
  65. pub fn split_sql_parts(input: &str) -> Result<Vec<SqlPart>, SqlParseError> {
  66. let mut parts = Vec::new();
  67. let mut current = String::new();
  68. let mut chars = input.chars().peekable();
  69. while let Some(&c) = chars.peek() {
  70. match c {
  71. '\'' | '"' => {
  72. // Start of string literal
  73. let quote = c;
  74. current.push(
  75. chars
  76. .next()
  77. .ok_or(SqlParseError::UnterminatedStringLiteral)?,
  78. );
  79. let mut closed = false;
  80. while let Some(&next) = chars.peek() {
  81. current.push(
  82. chars
  83. .next()
  84. .ok_or(SqlParseError::UnterminatedStringLiteral)?,
  85. );
  86. if next == quote {
  87. if chars.peek() == Some(&quote) {
  88. // Escaped quote (e.g. '' inside strings)
  89. current.push(
  90. chars
  91. .next()
  92. .ok_or(SqlParseError::UnterminatedStringLiteral)?,
  93. );
  94. } else {
  95. closed = true;
  96. break;
  97. }
  98. }
  99. }
  100. if !closed {
  101. return Err(SqlParseError::UnterminatedStringLiteral);
  102. }
  103. }
  104. '-' => {
  105. current.push(
  106. chars
  107. .next()
  108. .ok_or(SqlParseError::UnterminatedStringLiteral)?,
  109. );
  110. if chars.peek() == Some(&'-') {
  111. while let Some(&next) = chars.peek() {
  112. current.push(
  113. chars
  114. .next()
  115. .ok_or(SqlParseError::UnterminatedStringLiteral)?,
  116. );
  117. if next == '\n' {
  118. break;
  119. }
  120. }
  121. }
  122. }
  123. ':' => {
  124. // Flush current raw SQL
  125. if !current.is_empty() {
  126. parts.push(SqlPart::Raw(current.clone().into()));
  127. current.clear();
  128. }
  129. chars.next(); // consume ':'
  130. let mut name = String::new();
  131. while let Some(&next) = chars.peek() {
  132. if next.is_alphanumeric() || next == '_' {
  133. name.push(
  134. chars
  135. .next()
  136. .ok_or(SqlParseError::UnterminatedStringLiteral)?,
  137. );
  138. } else {
  139. break;
  140. }
  141. }
  142. if name.is_empty() {
  143. return Err(SqlParseError::InvalidPlaceholder);
  144. }
  145. parts.push(SqlPart::Placeholder(name.into(), None));
  146. }
  147. _ => {
  148. current.push(
  149. chars
  150. .next()
  151. .ok_or(SqlParseError::UnterminatedStringLiteral)?,
  152. );
  153. }
  154. }
  155. }
  156. if !current.is_empty() {
  157. parts.push(SqlPart::Raw(current.into()));
  158. }
  159. Ok(parts)
  160. }
  161. type Cache = HashMap<String, (Vec<SqlPart>, Option<Arc<str>>)>;
  162. /// Sql message
  163. #[derive(Debug, Default)]
  164. pub struct Statement {
  165. cache: Arc<RwLock<Cache>>,
  166. cached_sql: Option<Arc<str>>,
  167. sql: Option<String>,
  168. /// The SQL statement
  169. pub parts: Vec<SqlPart>,
  170. /// The expected response type
  171. pub expected_response: ExpectedSqlResponse,
  172. }
  173. impl Statement {
  174. /// Creates a new statement
  175. fn new(sql: &str, cache: Arc<RwLock<Cache>>) -> Result<Self, SqlParseError> {
  176. let parsed = cache
  177. .read()
  178. .map(|cache| cache.get(sql).cloned())
  179. .ok()
  180. .flatten();
  181. if let Some((parts, cached_sql)) = parsed {
  182. Ok(Self {
  183. parts,
  184. cached_sql,
  185. sql: None,
  186. cache,
  187. ..Default::default()
  188. })
  189. } else {
  190. let parts = split_sql_parts(sql)?;
  191. if let Ok(mut cache) = cache.write() {
  192. cache.insert(sql.to_owned(), (parts.clone(), None));
  193. } else {
  194. tracing::warn!("Failed to acquire write lock for SQL statement cache");
  195. }
  196. Ok(Self {
  197. parts,
  198. sql: Some(sql.to_owned()),
  199. cache,
  200. ..Default::default()
  201. })
  202. }
  203. }
  204. /// Convert Statement into a SQL statement and the list of placeholders
  205. ///
  206. /// By default it converts the statement into placeholder using $1..$n placeholders which seems
  207. /// to be more widely supported, although it can be reimplemented with other formats since part
  208. /// is public
  209. pub fn to_sql(self) -> Result<(String, Vec<Value>), Error> {
  210. if let Some(cached_sql) = self.cached_sql {
  211. let sql = cached_sql.to_string();
  212. let values = self
  213. .parts
  214. .into_iter()
  215. .map(|x| match x {
  216. SqlPart::Placeholder(name, value) => {
  217. match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
  218. PlaceholderValue::Value(value) => Ok(vec![value]),
  219. PlaceholderValue::Set(values) => Ok(values),
  220. }
  221. }
  222. SqlPart::Raw(_) => Ok(vec![]),
  223. })
  224. .collect::<Result<Vec<_>, Error>>()?
  225. .into_iter()
  226. .flatten()
  227. .collect::<Vec<_>>();
  228. return Ok((sql, values));
  229. }
  230. let mut placeholder_values = Vec::new();
  231. let mut can_be_cached = true;
  232. let sql = self
  233. .parts
  234. .into_iter()
  235. .map(|x| match x {
  236. SqlPart::Placeholder(name, value) => {
  237. match value.ok_or(Error::MissingPlaceholder(name.to_string()))? {
  238. PlaceholderValue::Value(value) => {
  239. placeholder_values.push(value);
  240. Ok::<_, Error>(format!("${}", placeholder_values.len()))
  241. }
  242. PlaceholderValue::Set(mut values) => {
  243. can_be_cached = false;
  244. let start_size = placeholder_values.len();
  245. placeholder_values.append(&mut values);
  246. let placeholders = (start_size + 1..=placeholder_values.len())
  247. .map(|i| format!("${i}"))
  248. .collect::<Vec<_>>()
  249. .join(", ");
  250. Ok(placeholders)
  251. }
  252. }
  253. }
  254. SqlPart::Raw(raw) => Ok(raw.trim().to_string()),
  255. })
  256. .collect::<Result<Vec<String>, _>>()?
  257. .join(" ");
  258. if can_be_cached {
  259. if let Some(original_sql) = self.sql {
  260. let _ = self.cache.write().map(|mut cache| {
  261. if let Some((_, cached_sql)) = cache.get_mut(&original_sql) {
  262. *cached_sql = Some(sql.clone().into());
  263. }
  264. });
  265. }
  266. }
  267. Ok((sql, placeholder_values))
  268. }
  269. /// Binds a given placeholder to a value.
  270. #[inline]
  271. pub fn bind<C, V>(mut self, name: C, value: V) -> Self
  272. where
  273. C: ToString,
  274. V: Into<Value>,
  275. {
  276. let name = name.to_string();
  277. let value = value.into();
  278. let value: PlaceholderValue = value.into();
  279. for part in self.parts.iter_mut() {
  280. if let SqlPart::Placeholder(part_name, part_value) = part {
  281. if **part_name == *name.as_str() {
  282. *part_value = Some(value.clone());
  283. }
  284. }
  285. }
  286. self
  287. }
  288. /// Binds a single variable with a vector.
  289. ///
  290. /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
  291. /// :foo2` and binds each value from the value vector accordingly.
  292. #[inline]
  293. pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
  294. where
  295. C: ToString,
  296. V: Into<Value>,
  297. {
  298. let name = name.to_string();
  299. let value: PlaceholderValue = value
  300. .into_iter()
  301. .map(|x| x.into())
  302. .collect::<Vec<Value>>()
  303. .into();
  304. for part in self.parts.iter_mut() {
  305. if let SqlPart::Placeholder(part_name, part_value) = part {
  306. if **part_name == *name.as_str() {
  307. *part_value = Some(value.clone());
  308. }
  309. }
  310. }
  311. self
  312. }
  313. /// Executes a query and returns the affected rows
  314. pub async fn pluck<C>(self, conn: &C) -> Result<Option<Value>, Error>
  315. where
  316. C: DatabaseExecutor,
  317. {
  318. conn.pluck(self).await
  319. }
  320. /// Executes a query and returns the affected rows
  321. pub async fn batch<C>(self, conn: &C) -> Result<(), Error>
  322. where
  323. C: DatabaseExecutor,
  324. {
  325. conn.batch(self).await
  326. }
  327. /// Executes a query and returns the affected rows
  328. pub async fn execute<C>(self, conn: &C) -> Result<usize, Error>
  329. where
  330. C: DatabaseExecutor,
  331. {
  332. conn.execute(self).await
  333. }
  334. /// Runs the query and returns the first row or None
  335. pub async fn fetch_one<C>(self, conn: &C) -> Result<Option<Vec<Column>>, Error>
  336. where
  337. C: DatabaseExecutor,
  338. {
  339. conn.fetch_one(self).await
  340. }
  341. /// Runs the query and returns the first row or None
  342. pub async fn fetch_all<C>(self, conn: &C) -> Result<Vec<Vec<Column>>, Error>
  343. where
  344. C: DatabaseExecutor,
  345. {
  346. conn.fetch_all(self).await
  347. }
  348. }
  349. /// Creates a new query statement
  350. #[inline(always)]
  351. pub fn query(sql: &str) -> Result<Statement, Error> {
  352. static CACHE: Lazy<Arc<RwLock<Cache>>> = Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
  353. Statement::new(sql, CACHE.clone()).map_err(|e| Error::Database(Box::new(e)))
  354. }