Browse Source

Code improvements (#10)

* Get rid of all unsafe code blocks.

Unsafe code blocks were used to transform an stream of bytes into an
UTF-8 string for quick comparison or number conversions.

@richardpringle showed me a better way of doing this.

* More improvements.

* Added conversion from usize to Value::Integer
* Rewrote db::del with a better style (suggestion by @richardpringle)

* Yet another suggestion by @richardpringle

Rewrite block code in a more rust-like style.

* More changes
César D. Rodas 3 years ago
parent
commit
b8d7c1382c
9 changed files with 58 additions and 58 deletions
  1. 2 2
      src/cmd/client.rs
  2. 10 13
      src/cmd/hash.rs
  3. 9 9
      src/cmd/list.rs
  4. 2 2
      src/cmd/set.rs
  5. 2 2
      src/cmd/string.rs
  6. 11 12
      src/cmd/transaction.rs
  7. 10 12
      src/db/mod.rs
  8. 5 5
      src/macros.rs
  9. 7 1
      src/value/mod.rs

+ 2 - 2
src/cmd/client.rs

@@ -3,7 +3,7 @@ use bytes::Bytes;
 use std::sync::Arc;
 
 pub async fn client(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
-    let sub = unsafe { std::str::from_utf8_unchecked(&args[1]) }.to_string();
+    let sub = String::from_utf8_lossy(&args[1]);
 
     let expected = match sub.to_lowercase().as_str() {
         "setname" => 3,
@@ -28,7 +28,7 @@ pub async fn client(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
             Ok(v.into())
         }
         "setname" => {
-            let name = unsafe { std::str::from_utf8_unchecked(&args[2]) }.to_string();
+            let name = String::from_utf8_lossy(&args[2]).to_string();
             conn.set_name(name);
             Ok(Value::Ok)
         }

+ 10 - 13
src/cmd/hash.rs

@@ -146,7 +146,7 @@ pub async fn hlen(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     conn.db().get_map_or(
         &args[1],
         |v| match v {
-            Value::Hash(h) => Ok((h.read().len() as i64).into()),
+            Value::Hash(h) => Ok(h.read().len().into()),
             _ => Err(Error::WrongType),
         },
         || Ok(0.into()),
@@ -190,14 +190,11 @@ pub async fn hrandfield(conn: &Connection, args: &[Bytes]) -> Result<Value, Erro
         }
         _ => return Err(Error::InvalidArgsCount("hrandfield".to_owned())),
     };
-    let (count, single, repeat) = if let Some(count) = count {
-        if count > 0 {
-            (count, false, 1)
-        } else {
-            (count.abs(), false, count.abs())
-        }
-    } else {
-        (1, true, 1)
+
+    let (count, single, repeat) = match count {
+        Some(count) if count > 0 => (count, false, 1),
+        Some(count) => (count.abs(), false, count.abs()),
+        _ => (1, true, 1),
     };
 
     conn.db().get_map_or(
@@ -270,7 +267,7 @@ pub async fn hset(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
             for i in (2..args.len()).step_by(2) {
                 h.insert(args[i].clone(), args[i + 1].clone());
             }
-            let len = h.len() as i64;
+            let len = h.len();
             conn.db().set(&args[1], h.into(), None);
             Ok(len.into())
         },
@@ -303,7 +300,7 @@ pub async fn hsetnx(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
             for i in (2..args.len()).step_by(2) {
                 h.insert(args[i].clone(), args[i + 1].clone());
             }
-            let len = h.len() as i64;
+            let len = h.len();
             conn.db().set(&args[1], h.into(), None);
             Ok(len.into())
         },
@@ -321,7 +318,7 @@ pub async fn hstrlen(conn: &Connection, args: &[Bytes]) -> Result<Value, Error>
         &args[1],
         |v| match v {
             Value::Hash(h) => Ok(if let Some(v) = h.read().get(&args[2]) {
-                (v.len() as i64).into()
+                v.len().into()
             } else {
                 0.into()
             }),
@@ -399,7 +396,7 @@ mod test {
         let r = run_command(&c, &["hrandfield", "foo"]).await;
         match r {
             Ok(Value::Blob(x)) => {
-                let x = unsafe { std::str::from_utf8_unchecked(&x) };
+                let x = String::from_utf8_lossy(&x);
                 assert!(x == "f1".to_owned() || x == "f2".to_owned() || x == "f3".to_owned());
             }
             _ => assert!(false),

+ 9 - 9
src/cmd/list.rs

@@ -157,7 +157,7 @@ pub async fn linsert(conn: &Connection, args: &[Bytes]) -> Result<Value, Error>
                 }
 
                 if found {
-                    Ok((x.len() as i64).into())
+                    Ok(x.len().into())
                 } else {
                     Ok((-1).into())
                 }
@@ -176,7 +176,7 @@ pub async fn llen(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     conn.db().get_map_or(
         &args[1],
         |v| match v {
-            Value::List(x) => Ok((x.read().len() as i64).into()),
+            Value::List(x) => Ok(x.read().len().into()),
             _ => Err(Error::WrongType),
         },
         || Ok(0.into()),
@@ -302,18 +302,18 @@ pub async fn lpos(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
                     if *val == element {
                         // Match!
                         if let Some(count) = count {
-                            ret.push((i as i64).into());
+                            ret.push(i.into());
                             if ret.len() > count {
                                 return Ok(ret.into());
                             }
                         } else if let Some(rank) = rank {
-                            ret.push((i as i64).into());
+                            ret.push(i.into());
                             if ret.len() == rank {
                                 return Ok(ret[rank - 1].clone());
                             }
                         } else {
                             // return first match!
-                            return Ok((i as i64).into());
+                            return Ok(i.into());
                         }
                     }
                     if (i as i64) == max_len {
@@ -350,7 +350,7 @@ pub async fn lpush(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
                 for val in args.iter().skip(2) {
                     x.push_front(checksum::Value::new(val.clone()));
                 }
-                Ok((x.len() as i64).into())
+                Ok(x.len().into())
             }
             _ => Err(Error::WrongType),
         },
@@ -364,7 +364,7 @@ pub async fn lpush(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
                 h.push_front(checksum::Value::new(val.clone()));
             }
 
-            let len = h.len() as i64;
+            let len = h.len();
             conn.db().set(&args[1], h.into(), None);
             Ok(len.into())
         },
@@ -557,7 +557,7 @@ pub async fn rpush(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
                 for val in args.iter().skip(2) {
                     x.push_back(checksum::Value::new(val.clone()));
                 }
-                Ok((x.len() as i64).into())
+                Ok(x.len().into())
             }
             _ => Err(Error::WrongType),
         },
@@ -571,7 +571,7 @@ pub async fn rpush(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
                 h.push_back(checksum::Value::new(val.clone()));
             }
 
-            let len = h.len() as i64;
+            let len = h.len();
             conn.db().set(&args[1], h.into(), None);
             Ok(len.into())
         },

+ 2 - 2
src/cmd/set.rs

@@ -106,7 +106,7 @@ pub async fn scard(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     conn.db().get_map_or(
         &args[1],
         |v| match v {
-            Value::Set(x) => Ok((x.read().len() as i64).into()),
+            Value::Set(x) => Ok(x.read().len().into()),
             _ => Err(Error::WrongType),
         },
         || Ok(0.into()),
@@ -150,7 +150,7 @@ pub async fn sinter(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
 
 pub async fn sintercard(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     if let Ok(Value::Array(x)) = sinter(conn, args).await {
-        Ok((x.len() as i64).into())
+        Ok(x.len().into())
     } else {
         Ok(0.into())
     }

+ 2 - 2
src/cmd/string.rs

@@ -64,8 +64,8 @@ pub async fn setex(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
 
 pub async fn strlen(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     match conn.db().get(&args[1]) {
-        Value::Blob(x) => Ok((x.len() as i64).into()),
-        Value::String(x) => Ok((x.len() as i64).into()),
+        Value::Blob(x) => Ok(x.len().into()),
+        Value::String(x) => Ok(x.len().into()),
         Value::Null => Ok(0.into()),
         _ => Ok(Error::WrongType.into()),
     }

+ 11 - 12
src/cmd/transaction.rs

@@ -127,10 +127,7 @@ mod test {
         );
         assert_eq!(Ok(Value::Queued), run_command(&c, &["get", "foo"]).await);
         assert_eq!(Ok(Value::Ok), run_command(&c, &["discard"]).await);
-        assert_eq!(
-            Err(Error::NotInTx),
-            run_command(&c, &["exec"]).await
-        );
+        assert_eq!(Err(Error::NotInTx), run_command(&c, &["exec"]).await);
     }
 
     #[tokio::test]
@@ -168,11 +165,12 @@ mod test {
         let c = create_connection();
 
         assert_eq!(Ok(Value::Ok), run_command(&c, &["multi"]).await);
-        assert_eq!(Ok(Value::Queued), run_command(&c, &["brpop", "foo", "1000"]).await);
         assert_eq!(
-            Ok(Value::Array(vec![
-                Value::Null,
-            ])),
+            Ok(Value::Queued),
+            run_command(&c, &["brpop", "foo", "1000"]).await
+        );
+        assert_eq!(
+            Ok(Value::Array(vec![Value::Null,])),
             run_command(&c, &["exec"]).await
         );
     }
@@ -182,11 +180,12 @@ mod test {
         let c = create_connection();
 
         assert_eq!(Ok(Value::Ok), run_command(&c, &["multi"]).await);
-        assert_eq!(Ok(Value::Queued), run_command(&c, &["blpop", "foo", "1000"]).await);
         assert_eq!(
-            Ok(Value::Array(vec![
-                Value::Null,
-            ])),
+            Ok(Value::Queued),
+            run_command(&c, &["blpop", "foo", "1000"]).await
+        );
+        assert_eq!(
+            Ok(Value::Array(vec![Value::Null,])),
             run_command(&c, &["exec"]).await
         );
     }

+ 10 - 12
src/db/mod.rs

@@ -194,21 +194,19 @@ impl Db {
     }
 
     pub fn del(&self, keys: &[Bytes]) -> Value {
-        let mut deleted = 0;
         let mut expirations = self.expirations.lock().unwrap();
+
         keys.iter()
-            .map(|key| {
-                let mut entries = self.entries[self.get_slot(key)].write().unwrap();
-                if let Some(entry) = entries.remove(key) {
-                    expirations.remove(key);
-                    if entry.is_valid() {
-                        deleted += 1;
-                    }
-                }
+            .filter_map(|key| {
+                expirations.remove(key);
+                self.entries[self.get_slot(key)]
+                    .write()
+                    .unwrap()
+                    .remove(key)
             })
-            .for_each(drop);
-
-        deleted.into()
+            .filter(|key| key.is_valid())
+            .count()
+            .into()
     }
 
     pub fn exists(&self, keys: &[Bytes]) -> Value {

+ 5 - 5
src/macros.rs

@@ -95,8 +95,9 @@ macro_rules! dispatcher {
                 }
             }
         )+)+
-        use std::ops::Deref;
+
         use async_trait::async_trait;
+        use std::ops::Deref;
 
         #[async_trait]
         pub trait ExecutableCommand {
@@ -122,9 +123,9 @@ macro_rules! dispatcher {
 
         impl Dispatcher {
             pub fn new(args: &[Bytes]) -> Result<Self, Error> {
-                let command = unsafe { std::str::from_utf8_unchecked(&args[0]) };
+                let command = String::from_utf8_lossy(&args[0]).to_lowercase();
 
-                let command = match command.to_lowercase().as_str() {
+                let command = match command.as_str() {
                 $($(
                     stringify!($command) => Ok(Self::$command($command::Command::new())),
                 )+)+
@@ -193,8 +194,7 @@ macro_rules! check_arg {
     {$args: tt, $pos: tt, $command: tt} => {{
         match $args.get($pos) {
             Some(bytes) => {
-                let command = unsafe { std::str::from_utf8_unchecked(&bytes) };
-                command.to_uppercase() == $command
+                String::from_utf8_lossy(&bytes).to_uppercase() == $command
             },
             None => false,
         }

+ 7 - 1
src/value/mod.rs

@@ -86,7 +86,7 @@ impl TryFrom<&Value> for f64 {
     }
 }
 pub fn bytes_to_number<T: FromStr>(bytes: &Bytes) -> Result<T, Error> {
-    let x = unsafe { std::str::from_utf8_unchecked(bytes) };
+    let x = String::from_utf8_lossy(bytes);
     x.parse::<T>().map_err(|_| Error::NotANumber)
 }
 
@@ -119,6 +119,12 @@ value_try_from!(i32, Value::Integer);
 value_try_from!(i64, Value::Integer);
 value_try_from!(i128, Value::BigInteger);
 
+impl From<usize> for Value {
+    fn from(value: usize) -> Value {
+        Value::Integer(value as i64)
+    }
+}
+
 impl From<&str> for Value {
     fn from(value: &str) -> Value {
         Value::Blob(Bytes::copy_from_slice(value.as_bytes()))