Selaa lähdekoodia

Remove all unwrap() from the codebase (#13)

* Removed all `unwrap()`, from tests and macros.
* Replaced RwLock and Mutex from the native library to use
  `parking_lot`, it is supposed to be faster and no `unwrap` needed to
  use it.
César D. Rodas 3 vuotta sitten
vanhempi
säilyke
ce42fba85f
8 muutettua tiedostoa jossa 66 lisäystä ja 74 poistoa
  1. 1 0
      Cargo.toml
  2. 1 5
      src/cmd/list.rs
  3. 5 6
      src/cmd/transaction.rs
  4. 20 20
      src/connection.rs
  5. 31 33
      src/db/mod.rs
  6. 2 2
      src/macros.rs
  7. 5 5
      src/value/locked.rs
  8. 1 3
      src/value/mod.rs

+ 1 - 0
Cargo.toml

@@ -9,6 +9,7 @@ edition = "2018"
 [dependencies]
 redis-zero-protocol-parser = {path = "redis-protocol-parser"}
 tokio={version="1", features = ["full", "tracing"] }
+parking_lot="^0.11"
 tokio-util={version="^0.6", features = ["full"] }
 async-trait = "0.1.50"
 crc32fast="^1.2"

+ 1 - 5
src/cmd/list.rs

@@ -35,11 +35,7 @@ fn remove_element(
                     }
                 }
 
