expiration.rs 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. use bytes::Bytes;
  2. use std::collections::{BTreeMap, HashMap};
  3. use tokio::time::Instant;
  4. /// ExpirationId
  5. ///
  6. /// The internal data structure is a B-Tree and the key is the expiration time,
  7. /// all data are naturally sorted by expiration time. Because it is possible
  8. /// that different keys expire at the same instant, an internal counter is added
  9. /// to the ID to make each ID unique (and sorted by Expiration Time +
  10. /// Incremental
  11. /// counter).
  12. #[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)]
  13. pub struct ExpirationId(pub (Instant, u64));
  14. #[derive(Debug)]
  15. pub struct ExpirationDb {
  16. /// B-Tree Map of expiring keys
  17. expiring_keys: BTreeMap<ExpirationId, Bytes>,
  18. /// Hash which contains the keys and their ExpirationId.
  19. keys: HashMap<Bytes, ExpirationId>,
  20. next_id: u64,
  21. }
  22. impl ExpirationDb {
  23. pub fn new() -> Self {
  24. Self {
  25. expiring_keys: BTreeMap::new(),
  26. keys: HashMap::new(),
  27. next_id: 0,
  28. }
  29. }
  30. pub fn add(&mut self, key: &Bytes, expires_at: Instant) {
  31. let entry_id = ExpirationId((expires_at, self.next_id));
  32. if let Some(prev) = self.keys.remove(key) {
  33. // Another key with expiration is already known, it has
  34. // to be removed before adding a new one
  35. self.expiring_keys.remove(&prev);
  36. }
  37. self.expiring_keys.insert(entry_id, key.clone());
  38. self.keys.insert(key.clone(), entry_id);
  39. self.next_id += 1;
  40. }
  41. pub fn has(&self, key: &Bytes) -> bool {
  42. self.keys.get(key).is_some()
  43. }
  44. pub fn remove(&mut self, key: &Bytes) -> bool {
  45. if let Some(prev) = self.keys.remove(key) {
  46. // Another key with expiration is already known, it has
  47. // to be removed before adding a new one
  48. self.expiring_keys.remove(&prev);
  49. true
  50. } else {
  51. false
  52. }
  53. }
  54. pub fn len(&self) -> usize {
  55. self.expiring_keys.len()
  56. }
  57. /// Returns a list of expired keys, these keys are removed from the internal
  58. /// data structure which is keeping track of expiring keys.
  59. pub fn get_expired_keys(&mut self, now: Option<Instant>) -> Vec<Bytes> {
  60. let now = now.unwrap_or_else(Instant::now);
  61. let mut expiring_keys = vec![];
  62. for (key, value) in self.expiring_keys.iter_mut() {
  63. if key.0 .0 > now {
  64. break;
  65. }
  66. expiring_keys.push((*key, value.clone()));
  67. self.keys.remove(value);
  68. }
  69. expiring_keys
  70. .iter()
  71. .map(|(k, v)| {
  72. self.expiring_keys.remove(k);
  73. v.to_owned()
  74. })
  75. .collect()
  76. }
  77. }
  78. #[cfg(test)]
  79. mod test {
  80. use super::*;
  81. use crate::bytes;
  82. use tokio::time::{Duration, Instant};
  83. #[test]
  84. fn two_entires_same_expiration() {
  85. let mut db = ExpirationDb::new();
  86. let key1 = bytes!(b"key");
  87. let key2 = bytes!(b"bar");
  88. let key3 = bytes!(b"xxx");
  89. let expiration = Instant::now() + Duration::from_secs(5);
  90. db.add(&key1, expiration);
  91. db.add(&key2, expiration);
  92. db.add(&key3, expiration);
  93. assert_eq!(3, db.len());
  94. }
  95. #[test]
  96. fn remove_prev_expiration() {
  97. let mut db = ExpirationDb::new();
  98. let key1 = bytes!(b"key");
  99. let key2 = bytes!(b"bar");
  100. let expiration = Instant::now() + Duration::from_secs(5);
  101. db.add(&key1, expiration);
  102. db.add(&key2, expiration);
  103. db.add(&key1, expiration);
  104. assert_eq!(2, db.len());
  105. }
  106. #[test]
  107. fn get_expiration() {
  108. let mut db = ExpirationDb::new();
  109. let keys = vec![
  110. ("hix".into(), Instant::now() + Duration::from_secs(15)),
  111. ("key".into(), Instant::now() + Duration::from_secs(2)),
  112. ("bar".into(), Instant::now() + Duration::from_secs(3)),
  113. ("hi".into(), Instant::now() + Duration::from_secs(3)),
  114. ];
  115. keys.iter()
  116. .map(|v| {
  117. db.add(&v.0, v.1);
  118. })
  119. .for_each(drop);
  120. assert_eq!(db.len(), keys.len());
  121. assert_eq!(0, db.get_expired_keys(Some(Instant::now())).len());
  122. assert_eq!(db.len(), keys.len());
  123. assert_eq!(
  124. vec![keys[1].0.clone()],
  125. db.get_expired_keys(Some(Instant::now() + Duration::from_secs(2)))
  126. );
  127. assert_eq!(3, db.len());
  128. assert_eq!(
  129. vec![keys[2].0.clone(), keys[3].0.clone()],
  130. db.get_expired_keys(Some(Instant::now() + Duration::from_secs(4)))
  131. );
  132. assert_eq!(1, db.len());
  133. }
  134. #[test]
  135. pub fn remove() {
  136. let mut db = ExpirationDb::new();
  137. let keys = vec![
  138. ("hix".into(), Instant::now() + Duration::from_secs(15)),
  139. ("key".into(), Instant::now() + Duration::from_secs(2)),
  140. ("bar".into(), Instant::now() + Duration::from_secs(3)),
  141. ("hi".into(), Instant::now() + Duration::from_secs(3)),
  142. ];
  143. keys.iter()
  144. .map(|v| {
  145. db.add(&(v.0), v.1);
  146. })
  147. .for_each(drop);
  148. assert_eq!(keys.len(), db.len());
  149. assert!(db.remove(&keys[0].0));
  150. assert!(!db.remove(&keys[0].0));
  151. assert_eq!(keys.len() - 1, db.len());
  152. }
  153. }