Ver código fonte

Add data structure to keep track of expiring keys

This new data structure will sort expiring keys by their expiration time
and will make sure they are no duplicate entries.

If the same key is added twice, only the latest expiration time will be
tracked. A function returns expired keys (and removed those keys from
the internal data structure).
Cesar Rodas 3 anos atrás
pai
commit
be44087050
4 arquivos alterados com 269 adições e 26 exclusões
  1. 5 8
      src/db/entry.rs
  2. 205 0
      src/db/expiration.rs
  3. 47 17
      src/db/mod.rs
  4. 12 1
      src/server.rs

+ 5 - 8
src/db/entry.rs

@@ -1,5 +1,5 @@
 use crate::value::Value;
-use tokio::time::{Duration, Instant};
+use tokio::time::Instant;
 
 #[derive(Debug)]
 pub struct Entry {
@@ -15,11 +15,8 @@ pub struct Entry {
 /// this promise we can run the purge process every few seconds instead of doing
 /// so more frequently.
 impl Entry {
-    pub fn new(value: Value, expires_in: Option<Duration>) -> Self {
-        Self {
-            value,
-            expires_at: expires_in.map(|duration| Instant::now() + duration),
-        }
+    pub fn new(value: Value, expires_at: Option<Instant>) -> Self {
+        Self { value, expires_at }
     }
 
     pub fn persist(&mut self) {
@@ -34,8 +31,8 @@ impl Entry {
         self.expires_at.is_some()
     }
 
-    pub fn set_ttl(&mut self, expires_in: Duration) {
-        self.expires_at = Some(Instant::now() + expires_in);
+    pub fn set_ttl(&mut self, expires_at: Instant) {
+        self.expires_at = Some(expires_at);
     }
 
     /// Changes the value that is wrapped in this entry, the TTL (expired_at) is

+ 205 - 0
src/db/expiration.rs

@@ -0,0 +1,205 @@
+use bytes::Bytes;
+use std::collections::{BTreeMap, HashMap};
+use tokio::time::Instant;
+
+/// ExpirationId
+///
+/// The internal data structure is a B-Tree and the key is the expiration time,
+/// all data are naturally sorted by expiration time. Because it is possible
+/// that different keys expire at the same instant, an internal counter is added
+/// to the ID to make each ID unique (and sorted by Expiration Time +
+/// Incremental
+/// counter).
+#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)]
+pub struct ExpirationId(pub (Instant, u64));
+
+#[derive(Debug)]
+pub struct ExpirationDb {
+    /// B-Tree Map of expiring keys
+    expiring_keys: BTreeMap<ExpirationId, Bytes>,
+    /// Hash which contains the keys and their ExpirationId.
+    keys: HashMap<Bytes, ExpirationId>,
+    next_id: u64,
+}
+
+impl ExpirationDb {
+    pub fn new() -> Self {
+        Self {
+            expiring_keys: BTreeMap::new(),
+            keys: HashMap::new(),
+            next_id: 0,
+        }
+    }
+
+    pub fn add(&mut self, key: &Bytes, expires_at: Instant) {
+        let entry_id = ExpirationId((expires_at, self.next_id));
+
+        if let Some(prev) = self.keys.remove(key) {
+            // Another key with expiration is already known, it has
+            // to be removed before adding a new one
+            self.expiring_keys.remove(&prev);
+        }
+
+        self.expiring_keys.insert(entry_id, key.clone());
+        self.keys.insert(key.clone(), entry_id);
+
+        self.next_id += 1;
+    }
+
+    pub fn remove(&mut self, key: &Bytes) -> bool {
+        if let Some(prev) = self.keys.remove(key) {
+            // Another key with expiration is already known, it has
+            // to be removed before adding a new one
+            self.expiring_keys.remove(&prev);
+            true
+        } else {
+            false
+        }
+    }
+
+    pub fn len(&self) -> usize {
+        self.expiring_keys.len()
+    }
+
+    /// Returns a list of expired keys, these keys are removed from the internal
+    /// data structure which is keeping track of expiring keys.
+    pub fn get_expired_keys(&mut self, now: Option<Instant>) -> Vec<Bytes> {
+        let now = now.unwrap_or_else(Instant::now);
+
+        let mut expiring_keys = vec![];
+
+        for (key, value) in self.expiring_keys.iter_mut() {
+            if key.0 .0 > now {
+                break;
+            }
+
+            expiring_keys.push((*key, value.clone()));
+            self.keys.remove(value);
+        }
+
+        expiring_keys
+            .iter()
+            .map(|(k, v)| {
+                self.expiring_keys.remove(k);
+                v.to_owned()
+            })
+            .collect()
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+    use tokio::time::{Duration, Instant};
+
+    #[test]
+    fn two_entires_same_expiration() {
+        let mut db = ExpirationDb::new();
+        let key1 = Bytes::from(&b"key"[..]);
+        let key2 = Bytes::from(&b"bar"[..]);
+        let key3 = Bytes::from(&b"xxx"[..]);
+        let expiration = Instant::now() + Duration::from_secs(5);
+
+        db.add(&key1, expiration);
+        db.add(&key2, expiration);
+        db.add(&key3, expiration);
+
+        assert_eq!(3, db.len());
+    }
+
+    #[test]
+    fn remove_prev_expiration() {
+        let mut db = ExpirationDb::new();
+        let key1 = Bytes::from(&b"key"[..]);
+        let key2 = Bytes::from(&b"bar"[..]);
+        let expiration = Instant::now() + Duration::from_secs(5);
+
+        db.add(&key1, expiration);
+        db.add(&key2, expiration);
+        db.add(&key1, expiration);
+
+        assert_eq!(2, db.len());
+    }
+
+    #[test]
+    fn get_expiration() {
+        let mut db = ExpirationDb::new();
+        let keys = vec![
+            (
+                Bytes::from(&b"hix"[..]),
+                Instant::now() + Duration::from_secs(15),
+            ),
+            (
+                Bytes::from(&b"key"[..]),
+                Instant::now() + Duration::from_secs(2),
+            ),
+            (
+                Bytes::from(&b"bar"[..]),
+                Instant::now() + Duration::from_secs(3),
+            ),
+            (
+                Bytes::from(&b"hi"[..]),
+                Instant::now() + Duration::from_secs(3),
+            ),
+        ];
+
+        keys.iter()
+            .map(|v| {
+                db.add(&v.0, v.1);
+            })
+            .for_each(drop);
+
+        assert_eq!(db.len(), keys.len());
+
+        assert_eq!(0, db.get_expired_keys(Some(Instant::now())).len());
+        assert_eq!(db.len(), keys.len());
+
+        assert_eq!(
+            vec![keys[1].0.clone()],
+            db.get_expired_keys(Some(Instant::now() + Duration::from_secs(2)))
+        );
+        assert_eq!(3, db.len());
+
+        assert_eq!(
+            vec![keys[2].0.clone(), keys[3].0.clone()],
+            db.get_expired_keys(Some(Instant::now() + Duration::from_secs(4)))
+        );
+        assert_eq!(1, db.len());
+    }
+
+    #[test]
+    pub fn remove() {
+        let mut db = ExpirationDb::new();
+        let keys = vec![
+            (
+                Bytes::from(&b"hix"[..]),
+                Instant::now() + Duration::from_secs(15),
+            ),
+            (
+                Bytes::from(&b"key"[..]),
+                Instant::now() + Duration::from_secs(2),
+            ),
+            (
+                Bytes::from(&b"bar"[..]),
+                Instant::now() + Duration::from_secs(3),
+            ),
+            (
+                Bytes::from(&b"hi"[..]),
+                Instant::now() + Duration::from_secs(3),
+            ),
+        ];
+
+        keys.iter()
+            .map(|v| {
+                db.add(&v.0, v.1);
+            })
+            .for_each(drop);
+
+        assert_eq!(keys.len(), db.len());
+
+        assert!(db.remove(&keys[0].0));
+        assert!(!db.remove(&keys[0].0));
+
+        assert_eq!(keys.len() - 1, db.len());
+    }
+}

+ 47 - 17
src/db/mod.rs

@@ -1,15 +1,17 @@
-pub mod entry;
+mod entry;
+mod expiration;
 
 use crate::{error::Error, value::Value};
 use bytes::Bytes;
 use entry::Entry;
+use expiration::ExpirationDb;
 use log::trace;
 use seahash::hash;
 use std::{
-    collections::{BTreeMap, HashMap},
+    collections::HashMap,
     convert::{TryFrom, TryInto},
     ops::AddAssign,
-    sync::RwLock,
+    sync::{Mutex, RwLock},
 };
 use tokio::time::{Duration, Instant};
 
@@ -24,14 +26,10 @@ pub struct Db {
     /// Because all operations are always key specific, the key is used to hash
     /// and select to which HashMap the data might be stored.
     entries: Vec<RwLock<HashMap<Bytes, Entry>>>,
-    /// B-Tree Map of expiring keys
-    ///
-    /// This B-Tree has the name of expiring entries, and they are sorted by the
-    /// Instant where the entries are expiring.
-    ///
-    /// Because it is possible that two entries expire at the same Instant, a
-    /// counter is introduced to avoid collisions on this B-Tree.
-    expirations: RwLock<BTreeMap<(Instant, u64), String>>,
+
+    /// Data structure to store all expiring keys
+    expirations: Mutex<ExpirationDb>,
+
     /// Number of HashMaps that are available.
     slots: usize,
 }
@@ -46,7 +44,7 @@ impl Db {
 
         Self {
             entries,
-            expirations: RwLock::new(BTreeMap::new()),
+            expirations: Mutex::new(ExpirationDb::new()),
             slots,
         }
     }
@@ -102,23 +100,28 @@ impl Db {
             })
     }
 
-    pub fn set_ttl(&self, key: &Bytes, expiration: Duration) -> Value {
+    pub fn set_ttl(&self, key: &Bytes, expires_in: Duration) -> Value {
         let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+        let expires_at = Instant::now() + expires_in;
+
         entries
             .get_mut(key)
             .filter(|x| x.is_valid())
             .map_or(0_i64.into(), |x| {
-                x.set_ttl(expiration);
+                self.expirations.lock().unwrap().add(key, expires_at);
+                x.set_ttl(expires_at);
                 1_i64.into()
             })
     }
 
     pub fn del(&self, keys: &[Bytes]) -> Value {
         let mut deleted = 0_i64;
+        let mut expirations = self.expirations.lock().unwrap();
         keys.iter()
             .map(|key| {
                 let mut entries = self.entries[self.get_slot(key)].write().unwrap();
                 if entries.remove(key).is_some() {
+                    expirations.remove(&key);
                     deleted += 1;
                 }
             })
@@ -161,6 +164,7 @@ impl Db {
 
     pub fn getset(&self, key: &Bytes, value: &Value) -> Value {
         let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+        self.expirations.lock().unwrap().remove(&key);
         entries
             .insert(key.clone(), Entry::new(value.clone(), None))
             .filter(|x| x.is_valid())
@@ -169,12 +173,20 @@ impl Db {
 
     pub fn getdel(&self, key: &Bytes) -> Value {
         let mut entries = self.entries[self.get_slot(key)].write().unwrap();
-        entries.remove(key).map_or(Value::Null, |x| x.get().clone())
+        entries.remove(key).map_or(Value::Null, |x| {
+            self.expirations.lock().unwrap().remove(&key);
+            x.get().clone()
+        })
     }
 
-    pub fn set(&self, key: &Bytes, value: &Value, expires: Option<Duration>) -> Value {
+    pub fn set(&self, key: &Bytes, value: &Value, expires_in: Option<Duration>) -> Value {
         let mut entries = self.entries[self.get_slot(key)].write().unwrap();
-        entries.insert(key.clone(), Entry::new(value.clone(), expires));
+        let expires_at = expires_in.map(|duration| Instant::now() + duration);
+
+        if let Some(expires_at) = expires_at {
+            self.expirations.lock().unwrap().add(&key, expires_at);
+        }
+        entries.insert(key.clone(), Entry::new(value.clone(), expires_at));
         Value::OK
     }
 
@@ -185,4 +197,22 @@ impl Db {
             .filter(|x| x.is_valid())
             .map(|x| x.get_ttl())
     }
+
+    pub fn purge(&self) {
+        let mut expirations = self.expirations.lock().unwrap();
+
+        trace!("Watching {} keys for expirations", expirations.len());
+
+        let keys = expirations.get_expired_keys(None);
+        drop(expirations);
+
+        keys.iter()
+            .map(|key| {
+                let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+                if entries.remove(key).is_some() {
+                    trace!("Removed key {:?} due timeout", key);
+                }
+            })
+            .for_each(drop);
+    }
 }

+ 12 - 1
src/server.rs

@@ -4,7 +4,10 @@ use futures::SinkExt;
 use log::{info, trace, warn};
 use redis_zero_protocol_parser::{parse_server, Error as RedisError};
 use std::{error::Error, io, ops::Deref, sync::Arc};
-use tokio::net::TcpListener;
+use tokio::{
+    net::TcpListener,
+    time::{sleep, Duration},
+};
 use tokio_stream::StreamExt;
 use tokio_util::codec::{Decoder, Encoder, Framed};
 
@@ -50,6 +53,14 @@ pub async fn serve(addr: String) -> Result<(), Box<dyn Error>> {
     let db = Arc::new(Db::new(1000));
     let mut all_connections = Connections::new();
 
+    let db_for_purging = db.clone();
+    tokio::spawn(async move {
+        loop {
+            db_for_purging.purge();
+            sleep(Duration::from_millis(5000)).await;
+        }
+    });
+
     loop {
         match listener.accept().await {
             Ok((socket, addr)) => {