Ver código fonte

Use generic incr() function

Cesar Rodas 3 anos atrás
pai
commit
2bb704f677
4 arquivos alterados com 62 adições e 18 exclusões
  1. 20 7
      src/cmd/string.rs
  2. 18 11
      src/db/mod.rs
  3. 10 0
      src/dispatcher.rs
  4. 14 0
      src/value.rs

+ 20 - 7
src/cmd/string.rs

@@ -1,14 +1,19 @@
 use crate::{connection::Connection, error::Error, value::Value};
 use bytes::Bytes;
-use std::convert::TryInto;
+use std::{convert::TryInto, ops::Neg};
+
+pub fn incr(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    conn.db().incr(&args[1], 1)
+}
 
 pub fn incr_by(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     let by: i64 = (&Value::Blob(args[2].to_owned())).try_into()?;
     conn.db().incr(&args[1], by)
 }
 
-pub fn incr(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
-    conn.db().incr(&args[1], 1)
+pub fn incr_by_float(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    let by: f64 = (&Value::Blob(args[2].to_owned())).try_into()?;
+    conn.db().incr(&args[1], by)
 }
 
 pub fn decr(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
@@ -17,7 +22,7 @@ pub fn decr(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
 
 pub fn decr_by(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     let by: i64 = (&Value::Blob(args[2].to_owned())).try_into()?;
-    conn.db().incr(&args[1], -1 * by)
+    conn.db().incr(&args[1], by.neg())
 }
 
 pub fn get(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
@@ -28,15 +33,23 @@ pub fn getdel(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     Ok(conn.db().getdel(&args[1]))
 }
 
+pub fn getset(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    Ok(conn.db().getset(&args[1], &Value::Blob(args[2].to_owned())))
+}
+
 pub fn mget(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     Ok(conn.db().get_multi(&args[1..]))
 }
 
-
 pub fn set(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     Ok(conn.db().set(&args[1], &Value::Blob(args[2].to_owned())))
 }
 
-pub fn getset(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
-    Ok(conn.db().getset(&args[1], &Value::Blob(args[2].to_owned())))
+pub fn strlen(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    match conn.db().get(&args[1]) {
+        Value::Blob(x) => Ok((x.len() as i64).into()),
+        Value::String(x) => Ok((x.len() as i64).into()),
+        Value::Null => Ok(0_i64.into()),
+        _ => Err(Error::WrongType),
+    }
 }

+ 18 - 11
src/db/mod.rs

@@ -7,7 +7,8 @@ use log::trace;
 use seahash::hash;
 use std::{
     collections::{BTreeMap, HashMap},
-    convert::TryInto,
+    convert::{TryFrom, TryInto},
+    ops::AddAssign,
     sync::RwLock,
 };
 use tokio::time::{Duration, Instant};
@@ -57,22 +58,28 @@ impl Db {
         id
     }
 
-    pub fn incr(&self, key: &Bytes, incr_by: i64) -> Result<Value, Error> {
+    pub fn incr<
+        T: ToString + AddAssign + for<'a> TryFrom<&'a Value, Error = Error> + Into<Value> + Copy,
+    >(
+        &self,
+        key: &Bytes,
+        incr_by: T,
+    ) -> Result<Value, Error> {
         let mut entries = self.entries[self.get_slot(key)].write().unwrap();
         match entries.get_mut(key) {
             Some(x) => {
                 let value = x.get();
-                let mut number: i64 = value.try_into()?;
+                let mut number: T = value.try_into()?;
 
                 number += incr_by;
 
-                x.change_value(format!("{}", number).as_str().into());
+                x.change_value(number.to_string().as_str().into());
 
                 Ok(number.into())
             }
             None => {
-                entries.insert(key.clone(), Entry::new(incr_by.into()));
-                Ok((incr_by as i64).into())
+                entries.insert(key.clone(), Entry::new(incr_by.to_string().as_str().into()));
+                Ok((incr_by as T).into())
             }
         }
     }
@@ -140,9 +147,10 @@ impl Db {
         keys.iter()
             .map(|key| {
                 let entries = self.entries[self.get_slot(key)].read().unwrap();
-                entries.get(key)
-                    .map_or(Value::Null, |x| x.get().clone())
-            }).collect::<Vec<Value>>().into()
+                entries.get(key).map_or(Value::Null, |x| x.get().clone())
+            })
+            .collect::<Vec<Value>>()
+            .into()
     }
 
     pub fn getset(&self, key: &Bytes, value: &Value) -> Value {
@@ -155,8 +163,7 @@ impl Db {
 
     pub fn getdel(&self, key: &Bytes) -> Value {
         let mut entries = self.entries[self.get_slot(key)].write().unwrap();
-        entries.remove(key)
-            .map_or(Value::Null, |x| x.get().clone())
+        entries.remove(key).map_or(Value::Null, |x| x.get().clone())
     }
 
     pub fn set(&self, key: &Bytes, value: &Value) -> Value {

+ 10 - 0
src/dispatcher.rs

@@ -115,6 +115,11 @@ dispatcher! {
             ["write" "denyoom" "fast"],
             3,
         },
+        incrbyfloat {
+            cmd::string::incr_by_float,
+            ["write" "denyoom" "fast"],
+            3,
+        },
         mget {
             cmd::string::mget,
             ["random" "loading" "stale"],
@@ -125,6 +130,11 @@ dispatcher! {
             ["random" "loading" "stale"],
             -3,
         },
+        strlen {
+            cmd::string::strlen,
+            ["random" "fast"],
+            2,
+        }
     },
     connection {
         client {

+ 14 - 0
src/value.rs

@@ -34,6 +34,7 @@ impl From<&Value> for Vec<u8> {
             }
             Value::Integer(x) => format!(":{}\r\n", x).into(),
             Value::BigInteger(x) => format!("({}\r\n", x).into(),
+            Value::Float(x) => format!(",{}\r\n", x).into(),
             Value::Blob(x) => {
                 let s = format!("${}\r\n", x.len());
                 let mut s: BytesMut = s.as_str().as_bytes().into();
@@ -63,6 +64,18 @@ impl TryFrom<&Value> for i64 {
     }
 }
 
+impl TryFrom<&Value> for f64 {
+    type Error = Error;
+
+    fn try_from(val: &Value) -> Result<Self, Self::Error> {
+        match val {
+            Value::Float(x) => Ok(*x),
+            Value::Blob(x) => bytes_to_number::<f64>(&x),
+            Value::String(x) => x.parse::<f64>().map_err(|_| Error::NotANumber),
+            _ => Err(Error::NotANumber),
+        }
+    }
+}
 pub fn bytes_to_number<T: FromStr>(bytes: &Bytes) -> Result<T, Error> {
     let x = unsafe { std::str::from_utf8_unchecked(bytes) };
     x.parse::<T>().map_err(|_| Error::NotANumber)
@@ -92,6 +105,7 @@ impl<'a> From<&ParsedValue<'a>> for Value {
     }
 }
 
+value_try_from!(f64, Value::Float);
 value_try_from!(i64, Value::Integer);
 value_try_from!(i128, Value::BigInteger);