stmt.rs 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. use std::collections::HashMap;
  2. use rusqlite::{self, CachedStatement};
  3. use crate::common::SqliteConnectionManager;
  4. use crate::pool::PooledResource;
  5. /// The Value coming from SQLite
  6. pub type Value = rusqlite::types::Value;
  7. /// The Column type
  8. pub type Column = Value;
  9. /// Expected response type for a given SQL statement
  10. #[derive(Debug, Clone, Copy, Default)]
  11. pub enum ExpectedSqlResponse {
  12. /// A single row
  13. SingleRow,
  14. /// All the rows that matches a query
  15. #[default]
  16. ManyRows,
  17. /// How many rows were affected by the query
  18. AffectedRows,
  19. /// Return the first column of the first row
  20. Pluck,
  21. }
  22. /// Sql message
  23. #[derive(Default, Debug)]
  24. pub struct Statement {
  25. /// The SQL statement
  26. pub sql: String,
  27. /// The list of arguments for the placeholders. It only supports named arguments for simplicity
  28. /// sake
  29. pub args: HashMap<String, Value>,
  30. /// The expected response type
  31. pub expected_response: ExpectedSqlResponse,
  32. }
  33. impl Statement {
  34. /// Creates a new statement
  35. pub fn new<T>(sql: T) -> Self
  36. where
  37. T: ToString,
  38. {
  39. Self {
  40. sql: sql.to_string(),
  41. ..Default::default()
  42. }
  43. }
  44. /// Binds a given placeholder to a value.
  45. #[inline]
  46. pub fn bind<C, V>(mut self, name: C, value: V) -> Self
  47. where
  48. C: ToString,
  49. V: Into<Value>,
  50. {
  51. self.args.insert(name.to_string(), value.into());
  52. self
  53. }
  54. /// Binds a single variable with a vector.
  55. ///
  56. /// This will rewrite the function from `:foo` (where value is vec![1, 2, 3]) to `:foo0, :foo1,
  57. /// :foo2` and binds each value from the value vector accordingly.
  58. #[inline]
  59. pub fn bind_vec<C, V>(mut self, name: C, value: Vec<V>) -> Self
  60. where
  61. C: ToString,
  62. V: Into<Value>,
  63. {
  64. let mut new_sql = String::with_capacity(self.sql.len());
  65. let target = name.to_string();
  66. let mut i = 0;
  67. let placeholders = value
  68. .into_iter()
  69. .enumerate()
  70. .map(|(key, value)| {
  71. let key = format!("{target}{key}");
  72. self.args.insert(key.clone(), value.into());
  73. key
  74. })
  75. .collect::<Vec<_>>()
  76. .join(",");
  77. while let Some(pos) = self.sql[i..].find(&target) {
  78. let abs_pos = i + pos;
  79. let after = abs_pos + target.len();
  80. let is_word_boundary = self.sql[after..]
  81. .chars()
  82. .next()
  83. .map_or(true, |c| !c.is_alphanumeric() && c != '_');
  84. if is_word_boundary {
  85. new_sql.push_str(&self.sql[i..abs_pos]);
  86. new_sql.push_str(&placeholders);
  87. i = after;
  88. } else {
  89. new_sql.push_str(&self.sql[i..=abs_pos]);
  90. i = abs_pos + 1;
  91. }
  92. }
  93. new_sql.push_str(&self.sql[i..]);
  94. self.sql = new_sql;
  95. self
  96. }
  97. fn get_stmt(
  98. self,
  99. conn: &PooledResource<SqliteConnectionManager>,
  100. ) -> rusqlite::Result<CachedStatement<'_>> {
  101. let mut stmt = conn.prepare_cached(&self.sql)?;
  102. for (name, value) in self.args {
  103. let index = stmt
  104. .parameter_index(&name)
  105. .map_err(|_| rusqlite::Error::InvalidColumnName(name.clone()))?
  106. .ok_or(rusqlite::Error::InvalidColumnName(name))?;
  107. stmt.raw_bind_parameter(index, value)?;
  108. }
  109. Ok(stmt)
  110. }
  111. /// Executes a query and returns the affected rows
  112. pub fn plunk(
  113. self,
  114. conn: &PooledResource<SqliteConnectionManager>,
  115. ) -> rusqlite::Result<Option<Value>> {
  116. let mut stmt = self.get_stmt(conn)?;
  117. let mut rows = stmt.raw_query();
  118. rows.next()?.map(|row| row.get(0)).transpose()
  119. }
  120. /// Executes a query and returns the affected rows
  121. pub fn execute(
  122. self,
  123. conn: &PooledResource<SqliteConnectionManager>,
  124. ) -> rusqlite::Result<usize> {
  125. self.get_stmt(conn)?.raw_execute()
  126. }
  127. /// Runs the query and returns the first row or None
  128. pub fn fetch_one(
  129. self,
  130. conn: &PooledResource<SqliteConnectionManager>,
  131. ) -> rusqlite::Result<Option<Vec<Column>>> {
  132. let mut stmt = self.get_stmt(conn)?;
  133. let columns = stmt.column_count();
  134. let mut rows = stmt.raw_query();
  135. rows.next()?
  136. .map(|row| {
  137. (0..columns)
  138. .map(|i| row.get(i))
  139. .collect::<Result<Vec<_>, _>>()
  140. })
  141. .transpose()
  142. }
  143. /// Runs the query and returns the first row or None
  144. pub fn fetch_all(
  145. self,
  146. conn: &PooledResource<SqliteConnectionManager>,
  147. ) -> rusqlite::Result<Vec<Vec<Column>>> {
  148. let mut stmt = self.get_stmt(conn)?;
  149. let columns = stmt.column_count();
  150. let mut rows = stmt.raw_query();
  151. let mut results = vec![];
  152. while let Some(row) = rows.next()? {
  153. results.push(
  154. (0..columns)
  155. .map(|i| row.get(i))
  156. .collect::<Result<Vec<_>, _>>()?,
  157. );
  158. }
  159. Ok(results)
  160. }
  161. }