Преглед на файлове

Guard inner value inside an Entry with a RwLock

The main goal is to avoid locking the entire slot as Write() to update
 a single value. Slots should be locked as write to add/remove new keys,
 not for individual updates. That's suboptimal.

 A requirement for that is to guard the Value with a RwLock
Cesar Rodas преди 1 година
родител
ревизия
5e90611e30
променени са 5 файла, в които са добавени 181 реда и са изтрити 158 реда
  1. 19 11
      src/cmd/string.rs
  2. 1 1
      src/cmd/transaction.rs
  3. 1 1
      src/connection/mod.rs
  4. 33 19
      src/db/entry.rs
  5. 127 126
      src/db/mod.rs

+ 19 - 11
src/cmd/string.rs

@@ -77,7 +77,7 @@ pub async fn decr_by(conn: &Connection, args: VecDeque<Bytes>) -> Result<Value,
 /// Get the value of key. If the key does not exist the special value nil is returned. An error is
 /// returned if the value stored at key is not a string, because GET only handles string values.
 pub async fn get(conn: &Connection, args: VecDeque<Bytes>) -> Result<Value, Error> {
-    Ok(conn.db().get(&args[0]).inner())
+    Ok(conn.db().get(&args[0]).into_inner())
 }
 
 /// Get the value of key and optionally set its expiration. GETEX is similar to