-                let ret: Vec<Value> = ret
-                    .iter()
-                    .filter(|v| v.is_some())
-                    .map(|x| x.as_ref().unwrap().clone_value())
-                    .collect();
+                let ret: Vec<Value> = ret.iter().flatten().map(|m| m.clone_value()).collect();
 
                 Ok(if ret.is_empty() {
                     Value::Null

+ 5 - 6
src/cmd/transaction.rs

@@ -192,11 +192,10 @@ mod test {
 
     fn get_keys(args: &[&str]) -> Vec<Bytes> {
         let args: Vec<Bytes> = args.iter().map(|s| Bytes::from(s.to_string())).collect();
-        Dispatcher::new(&args)
-            .unwrap()
-            .get_keys(&args)
-            .iter()
-            .map(|k| (*k).clone())
-            .collect()
+        if let Ok(cmd) = Dispatcher::new(&args) {
+            cmd.get_keys(&args).iter().map(|k| (*k).clone()).collect()
+        } else {
+            vec![]
+        }
     }
 }

+ 20 - 20
src/connection.rs

@@ -1,9 +1,10 @@
 use crate::{db::Db, error::Error, value::Value};
 use bytes::Bytes;
+use parking_lot::RwLock;
 use std::{
     collections::{BTreeMap, HashSet},
     net::SocketAddr,
-    sync::{Arc, RwLock},
+    sync::Arc,
 };
 
 #[derive(Debug)]
@@ -42,7 +43,7 @@ impl Connections {
 
     pub fn remove(self: Arc<Connections>, conn: Arc<Connection>) {
         let id = conn.id();
-        self.connections.write().unwrap().remove(&id);
+        self.connections.write().remove(&id);
     }
 
     pub fn new_connection(
@@ -50,7 +51,7 @@ impl Connections {
         db: Arc<Db>,
         addr: SocketAddr,
     ) -> Arc<Connection> {
-        let mut id = self.counter.write().unwrap();
+        let mut id = self.counter.write();
         *id += 1;
 
         let conn = Arc::new(Connection {
@@ -62,12 +63,12 @@ impl Connections {
             info: RwLock::new(ConnectionInfo::new()),
         });
 
-        self.connections.write().unwrap().insert(*id, conn.clone());
+        self.connections.write().insert(*id, conn.clone());
         conn
     }
 
     pub fn iter(&self, f: &mut dyn FnMut(Arc<Connection>)) {
-        for (_, value) in self.connections.read().unwrap().iter() {
+        for (_, value) in self.connections.read().iter() {
             f(value.clone())
         }
     }
@@ -96,7 +97,7 @@ impl Connection {
     }
 
     pub fn stop_transaction(&self) -> Result<Value, Error> {
-        let info = &mut self.info.write().unwrap();
+        let info = &mut self.info.write();
         if info.in_transaction {
             info.commands = None;
             info.watch_keys.clear();
@@ -111,7 +112,7 @@ impl Connection {
     }
 
     pub fn start_transaction(&self) -> Result<Value, Error> {
-        let mut info = self.info.write().unwrap();
+        let mut info = self.info.write();
         if !info.in_transaction {
             info.in_transaction = true;
             Ok(Value::Ok)
@@ -123,24 +124,24 @@ impl Connection {
     /// We are inside a MULTI, most transactions are rather queued for later
     /// execution instead of being executed right away.
     pub fn in_transaction(&self) -> bool {
-        self.info.read().unwrap().in_transaction
+        self.info.read().in_transaction
     }
 
     /// The commands are being executed inside a transaction (by EXEC). It is
     /// important to keep track of this because some commands change their
     /// behaviour.
     pub fn is_executing_transaction(&self) -> bool {
-        self.info.read().unwrap().in_executing_transaction
+        self.info.read().in_executing_transaction
     }
 
     /// EXEC has been called and we need to keep track
     pub fn start_executing_transaction(&self) {
-        let info = &mut self.info.write().unwrap();
+        let info = &mut self.info.write();
         info.in_executing_transaction = true;
     }
 
     pub fn watch_key(&self, keys: &[(&Bytes, u128)]) {
-        let watch_keys = &mut self.info.write().unwrap().watch_keys;
+        let watch_keys = &mut self.info.write().watch_keys;
         keys.iter()
             .map(|(bytes, version)| {
                 watch_keys.push(((*bytes).clone(), *version));
@@ -149,7 +150,7 @@ impl Connection {
     }
 
     pub fn did_keys_change(&self) -> bool {
-        let watch_keys = &self.info.read().unwrap().watch_keys;
+        let watch_keys = &self.info.read().watch_keys;
 
         for key in watch_keys.iter() {
             if self.db.get_version(&key.0) != key.1 {
@@ -161,14 +162,13 @@ impl Connection {
     }
 
     pub fn discard_watched_keys(&self) {
-        let watch_keys = &mut self.info.write().unwrap().watch_keys;
+        let watch_keys = &mut self.info.write().watch_keys;
         watch_keys.clear();
     }
 
     pub fn get_tx_keys(&self) -> Vec<Bytes> {
         self.info
             .read()
-            .unwrap()
             .tx_keys
             .iter()
             .cloned()
@@ -176,13 +176,13 @@ impl Connection {
     }
 
     pub fn queue_command(&self, args: &[Bytes]) {
-        let info = &mut self.info.write().unwrap();
+        let info = &mut self.info.write();
         let commands = info.commands.get_or_insert(vec![]);
         commands.push(args.iter().map(|m| (*m).clone()).collect());
     }
 
     pub fn get_queue_commands(&self) -> Option<Vec<Vec<Bytes>>> {
-        let info = &mut self.info.write().unwrap();
+        let info = &mut self.info.write();
         info.watch_keys = vec![];
         info.in_transaction = false;
         info.commands.take()
@@ -190,7 +190,7 @@ impl Connection {
 
     pub fn tx_keys(&self, keys: Vec<&Bytes>) {
         #[allow(clippy::mutable_key_type)]
-        let tx_keys = &mut self.info.write().unwrap().tx_keys;
+        let tx_keys = &mut self.info.write().tx_keys;
         keys.iter()
             .map(|k| {
                 tx_keys.insert((*k).clone());
@@ -207,11 +207,11 @@ impl Connection {
     }
 
     pub fn name(&self) -> Option<String> {
-        self.info.read().unwrap().name.clone()
+        self.info.read().name.clone()
     }
 
     pub fn set_name(&self, name: String) {
-        let mut r = self.info.write().unwrap();
+        let mut r = self.info.write();
         r.name = Some(name);
     }
 
@@ -225,7 +225,7 @@ impl Connection {
             "id={} addr={} name={:?} db={}\r\n",
             self.id,
             self.addr,
-            self.info.read().unwrap().name,
+            self.info.read().name,
             self.current_db
         )
     }

+ 31 - 33
src/db/mod.rs

@@ -6,12 +6,13 @@ use bytes::Bytes;
 use entry::{new_version, Entry};
 use expiration::ExpirationDb;
 use log::trace;
+use parking_lot::{Mutex, RwLock};
 use seahash::hash;
 use std::{
     collections::HashMap,
     convert::{TryFrom, TryInto},
     ops::AddAssign,
-    sync::{Arc, Mutex, RwLock},
+    sync::Arc,
     thread,
 };
 use tokio::time::{Duration, Instant};
@@ -81,7 +82,7 @@ impl Db {
 
         let waiting = Duration::from_nanos(100);
 
-        while let Some(blocker) = self.tx_key_locks.read().unwrap().get(key) {
+        while let Some(blocker) = self.tx_key_locks.read().get(key) {
             // Loop while the key we are trying to access is being blocked by a
             // connection in a transaction
             if *blocker == self.conn_id {
@@ -99,7 +100,7 @@ impl Db {
     pub fn lock_keys(&self, keys: &[Bytes]) {
         let waiting = Duration::from_nanos(100);
         loop {
-            let mut lock = self.tx_key_locks.write().unwrap();
+            let mut lock = self.tx_key_locks.write();
             let mut i = 0;
 
             for key in keys.iter() {
@@ -129,7 +130,7 @@ impl Db {
     }
 
     pub fn unlock_keys(&self, keys: &[Bytes]) {
-        let mut lock = self.tx_key_locks.write().unwrap();
+        let mut lock = self.tx_key_locks.write();
         for key in keys.iter() {
             lock.remove(key);
         }
@@ -142,7 +143,7 @@ impl Db {
         key: &Bytes,
         incr_by: T,
     ) -> Result<Value, Error> {
-        let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+        let mut entries = self.entries[self.get_slot(key)].write();
         match entries.get_mut(key) {
             Some(x) => {
                 let value = x.get();
@@ -165,7 +166,7 @@ impl Db {
     }
 
     pub fn persist(&self, key: &Bytes) -> Value {
-        let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+        let mut entries = self.entries[self.get_slot(key)].write();
         entries
             .get_mut(key)
             .filter(|x| x.is_valid())
@@ -180,29 +181,26 @@ impl Db {
     }
 
     pub fn set_ttl(&self, key: &Bytes, expires_in: Duration) -> Value {
-        let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+        let mut entries = self.entries[self.get_slot(key)].write();
         let expires_at = Instant::now() + expires_in;
 
         entries
             .get_mut(key)
             .filter(|x| x.is_valid())
             .map_or(0.into(), |x| {
-                self.expirations.lock().unwrap().add(key, expires_at);
+                self.expirations.lock().add(key, expires_at);
                 x.set_ttl(expires_at);
                 1.into()
             })
     }
 
     pub fn del(&self, keys: &[Bytes]) -> Value {
-        let mut expirations = self.expirations.lock().unwrap();
+        let mut expirations = self.expirations.lock();
 
         keys.iter()
             .filter_map(|key| {
                 expirations.remove(key);
-                self.entries[self.get_slot(key)]
-                    .write()
-                    .unwrap()
-                    .remove(key)
+                self.entries[self.get_slot(key)].write().remove(key)
             })
             .filter(|key| key.is_valid())
             .count()
@@ -213,7 +211,7 @@ impl Db {
         let mut matches = 0;
         keys.iter()
             .map(|key| {
-                let entries = self.entries[self.get_slot(key)].read().unwrap();
+                let entries = self.entries[self.get_slot(key)].read();
                 if entries.get(key).is_some() {
                     matches += 1;
                 }
@@ -228,7 +226,7 @@ impl Db {
         F1: FnOnce(&Value) -> Result<Value, Error>,
         F2: FnOnce() -> Result<Value, Error>,
     {
-        let entries = self.entries[self.get_slot(key)].read().unwrap();
+        let entries = self.entries[self.get_slot(key)].read();
         let entry = entries.get(key).filter(|x| x.is_valid()).map(|e| e.get());
 
         if let Some(entry) = entry {
@@ -242,7 +240,7 @@ impl Db {
     }
 
     pub fn bump_version(&self, key: &Bytes) -> bool {
-        let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+        let mut entries = self.entries[self.get_slot(key)].write();
         entries
             .get_mut(key)
             .filter(|x| x.is_valid())
@@ -253,7 +251,7 @@ impl Db {
     }
 
     pub fn get_version(&self, key: &Bytes) -> u128 {
-        let entries = self.entries[self.get_slot(key)].read().unwrap();
+        let entries = self.entries[self.get_slot(key)].read();
         entries
             .get(key)
             .filter(|x| x.is_valid())
@@ -262,7 +260,7 @@ impl Db {
     }
 
     pub fn get(&self, key: &Bytes) -> Value {
-        let entries = self.entries[self.get_slot(key)].read().unwrap();
+        let entries = self.entries[self.get_slot(key)].read();
         entries
             .get(key)
             .filter(|x| x.is_valid())
@@ -272,7 +270,7 @@ impl Db {
     pub fn get_multi(&self, keys: &[Bytes]) -> Value {
         keys.iter()
             .map(|key| {
-                let entries = self.entries[self.get_slot(key)].read().unwrap();
+                let entries = self.entries[self.get_slot(key)].read();
                 entries
                     .get(key)
                     .filter(|x| x.is_valid() && x.is_clonable())
@@ -283,8 +281,8 @@ impl Db {
     }
 
     pub fn getset(&self, key: &Bytes, value: &Value) -> Value {
-        let mut entries = self.entries[self.get_slot(key)].write().unwrap();
-        self.expirations.lock().unwrap().remove(key);
+        let mut entries = self.entries[self.get_slot(key)].write();
+        self.expirations.lock().remove(key);
         entries
             .insert(key.clone(), Entry::new(value.clone(), None))
             .filter(|x| x.is_valid())
@@ -292,26 +290,26 @@ impl Db {
     }
 
     pub fn getdel(&self, key: &Bytes) -> Value {
-        let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+        let mut entries = self.entries[self.get_slot(key)].write();
         entries.remove(key).map_or(Value::Null, |x| {
-            self.expirations.lock().unwrap().remove(key);
+            self.expirations.lock().remove(key);
             x.clone_value()
         })
     }
 
     pub fn set(&self, key: &Bytes, value: Value, expires_in: Option<Duration>) -> Value {
-        let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+        let mut entries = self.entries[self.get_slot(key)].write();
         let expires_at = expires_in.map(|duration| Instant::now() + duration);
 
         if let Some(expires_at) = expires_at {
-            self.expirations.lock().unwrap().add(key, expires_at);
+            self.expirations.lock().add(key, expires_at);
         }
         entries.insert(key.clone(), Entry::new(value, expires_at));
         Value::Ok
     }
 
     pub fn ttl(&self, key: &Bytes) -> Option<Option<Instant>> {
-        let entries = self.entries[self.get_slot(key)].read().unwrap();
+        let entries = self.entries[self.get_slot(key)].read();
         entries
             .get(key)
             .filter(|x| x.is_valid())
@@ -319,7 +317,7 @@ impl Db {
     }
 
     pub fn purge(&self) -> u64 {
-        let mut expirations = self.expirations.lock().unwrap();
+        let mut expirations = self.expirations.lock();
         let mut removed = 0;
 
         trace!("Watching {} keys for expirations", expirations.len());
@@ -329,7 +327,7 @@ impl Db {
 
         keys.iter()
             .map(|key| {
-                let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+                let mut entries = self.entries[self.get_slot(key)].write();
                 if entries.remove(key).is_some() {
                     trace!("Removed key {:?} due timeout", key);
                     removed += 1;
@@ -363,7 +361,7 @@ mod test {
         let db = Db::new(100);
         db.set(&bytes!(b"num"), Value::Blob(bytes!("1.1")), None);
 
-        assert_eq!(Value::Float(2.2), db.incr(&bytes!("num"), 1.1).unwrap());
+        assert_eq!(Ok(Value::Float(2.2)), db.incr(&bytes!("num"), 1.1));
         assert_eq!(Value::Blob(bytes!("2.2")), db.get(&bytes!("num")));
     }
 
@@ -372,7 +370,7 @@ mod test {
         let db = Db::new(100);
         db.set(&bytes!(b"num"), Value::Blob(bytes!("1")), None);
 
-        assert_eq!(Value::Float(2.1), db.incr(&bytes!("num"), 1.1).unwrap());
+        assert_eq!(Ok(Value::Float(2.1)), db.incr(&bytes!("num"), 1.1));
         assert_eq!(Value::Blob(bytes!("2.1")), db.get(&bytes!("num")));
     }
 
@@ -381,21 +379,21 @@ mod test {
         let db = Db::new(100);
         db.set(&bytes!(b"num"), Value::Blob(bytes!("1")), None);
 
-        assert_eq!(Value::Integer(2), db.incr(&bytes!("num"), 1).unwrap());
+        assert_eq!(Ok(Value::Integer(2)), db.incr(&bytes!("num"), 1));
         assert_eq!(Value::Blob(bytes!("2")), db.get(&bytes!("num")));
     }
 
     #[test]
     fn incr_blob_int_set() {
         let db = Db::new(100);
-        assert_eq!(Value::Integer(1), db.incr(&bytes!("num"), 1).unwrap());
+        assert_eq!(Ok(Value::Integer(1)), db.incr(&bytes!("num"), 1));
         assert_eq!(Value::Blob(bytes!("1")), db.get(&bytes!("num")));
     }
 
     #[test]
     fn incr_blob_float_set() {
         let db = Db::new(100);
-        assert_eq!(Value::Float(1.1), db.incr(&bytes!("num"), 1.1).unwrap());
+        assert_eq!(Ok(Value::Float(1.1)), db.incr(&bytes!("num"), 1.1));
         assert_eq!(Value::Blob(bytes!("1.1")), db.get(&bytes!("num")));
     }
 

+ 2 - 2
src/macros.rs

@@ -78,9 +78,9 @@ macro_rules! dispatcher {
 
                     fn check_number_args(&self, n: usize) -> bool {
                         if ($min_args >= 0) {
-                            n == ($min_args as i32).try_into().unwrap()
+                            n == ($min_args as i32).try_into().unwrap_or(0)
                         } else {
-                            let s: usize = ($min_args as i32).abs().try_into().unwrap();
+                            let s: usize = ($min_args as i32).abs().try_into().unwrap_or(0);
                             n >= s
                         }
                     }

+ 5 - 5
src/value/locked.rs

@@ -1,17 +1,17 @@
-use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
+use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
 
 #[derive(Debug)]
 pub struct Value<T: Clone + PartialEq>(pub RwLock<T>);
 
 impl<T: Clone + PartialEq> Clone for Value<T> {
     fn clone(&self) -> Self {
-        Self(RwLock::new(self.0.read().unwrap().clone()))
+        Self(RwLock::new(self.0.read().clone()))
     }
 }
 
 impl<T: PartialEq + Clone> PartialEq for Value<T> {
     fn eq(&self, other: &Value<T>) -> bool {
-        self.0.read().unwrap().eq(&other.0.read().unwrap())
+        self.0.read().eq(&other.0.read())
     }
 }
 
@@ -21,11 +21,11 @@ impl<T: PartialEq + Clone> Value<T> {
     }
 
     pub fn write(&self) -> RwLockWriteGuard<'_, T> {
-        self.0.write().unwrap()
+        self.0.write()
     }
 
     pub fn read(&self) -> RwLockReadGuard<'_, T> {
-        self.0.read().unwrap()
+        self.0.read()
     }
 }
 

+ 1 - 3
src/value/mod.rs

@@ -101,9 +101,7 @@ impl<'a> From<&ParsedValue<'a>> for Value {
         match value {
             ParsedValue::String(x) => Self::String((*x).to_string()),
             ParsedValue::Blob(x) => Self::Blob(Bytes::copy_from_slice(*x)),
-            ParsedValue::Array(x) => {
-                Self::Array(x.iter().map(|x| Value::try_from(x).unwrap()).collect())
-            }
+            ParsedValue::Array(x) => Self::Array(x.iter().map(|x| x.into()).collect()),
             ParsedValue::Boolean(x) => Self::Boolean(*x),
             ParsedValue::BigInteger(x) => Self::BigInteger(*x),
             ParsedValue::Integer(x) => Self::Integer(*x),