mod.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. //! # In-Memory database
  2. //!
  3. //! This database module is the core of the miniredis project. All other modules around this
  4. //! database module.
  5. mod entry;
  6. mod expiration;
  7. use crate::{error::Error, value::Value};
  8. use bytes::Bytes;
  9. use entry::{new_version, Entry};
  10. use expiration::ExpirationDb;
  11. use log::trace;
  12. use parking_lot::{Mutex, RwLock};
  13. use seahash::hash;
  14. use std::{
  15. collections::HashMap,
  16. convert::{TryFrom, TryInto},
  17. ops::AddAssign,
  18. sync::Arc,
  19. thread,
  20. };
  21. use tokio::time::{Duration, Instant};
  22. /// Databas structure
  23. ///
  24. /// Each connection has their own clone of the database and the conn_id is stored in each instance.
  25. /// The entries property is shared for all connections.
  26. ///
  27. /// To avoid lock contention this database is *not* a single HashMap, instead it is a vector of
  28. /// HashMaps. Each key is presharded and a bucket is selected. By doing this pre-step instead of
  29. /// locking the entire database, only a small portion is locked (shared or exclusively) at a time,
  30. /// making this database implementation thread-friendly. The number of slots available cannot be
  31. /// changed at runtime.
  32. ///
  33. /// The database is also aware of other connections locking other keys exclusively (for
  34. /// transactions).
  35. ///
  36. /// Each entry is wrapped with an entry::Entry struct, which is aware of expirations and data
  37. /// versioning (in practice the nanosecond of last modification).
  38. #[derive(Debug)]
  39. pub struct Db {
  40. /// A vector of hashmaps.
  41. ///
  42. /// Instead of having a single HashMap, and having all threads fighting for
  43. /// blocking the single HashMap, we have a vector of N HashMap
  44. /// (configurable), which in theory allow to have faster reads and writes.
  45. ///
  46. /// Because all operations are always key specific, the key is used to hash
  47. /// and select to which HashMap the data might be stored.
  48. entries: Arc<Vec<RwLock<HashMap<Bytes, Entry>>>>,
  49. /// Data structure to store all expiring keys
  50. expirations: Arc<Mutex<ExpirationDb>>,
  51. /// Number of HashMaps that are available.
  52. slots: usize,
  53. /// A Database is attached to a conn_id. The entries and expiration data
  54. /// structures are shared between all connections.
  55. ///
  56. /// This particular database instace is attached to a conn_id, used to block
  57. /// all keys in case of a transaction.
  58. conn_id: u128,
  59. /// HashMap of all blocked keys by other connections. If a key appears in
  60. /// here and it is not being hold by the current connection, current
  61. /// connection must wait.
  62. tx_key_locks: Arc<RwLock<HashMap<Bytes, u128>>>,
  63. }
  64. impl Db {
  65. /// Creates a new database instance
  66. pub fn new(slots: usize) -> Self {
  67. let mut entries = vec![];
  68. for _i in 0..slots {
  69. entries.push(RwLock::new(HashMap::new()));
  70. }
  71. Self {
  72. entries: Arc::new(entries),
  73. expirations: Arc::new(Mutex::new(ExpirationDb::new())),
  74. conn_id: 0,
  75. tx_key_locks: Arc::new(RwLock::new(HashMap::new())),
  76. slots,
  77. }
  78. }
  79. /// Creates a new Database instance bound to a connection.
  80. ///
  81. /// This is particular useful when locking keys exclusively.
  82. ///
  83. /// All the internal data are shjared through an Arc.
  84. pub fn new_db_instance(self: Arc<Db>, conn_id: u128) -> Db {
  85. Self {
  86. entries: self.entries.clone(),
  87. tx_key_locks: self.tx_key_locks.clone(),
  88. expirations: self.expirations.clone(),
  89. conn_id,
  90. slots: self.slots,
  91. }
  92. }
  93. #[inline]
  94. /// Returns a slot where a key may be hosted.
  95. ///
  96. /// In order to avoid too much locks, instead of having a single hash a
  97. /// database instance is a set of hashes. Each key is pre-shared with a
  98. /// quick hashing algorithm to select a 'slot' or HashMap where it may be
  99. /// hosted.
  100. fn get_slot(&self, key: &Bytes) -> usize {
  101. let id = (hash(key) as usize) % self.entries.len();
  102. trace!("selected slot {} for key {:?}", id, key);
  103. let waiting = Duration::from_nanos(100);
  104. while let Some(blocker) = self.tx_key_locks.read().get(key) {
  105. // Loop while the key we are trying to access is being blocked by a
  106. // connection in a transaction
  107. if *blocker == self.conn_id {
  108. // the key is being blocked by ourself, it is safe to break the
  109. // waiting loop
  110. break;
  111. }
  112. thread::sleep(waiting);
  113. }
  114. id
  115. }
  116. /// Locks keys exclusively
  117. ///
  118. /// The locked keys are only accesible (read or write) by the connection
  119. /// that locked them, any other connection must wait until the locking
  120. /// connection releases them.
  121. ///
  122. /// This is used to simulate redis transactions. Transaction in Redis are
  123. /// atomic but pausing a multi threaded Redis just to keep the same promises
  124. /// was a bit extreme, that's the reason why a transaction will lock
  125. /// exclusively all keys involved.
  126. pub fn lock_keys(&self, keys: &[Bytes]) {
  127. let waiting = Duration::from_nanos(100);
  128. loop {
  129. let mut lock = self.tx_key_locks.write();
  130. let mut i = 0;
  131. for key in keys.iter() {
  132. if let Some(blocker) = lock.get(key) {
  133. if *blocker == self.conn_id {
  134. // It is blocked by us already.
  135. continue;
  136. }
  137. // It is blocked by another tx, we need to break
  138. // and retry to gain the lock over this key
  139. break;
  140. }
  141. lock.insert(key.clone(), self.conn_id);
  142. i += 1;
  143. }
  144. if i == keys.len() {
  145. // All the involved keys are successfully being blocked
  146. // exclusively.
  147. break;
  148. }
  149. // We need to sleep a bit and retry.
  150. drop(lock);
  151. thread::sleep(waiting);
  152. }
  153. }
  154. /// Releases the lock on keys
  155. pub fn unlock_keys(&self, keys: &[Bytes]) {
  156. let mut lock = self.tx_key_locks.write();
  157. for key in keys.iter() {
  158. lock.remove(key);
  159. }
  160. }
  161. /// Increments a key's value by a given number
  162. ///
  163. /// If the stored value cannot be converted into a number an error will be
  164. /// thrown.
  165. pub fn incr<
  166. T: ToString + AddAssign + for<'a> TryFrom<&'a Value, Error = Error> + Into<Value> + Copy,
  167. >(
  168. &self,
  169. key: &Bytes,
  170. incr_by: T,
  171. ) -> Result<Value, Error> {
  172. let mut entries = self.entries[self.get_slot(key)].write();
  173. match entries.get_mut(key) {
  174. Some(x) => {
  175. let value = x.get();
  176. let mut number: T = value.try_into()?;
  177. number += incr_by;
  178. x.change_value(number.to_string().as_str().into());
  179. Ok(number.into())
  180. }
  181. None => {
  182. entries.insert(
  183. key.clone(),
  184. Entry::new(incr_by.to_string().as_str().into(), None),
  185. );
  186. Ok((incr_by as T).into())
  187. }
  188. }
  189. }
  190. /// Removes any expiration associated with a given key
  191. pub fn persist(&self, key: &Bytes) -> Value {
  192. let mut entries = self.entries[self.get_slot(key)].write();
  193. entries
  194. .get_mut(key)
  195. .filter(|x| x.is_valid())
  196. .map_or(0.into(), |x| {
  197. if x.has_ttl() {
  198. self.expirations.lock().remove(key);
  199. x.persist();
  200. 1.into()
  201. } else {
  202. 0.into()
  203. }
  204. })
  205. }
  206. /// Set time to live for a given key
  207. pub fn set_ttl(&self, key: &Bytes, expires_in: Duration) -> Value {
  208. let mut entries = self.entries[self.get_slot(key)].write();
  209. let expires_at = Instant::now() + expires_in;
  210. entries
  211. .get_mut(key)
  212. .filter(|x| x.is_valid())
  213. .map_or(0.into(), |x| {
  214. self.expirations.lock().add(key, expires_at);
  215. x.set_ttl(expires_at);
  216. 1.into()
  217. })
  218. }
  219. /// Removes keys from the database
  220. pub fn del(&self, keys: &[Bytes]) -> Value {
  221. let mut expirations = self.expirations.lock();
  222. keys.iter()
  223. .filter_map(|key| {
  224. expirations.remove(key);
  225. self.entries[self.get_slot(key)].write().remove(key)
  226. })
  227. .filter(|key| key.is_valid())
  228. .count()
  229. .into()
  230. }
  231. /// Check if keys exists in the database
  232. pub fn exists(&self, keys: &[Bytes]) -> Value {
  233. let mut matches = 0;
  234. keys.iter()
  235. .map(|key| {
  236. let entries = self.entries[self.get_slot(key)].read();
  237. if entries.get(key).is_some() {
  238. matches += 1;
  239. }
  240. })
  241. .for_each(drop);
  242. matches.into()
  243. }
  244. /// get_map_or
  245. ///
  246. /// Instead of returning an entry of the database, to avoid clonning, this function will
  247. /// execute a callback function with the entry as a parameter. If no record is found another
  248. /// callback function is going to be executed, dropping the lock before doing so.
  249. ///
  250. /// If an entry is found, the lock is not dropped before doing the callback. Avoid inserting
  251. /// new entries. In this case the value is passed by reference, so it is possible to modify the
  252. /// entry itself.
  253. ///
  254. /// This function is useful to read non-scalar values from the database. Non-scalar values are
  255. /// forbidden to clone, attempting cloning will endup in an error (Error::WrongType)
  256. pub fn get_map_or<F1, F2>(&self, key: &Bytes, found: F1, not_found: F2) -> Result<Value, Error>
  257. where
  258. F1: FnOnce(&Value) -> Result<Value, Error>,
  259. F2: FnOnce() -> Result<Value, Error>,
  260. {
  261. let entries = self.entries[self.get_slot(key)].read();
  262. let entry = entries.get(key).filter(|x| x.is_valid()).map(|e| e.get());
  263. if let Some(entry) = entry {
  264. found(entry)
  265. } else {
  266. // drop lock
  267. drop(entries);
  268. not_found()
  269. }
  270. }
  271. /// Updates the entry version of a given key
  272. pub fn bump_version(&self, key: &Bytes) -> bool {
  273. let mut entries = self.entries[self.get_slot(key)].write();
  274. entries
  275. .get_mut(key)
  276. .filter(|x| x.is_valid())
  277. .map(|entry| {
  278. entry.bump_version();
  279. })
  280. .is_some()
  281. }
  282. /// Returns the version of a given key
  283. pub fn get_version(&self, key: &Bytes) -> u128 {
  284. let entries = self.entries[self.get_slot(key)].read();
  285. entries
  286. .get(key)
  287. .filter(|x| x.is_valid())
  288. .map(|entry| entry.version())
  289. .unwrap_or_else(new_version)
  290. }
  291. /// Get a copy of an entry
  292. pub fn get(&self, key: &Bytes) -> Value {
  293. let entries = self.entries[self.get_slot(key)].read();
  294. entries
  295. .get(key)
  296. .filter(|x| x.is_valid())
  297. .map_or(Value::Null, |x| x.clone_value())
  298. }
  299. /// Get multiple copies of entries
  300. pub fn get_multi(&self, keys: &[Bytes]) -> Value {
  301. keys.iter()
  302. .map(|key| {
  303. let entries = self.entries[self.get_slot(key)].read();
  304. entries
  305. .get(key)
  306. .filter(|x| x.is_valid() && x.is_clonable())
  307. .map_or(Value::Null, |x| x.clone_value())
  308. })
  309. .collect::<Vec<Value>>()
  310. .into()
  311. }
  312. /// Get a key or set a new value for the given key.
  313. pub fn getset(&self, key: &Bytes, value: &Value) -> Value {
  314. let mut entries = self.entries[self.get_slot(key)].write();
  315. self.expirations.lock().remove(key);
  316. entries
  317. .insert(key.clone(), Entry::new(value.clone(), None))
  318. .filter(|x| x.is_valid())
  319. .map_or(Value::Null, |x| x.clone_value())
  320. }
  321. /// Takes an entry from the database.
  322. pub fn getdel(&self, key: &Bytes) -> Value {
  323. let mut entries = self.entries[self.get_slot(key)].write();
  324. entries.remove(key).map_or(Value::Null, |x| {
  325. self.expirations.lock().remove(key);
  326. x.clone_value()
  327. })
  328. }
  329. /// Set a key, value with an optional expiration time
  330. pub fn set(&self, key: &Bytes, value: Value, expires_in: Option<Duration>) -> Value {
  331. let mut entries = self.entries[self.get_slot(key)].write();
  332. let expires_at = expires_in.map(|duration| Instant::now() + duration);
  333. if let Some(expires_at) = expires_at {
  334. self.expirations.lock().add(key, expires_at);
  335. }
  336. entries.insert(key.clone(), Entry::new(value, expires_at));
  337. Value::Ok
  338. }
  339. /// Returns the TTL of a given key
  340. pub fn ttl(&self, key: &Bytes) -> Option<Option<Instant>> {
  341. let entries = self.entries[self.get_slot(key)].read();
  342. entries
  343. .get(key)
  344. .filter(|x| x.is_valid())
  345. .map(|x| x.get_ttl())
  346. }
  347. /// Check whether a given key is in the list of keys to be purged or not.
  348. /// This function is mainly used for unit testing
  349. pub fn is_key_in_expiration_list(&self, key: &Bytes) -> bool {
  350. self.expirations.lock().has(key)
  351. }
  352. /// Remove expired entries from the database.
  353. ///
  354. /// This function should be called from a background thread every few seconds. Calling it more
  355. /// often is a waste of resources.
  356. ///
  357. /// Expired keys are automatically hidden by the database, this process is just claiming back
  358. /// the memory from those expired keys.
  359. pub fn purge(&self) -> u64 {
  360. let mut expirations = self.expirations.lock();
  361. let mut removed = 0;
  362. trace!("Watching {} keys for expirations", expirations.len());
  363. let keys = expirations.get_expired_keys(None);
  364. drop(expirations);
  365. keys.iter()
  366. .map(|key| {
  367. let mut entries = self.entries[self.get_slot(key)].write();
  368. if entries.remove(key).is_some() {
  369. trace!("Removed key {:?} due timeout", key);
  370. removed += 1;
  371. }
  372. })
  373. .for_each(drop);
  374. removed
  375. }
  376. }
  377. #[cfg(test)]
  378. mod test {
  379. use super::*;
  380. use crate::bytes;
  381. #[test]
  382. fn incr_wrong_type() {
  383. let db = Db::new(100);
  384. db.set(&bytes!(b"num"), Value::Blob(bytes!("some string")), None);
  385. let r = db.incr(&bytes!("num"), 1);
  386. assert!(r.is_err());
  387. assert_eq!(Error::NotANumber, r.expect_err("should fail"));
  388. assert_eq!(Value::Blob(bytes!("some string")), db.get(&bytes!("num")));
  389. }
  390. #[test]
  391. fn incr_blob_float() {
  392. let db = Db::new(100);
  393. db.set(&bytes!(b"num"), Value::Blob(bytes!("1.1")), None);
  394. assert_eq!(Ok(Value::Float(2.2)), db.incr(&bytes!("num"), 1.1));
  395. assert_eq!(Value::Blob(bytes!("2.2")), db.get(&bytes!("num")));
  396. }
  397. #[test]
  398. fn incr_blob_int_float() {
  399. let db = Db::new(100);
  400. db.set(&bytes!(b"num"), Value::Blob(bytes!("1")), None);
  401. assert_eq!(Ok(Value::Float(2.1)), db.incr(&bytes!("num"), 1.1));
  402. assert_eq!(Value::Blob(bytes!("2.1")), db.get(&bytes!("num")));
  403. }
  404. #[test]
  405. fn incr_blob_int() {
  406. let db = Db::new(100);
  407. db.set(&bytes!(b"num"), Value::Blob(bytes!("1")), None);
  408. assert_eq!(Ok(Value::Integer(2)), db.incr(&bytes!("num"), 1));
  409. assert_eq!(Value::Blob(bytes!("2")), db.get(&bytes!("num")));
  410. }
  411. #[test]
  412. fn incr_blob_int_set() {
  413. let db = Db::new(100);
  414. assert_eq!(Ok(Value::Integer(1)), db.incr(&bytes!("num"), 1));
  415. assert_eq!(Value::Blob(bytes!("1")), db.get(&bytes!("num")));
  416. }
  417. #[test]
  418. fn incr_blob_float_set() {
  419. let db = Db::new(100);
  420. assert_eq!(Ok(Value::Float(1.1)), db.incr(&bytes!("num"), 1.1));
  421. assert_eq!(Value::Blob(bytes!("1.1")), db.get(&bytes!("num")));
  422. }
  423. #[test]
  424. fn del() {
  425. let db = Db::new(100);
  426. db.set(&bytes!(b"expired"), Value::Ok, Some(Duration::from_secs(0)));
  427. db.set(&bytes!(b"valid"), Value::Ok, None);
  428. db.set(
  429. &bytes!(b"expiring"),
  430. Value::Ok,
  431. Some(Duration::from_secs(5)),
  432. );
  433. assert_eq!(
  434. Value::Integer(2),
  435. db.del(&[
  436. bytes!(b"expired"),
  437. bytes!(b"valid"),
  438. bytes!(b"expiring"),
  439. bytes!(b"not_existing_key")
  440. ])
  441. );
  442. }
  443. #[test]
  444. fn ttl() {
  445. let db = Db::new(100);
  446. db.set(&bytes!(b"expired"), Value::Ok, Some(Duration::from_secs(0)));
  447. db.set(&bytes!(b"valid"), Value::Ok, None);
  448. db.set(
  449. &bytes!(b"expiring"),
  450. Value::Ok,
  451. Some(Duration::from_secs(5)),
  452. );
  453. assert_eq!(None, db.ttl(&bytes!(b"expired")));
  454. assert_eq!(None, db.ttl(&bytes!(b"not_existing_key")));
  455. assert_eq!(Some(None), db.ttl(&bytes!(b"valid")));
  456. assert!(match db.ttl(&bytes!(b"expiring")) {
  457. Some(Some(_)) => true,
  458. _ => false,
  459. });
  460. }
  461. #[test]
  462. fn persist_bug() {
  463. let db = Db::new(100);
  464. db.set(&bytes!(b"one"), Value::Ok, Some(Duration::from_secs(1)));
  465. assert_eq!(Value::Ok, db.get(&bytes!(b"one")));
  466. assert!(db.is_key_in_expiration_list(&bytes!(b"one")));
  467. db.persist(&bytes!(b"one"));
  468. assert!(!db.is_key_in_expiration_list(&bytes!(b"one")));
  469. }
  470. #[test]
  471. fn purge_keys() {
  472. let db = Db::new(100);
  473. db.set(&bytes!(b"one"), Value::Ok, Some(Duration::from_secs(0)));
  474. // Expired keys should not be returned, even if they are not yet
  475. // removed by the purge process.
  476. assert_eq!(Value::Null, db.get(&bytes!(b"one")));
  477. // Purge twice
  478. assert_eq!(1, db.purge());
  479. assert_eq!(0, db.purge());
  480. assert_eq!(Value::Null, db.get(&bytes!(b"one")));
  481. }
  482. #[test]
  483. fn replace_purge_keys() {
  484. let db = Db::new(100);
  485. db.set(&bytes!(b"one"), Value::Ok, Some(Duration::from_secs(0)));
  486. // Expired keys should not be returned, even if they are not yet
  487. // removed by the purge process.
  488. assert_eq!(Value::Null, db.get(&bytes!(b"one")));
  489. db.set(&bytes!(b"one"), Value::Ok, Some(Duration::from_secs(5)));
  490. assert_eq!(Value::Ok, db.get(&bytes!(b"one")));
  491. // Purge should return 0 as the expired key has been removed already
  492. assert_eq!(0, db.purge());
  493. }
  494. }