@@ -124,11 +124,15 @@ pub async fn getex(conn: &Connection, args: VecDeque<Bytes>) -> Result<Value, Er
 /// Get the value of key. If the key does not exist the special value nil is returned. An error is
 /// returned if the value stored at key is not a string, because GET only handles string values.
 pub async fn getrange(conn: &Connection, args: VecDeque<Bytes>) -> Result<Value, Error> {
-    let bytes = match conn.db().get(&args[0]).deref() {
-        Value::Blob(binary) => binary.clone(),
-        Value::BlobRw(binary) => binary.clone().freeze(),
-        Value::Null => return Ok("".into()),
-        _ => return Err(Error::WrongType),
+    let bytes = if let Some(value) = conn.db().get(&args[0]).inner() {
+        match value.deref() {
+            Value::Blob(binary) => binary.clone(),
+            Value::BlobRw(binary) => binary.clone().freeze(),
+            Value::Null => return Ok("".into()),
+            _ => return Err(Error::WrongType),
+        }
+    } else {
+        return Ok("".into());
     };
 
     let start = bytes_to_number::<i64>(&args[1])?;
@@ -350,11 +354,15 @@ pub async fn setnx(conn: &Connection, mut args: VecDeque<Bytes>) -> Result<Value
 /// Returns the length of the string value stored at key. An error is returned when key holds a
 /// non-string value.
 pub async fn strlen(conn: &Connection, args: VecDeque<Bytes>) -> Result<Value, Error> {
-    match conn.db().get(&args[0]).deref() {
-        Value::Blob(x) => Ok(x.len().into()),
-        Value::String(x) => Ok(x.len().into()),
-        Value::Null => Ok(0.into()),
-        _ => Ok(Error::WrongType.into()),
+    if let Some(value) = conn.db().get(&args[0]).inner() {
+        match value.deref() {
+            Value::Blob(x) => Ok(x.len().into()),
+            Value::String(x) => Ok(x.len().into()),
+            Value::Null => Ok(0.into()),
+            _ => Ok(Error::WrongType.into()),
+        }
+    } else {
+        Ok(0.into())
     }
 }
 

+ 1 - 1
src/cmd/transaction.rs

@@ -74,7 +74,7 @@ pub async fn watch(conn: &Connection, args: VecDeque<Bytes>) -> Result<Value, Er
     conn.watch_key(
         args.into_iter()
             .map(|key| {
-                let v = conn.db().get_version(&key);
+                let v = conn.db().get(&key).version();
                 (key, v)
             })
             .collect::<Vec<(Bytes, usize)>>(),

+ 1 - 1
src/connection/mod.rs

@@ -281,7 +281,7 @@ impl Connection {
         let watch_keys = &self.info.read().watch_keys;
 
         for key in watch_keys.iter() {
-            if self.info.read().db.get_version(&key.0) != key.1 {
+            if self.info.read().db.get(&key.0).version() != key.1 {
                 return true;
             }
         }

+ 33 - 19
src/db/entry.rs

@@ -1,12 +1,13 @@
 use crate::{error::Error, value::Value};
-use parking_lot::Mutex;
+use bytes::BytesMut;
+use parking_lot::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard};
 use std::sync::atomic::{AtomicUsize, Ordering};
 use tokio::time::Instant;
 
 #[derive(Debug)]
 pub struct Entry {
-    pub value: Value,
-    pub version: AtomicUsize,
+    value: RwLock<Value>,
+    version: AtomicUsize,
     expires_at: Mutex<Option<Instant>>,
 }
 
@@ -27,12 +28,20 @@ pub fn unique_id() -> usize {
 impl Entry {
     pub fn new(value: Value, expires_at: Option<Instant>) -> Self {
         Self {
-            value,
+            value: RwLock::new(value),
             expires_at: Mutex::new(expires_at),
             version: AtomicUsize::new(LAST_VERSION.fetch_add(1, Ordering::Relaxed)),
         }
     }
 
+    pub fn take_value(self) -> Value {
+        self.value.into_inner()
+    }
+
+    pub fn digest(&self) -> Vec<u8> {
+        self.value.read().digest()
+    }
+
     #[inline(always)]
     pub fn bump_version(&self) {
         self.version.store(
@@ -46,7 +55,7 @@ impl Entry {
     }
 
     pub fn clone(&self) -> Self {
-        Self::new(self.value.clone(), *self.expires_at.lock())
+        Self::new(self.value.read().clone(), *self.expires_at.lock())
     }
 
     pub fn get_ttl(&self) -> Option<Instant> {
@@ -66,21 +75,26 @@ impl Entry {
         self.version.load(Ordering::Relaxed)
     }
 
-    /// Changes the value that is wrapped in this entry, the TTL (expired_at) is
-    /// not affected.
-    pub fn change_value(&mut self, value: Value) {
-        self.value = value;
-        self.bump_version()
+    pub fn get(&self) -> RwLockReadGuard<'_, Value> {
+        self.value.read()
     }
 
-    #[allow(dead_code)]
-    pub fn get_mut(&mut self) -> &mut Value {
+    pub fn get_mut(&self) -> RwLockWriteGuard<'_, Value> {
         self.bump_version();
-        &mut self.value
-    }
-
-    pub fn get(&self) -> &Value {
-        &self.value
+        self.value.write()
+    }
+
+    pub fn ensure_blob_is_mutable(&self) -> Result<(), Error> {
+        let mut val = self.get_mut();
+        match *val {
+            Value::Blob(ref mut data) => {
+                let rw_data = BytesMut::from(&data[..]);
+                *val = Value::BlobRw(rw_data);
+                Ok(())
+            }
+            Value::BlobRw(_) => Ok(()),
+            _ => Err(Error::WrongType),
+        }
     }
 
     /// If the Entry should be taken as valid, if this function returns FALSE
@@ -94,7 +108,7 @@ impl Entry {
     /// Whether or not the value is scalar
     pub fn is_scalar(&self) -> bool {
         matches!(
-            &self.value,
+            *self.value.read(),
             Value::Boolean(_)
                 | Value::Blob(_)
                 | Value::BlobRw(_)
@@ -111,7 +125,7 @@ impl Entry {
     /// returned instead
     pub fn clone_value(&self) -> Value {
         if self.is_scalar() {
-            self.value.clone()
+            self.value.read().clone()
         } else {
             Error::WrongType.into()
         }

+ 127 - 126
src/db/mod.rs

@@ -19,7 +19,7 @@ use seahash::hash;
 use std::{
     collections::{HashMap, VecDeque},
     convert::{TryFrom, TryInto},
-    ops::Deref,
+    ops::DerefMut,
     str::FromStr,
     sync::Arc,
     thread,
@@ -44,7 +44,7 @@ pub struct RefValue<'a> {
 impl<'a> RefValue<'a> {
     /// test
     #[inline(always)]
-    pub fn inner(self) -> Value {
+    pub fn into_inner(self) -> Value {
         self.slot
             .get(self.key)
             .filter(|x| x.is_valid())
@@ -58,26 +58,22 @@ impl<'a> RefValue<'a> {
             .unwrap_or_default()
     }
 
-    /// Returns the version of a given key
-    #[inline(always)]
-    pub fn version(&self) -> usize {
+    /// test
+    pub fn inner(&self) -> Option<RwLockReadGuard<'_, Value>> {
         self.slot
             .get(self.key)
             .filter(|x| x.is_valid())
-            .map(|x| x.version())
-            .unwrap_or_default()
+            .map(|x| x.get())
     }
-}
-
-impl Deref for RefValue<'_> {
-    type Target = Value;
 
-    fn deref(&self) -> &Self::Target {
+    /// Returns the version of a given key
+    #[inline(always)]
+    pub fn version(&self) -> usize {
         self.slot
             .get(self.key)
             .filter(|x| x.is_valid())
-            .map(|x| x.get())
-            .unwrap_or(&Value::Null)
+            .map(|x| x.version())
+            .unwrap_or_default()
     }
 }
 
@@ -256,12 +252,10 @@ impl Db {
     /// Return debug info for a key
     pub fn debug(&self, key: &Bytes) -> Result<VDebug, Error> {
         let slot = self.slots[self.get_slot(key)].read();
-        Ok(slot
-            .get(key)
+        slot.get(key)
             .filter(|x| x.is_valid())
-            .ok_or(Error::NotFound)?
-            .value
-            .debug())
+            .map(|x| x.get().debug())
+            .ok_or(Error::NotFound)
     }
 
     /// Return the digest for each key. This used for testing only
@@ -273,7 +267,7 @@ impl Db {
                 Value::new(
                     slot.get(key)
                         .filter(|v| v.is_valid())
-                        .map(|v| hex::encode(v.value.digest()))
+                        .map(|v| hex::encode(v.digest()))
                         .unwrap_or("00000".into())
                         .as_bytes(),
                 )
@@ -360,11 +354,14 @@ impl Db {
             + Into<Value>
             + Copy,
     {
-        let mut slot = self.slots[self.get_slot(key)].write();
+        let slot_id = self.get_slot(key);
+        let slot = self.slots[slot_id].read();
         let mut incr_by: T =
             bytes_to_number(incr_by).map_err(|_| Error::NotANumberType(typ.to_owned()))?;
-        match slot.get_mut(key).filter(|x| x.is_valid()).map(|x| x.get()) {
-            Some(Value::Hash(h)) => {
+
+        if let Some(entry) = slot.get(key).filter(|x| x.is_valid()) {
+            let mut value = entry.get_mut();
+            if let Value::Hash(h) = value.deref_mut() {
                 let mut h = h.write();
                 if let Some(n) = h.get(sub_key) {
                     incr_by = incr_by
@@ -378,16 +375,19 @@ impl Db {
                 h.insert(sub_key.clone(), incr_by_bytes.clone());
 
                 Self::number_to_value(&incr_by_bytes)
+            } else {
+                Err(Error::WrongType)
             }
-            None => {
-                #[allow(clippy::mutable_key_type)]
-                let mut h = HashMap::new();
-                let incr_by_bytes = Self::round_numbers(incr_by);
-                h.insert(sub_key.clone(), incr_by_bytes.clone());
-                let _ = slot.insert(key.clone(), Entry::new(h.into(), None));
-                Self::number_to_value(&incr_by_bytes)
-            }
-            _ => Err(Error::WrongType),
+        } else {
+            drop(slot);
+            #[allow(clippy::mutable_key_type)]
+            let mut h = HashMap::new();
+            let incr_by_bytes = Self::round_numbers(incr_by);
+            h.insert(sub_key.clone(), incr_by_bytes.clone());
+            let _ = self.slots[slot_id]
+                .write()
+                .insert(key.clone(), Entry::new(h.into(), None));
+            Self::number_to_value(&incr_by_bytes)
         }
     }
 
@@ -399,28 +399,26 @@ impl Db {
     where
         T: ToString + CheckedAdd + for<'a> TryFrom<&'a Value, Error = Error> + Into<Value> + Copy,
     {
-        let mut slot = self.slots[self.get_slot(key)].write();
-        match slot.get_mut(key).filter(|x| x.is_valid()) {
-            Some(x) => {
-                if !x.is_scalar() {
-                    return Err(Error::WrongType);
-                }
-                let value = x.get();
-                let mut number: T = value.try_into()?;
-
-                number = incr_by.checked_add(&number).ok_or(Error::Overflow)?;
-
-                x.change_value(Value::Blob(Self::round_numbers(number)));
+        let slot_id = self.get_slot(key);
+        let slot = self.slots[slot_id].read();
 
-                Ok(number)
-            }
-            None => {
-                slot.insert(
-                    key.clone(),
-                    Entry::new(Value::Blob(Self::round_numbers(incr_by)), None),
-                );
-                Ok(incr_by)
+        if let Some(entry) = slot.get(key).filter(|x| x.is_valid()) {
+            if !entry.is_scalar() {
+                return Err(Error::WrongType);
             }
+            let mut value = entry.get_mut();
+            let mut number: T = (&*value).try_into()?;
+
+            number = incr_by.checked_add(&number).ok_or(Error::Overflow)?;
+            *value = Value::Blob(Self::round_numbers(number));
+            Ok(number)
+        } else {
+            drop(slot);
+            self.slots[slot_id].write().insert(
+                key.clone(),
+                Entry::new(Value::Blob(Self::round_numbers(incr_by)), None),
+            );
+            Ok(incr_by)
         }
     }
 
@@ -502,22 +500,20 @@ impl Db {
     /// command will make sure it holds a string large enough to be able to set
     /// value at offset.
     pub fn set_range(&self, key: &Bytes, offset: i128, data: &[u8]) -> Result<Value, Error> {
-        let mut slot = self.slots[self.get_slot(key)].write();
+        let slot_id = self.get_slot(key);
+        let slot = self.slots[slot_id].read();
 
-        if let Some(entry) = slot.get_mut(key).filter(|x| x.is_valid()) {
-            if let Value::Blob(data) = entry.get() {
-                let rw_data = BytesMut::from(&data[..]);
-                entry.change_value(Value::BlobRw(rw_data));
-            }
-        }
-
-        let value = slot.get_mut(key).map(|value| {
-            if !value.is_valid() {
-                self.expirations.lock().remove(key);
-                value.persist();
-            }
-            value.get_mut()
-        });
+        let mut value = slot
+            .get(key)
+            .map(|value| {
+                value.ensure_blob_is_mutable()?;
+                if !value.is_valid() {
+                    self.expirations.lock().remove(key);
+                    value.persist();
+                }
+                Ok::<_, Error>(value.get_mut())
+            })
+            .transpose()?;
 
         if offset < 0 {
             return Err(Error::OutOfRange);
@@ -528,27 +524,32 @@ impl Db {
         }
 
         let length = offset as usize + data.len();
-        match value {
-            Some(Value::BlobRw(bytes)) => {
-                if bytes.capacity() < length {
-                    bytes.resize(length, 0);
+        if let Some(value) = value.as_mut() {
+            match value.deref_mut() {
+                Value::BlobRw(ref mut bytes) => {
+                    if bytes.capacity() < length {
+                        bytes.resize(length, 0);
+                    }
+                    let writer = &mut bytes[offset as usize..length];
+                    writer.copy_from_slice(data);
+                    Ok(bytes.len().into())
                 }
-                let writer = &mut bytes[offset as usize..length];
-                writer.copy_from_slice(data);
-                Ok(bytes.len().into())
+                _ => Err(Error::WrongType),
             }
-            None => {
-                if data.is_empty() {
-                    return Ok(0.into());
-                }
-                let mut bytes = BytesMut::new();
-                bytes.resize(length, 0);
-                let writer = &mut bytes[offset as usize..];
-                writer.copy_from_slice(data);
-                slot.insert(key.clone(), Entry::new(Value::new(&bytes), None));
-                Ok(bytes.len().into())
+        } else {
+            drop(value);
+            drop(slot);
+            if data.is_empty() {
+                return Ok(0.into());
             }
-            _ => Err(Error::WrongType),
+            let mut bytes = BytesMut::new();
+            bytes.resize(length, 0);
+            let writer = &mut bytes[offset as usize..];
+            writer.copy_from_slice(data);
+            self.slots[slot_id]
+                .write()
+                .insert(key.clone(), Entry::new(Value::new(&bytes), None));
+            Ok(bytes.len().into())
         }
     }
 
@@ -575,14 +576,9 @@ impl Db {
             if replace == Override::No && db.exists(&[target.clone()]) > 0 {
                 return Ok(false);
             }
-            let _ = db.set_advanced(
-                target,
-                value.value.clone(),
-                value.get_ttl().map(|v| v - Instant::now()),
-                replace,
-                false,
-                false,
-            );
+
+            let ttl = value.get_ttl().map(|v| v - Instant::now());
+            let _ = db.set_advanced(target, value.take_value(), ttl, replace, false, false);
             Ok(true)
         } else {
             if source == target {
@@ -608,7 +604,7 @@ impl Db {
         let (expires_in, value) = if let Some(value) = slot.get(&source).filter(|v| v.is_valid()) {
             (
                 value.get_ttl().map(|t| t - Instant::now()),
-                value.value.clone(),
+                value.get().clone(),
             )
         } else {
             return Ok(false);
@@ -766,10 +762,11 @@ impl Db {
         let slot = self.slots[self.get_slot(key)].read();
         let entry = slot.get(key).filter(|x| x.is_valid()).map(|e| e.get());
 
-        if let Some(entry) = entry {
+        if let Some(entry) = entry.as_ref() {
             found(Some(entry))
         } else {
             // drop lock
+            drop(entry);
             drop(slot);
             found(None)
         }
@@ -818,23 +815,13 @@ impl Db {
             .collect()
     }
 
-    /// Returns the version of a given key
-    #[inline]
-    pub fn get_version(&self, key: &Bytes) -> usize {
-        let slot = self.slots[self.get_slot(key)].read();
-        slot.get(key)
-            .filter(|x| x.is_valid())
-            .map(|entry| entry.version())
-            .unwrap_or_default()
-    }
-
     /// Returns the name of the value type
     pub fn get_data_type(&self, key: &Bytes) -> String {
         let slot = self.slots[self.get_slot(key)].read();
         slot.get(key)
             .filter(|x| x.is_valid())
             .map_or("none".to_owned(), |x| {
-                Typ::get_type(x.get()).to_string().to_lowercase()
+                Typ::get_type(&x.get()).to_string().to_lowercase()
             })
     }
 
@@ -900,21 +887,20 @@ impl Db {
 
     /// Set a key, value with an optional expiration time
     pub fn append(&self, key: &Bytes, value_to_append: &Bytes) -> Result<Value, Error> {
-        let mut slot = self.slots[self.get_slot(key)].write();
+        let slot = self.slots[self.get_slot(key)].read();
 
-        if let Some(entry) = slot.get_mut(key).filter(|x| x.is_valid()) {
-            if let Value::Blob(data) = entry.get() {
-                let rw_data = BytesMut::from(&data[..]);
-                entry.change_value(Value::BlobRw(rw_data));
-            }
-            match entry.get_mut() {
-                Value::BlobRw(value) => {
+        if let Some(entry) = slot.get(key).filter(|x| x.is_valid()) {
+            entry.ensure_blob_is_mutable()?;
+            match *entry.get_mut() {
+                Value::BlobRw(ref mut value) => {
                     value.put(value_to_append.as_ref());
                     Ok(value.len().into())
                 }
                 _ => Err(Error::WrongType),
             }
         } else {
+            drop(slot);
+            let mut slot = self.slots[self.get_slot(key)].write();
             slot.insert(key.clone(), Entry::new(Value::new(value_to_append), None));
             Ok(value_to_append.len().into())
         }
@@ -1150,7 +1136,7 @@ impl scan::Scan for Db {
                     }
                 }
                 if let Some(typ) = &typ {
-                    if !typ.is_value_type(value.get()) {
+                    if !typ.is_value_type(&value.get()) {
                         last_pos += 1;
                         continue;
                     }
@@ -1194,7 +1180,7 @@ mod test {
         assert_eq!(Error::NotANumber, r.expect_err("should fail"));
         assert_eq!(
             Value::Blob(bytes!("some string")),
-            db.get(&bytes!("num")).inner()
+            db.get(&bytes!("num")).into_inner()
         );
     }
 
@@ -1204,7 +1190,10 @@ mod test {
         db.set(bytes!(b"num"), Value::Blob(bytes!("1.1")), None);
 
         assert_eq!(Ok(2.2.into()), db.incr::<Float>(&bytes!("num"), 1.1.into()));
-        assert_eq!(Value::Blob(bytes!("2.2")), db.get(&bytes!("num")).inner());
+        assert_eq!(
+            Value::Blob(bytes!("2.2")),
+            db.get(&bytes!("num")).into_inner()
+        );
     }
 
     #[test]
@@ -1213,7 +1202,10 @@ mod test {
         db.set(bytes!(b"num"), Value::Blob(bytes!("1")), None);
 
         assert_eq!(Ok(2.1.into()), db.incr::<Float>(&bytes!("num"), 1.1.into()));
-        assert_eq!(Value::Blob(bytes!("2.1")), db.get(&bytes!("num")).inner());
+        assert_eq!(
+            Value::Blob(bytes!("2.1")),
+            db.get(&bytes!("num")).into_inner()
+        );
     }
 
     #[test]
@@ -1222,21 +1214,30 @@ mod test {
         db.set(bytes!(b"num"), Value::Blob(bytes!("1")), None);
 
         assert_eq!(Ok(2), db.incr(&bytes!("num"), 1));
-        assert_eq!(Value::Blob(bytes!("2")), db.get(&bytes!("num")).inner());
+        assert_eq!(
+            Value::Blob(bytes!("2")),
+            db.get(&bytes!("num")).into_inner()
+        );
     }
 
     #[test]
     fn incr_blob_int_set() {
         let db = Db::new(100);
         assert_eq!(Ok(1), db.incr(&bytes!("num"), 1));
-        assert_eq!(Value::Blob(bytes!("1")), db.get(&bytes!("num")).inner());
+        assert_eq!(
+            Value::Blob(bytes!("1")),
+            db.get(&bytes!("num")).into_inner()
+        );
     }
 
     #[test]
     fn incr_blob_float_set() {
         let db = Db::new(100);
         assert_eq!(Ok(1.1.into()), db.incr::<Float>(&bytes!("num"), 1.1.into()));
-        assert_eq!(Value::Blob(bytes!("1.1")), db.get(&bytes!("num")).inner());
+        assert_eq!(
+            Value::Blob(bytes!("1.1")),
+            db.get(&bytes!("num")).into_inner()
+        );
     }
 
     #[test]
@@ -1274,7 +1275,7 @@ mod test {
     fn persist_bug() {
         let db = Db::new(100);
         db.set(bytes!(b"one"), Value::Ok, Some(Duration::from_secs(1)));
-        assert_eq!(Value::Ok, db.get(&bytes!(b"one")).inner());
+        assert_eq!(Value::Ok, db.get(&bytes!(b"one")).into_inner());
         assert!(db.is_key_in_expiration_list(&bytes!(b"one")));
         db.persist(&bytes!(b"one"));
         assert!(!db.is_key_in_expiration_list(&bytes!(b"one")));
@@ -1286,13 +1287,13 @@ mod test {
         db.set(bytes!(b"one"), Value::Ok, Some(Duration::from_secs(0)));
         // Expired keys should not be returned, even if they are not yet
         // removed by the purge process.
-        assert_eq!(Value::Null, db.get(&bytes!(b"one")).inner());
+        assert_eq!(Value::Null, db.get(&bytes!(b"one")).into_inner());
 
         // Purge twice
         assert_eq!(1, db.purge());
         assert_eq!(0, db.purge());
 
-        assert_eq!(Value::Null, db.get(&bytes!(b"one")).inner());
+        assert_eq!(Value::Null, db.get(&bytes!(b"one")).into_inner());
     }
 
     #[test]
@@ -1301,10 +1302,10 @@ mod test {
         db.set(bytes!(b"one"), Value::Ok, Some(Duration::from_secs(0)));
         // Expired keys should not be returned, even if they are not yet
         // removed by the purge process.
-        assert_eq!(Value::Null, db.get(&bytes!(b"one")).inner());
+        assert_eq!(Value::Null, db.get(&bytes!(b"one")).into_inner());
 
         db.set(bytes!(b"one"), Value::Ok, Some(Duration::from_secs(5)));
-        assert_eq!(Value::Ok, db.get(&bytes!(b"one")).inner());
+        assert_eq!(Value::Ok, db.get(&bytes!(b"one")).into_inner());
 
         // Purge should return 0 as the expired key has been removed already
         assert_eq!(0, db.purge());