Browse Source

Added database implementation

The DB will be hashed in a vec of hashes, this quick "pre-hash" round
will make lookups slightly slower but will hopefully reduce reduce
lockings between threads.

To disable the "pre-hash" just modify `DB::get_slot` to return the same
number always.
Cesar Rodas 3 years ago
parent
commit
d578b99387
6 changed files with 118 additions and 27 deletions
  1. 1 0
      Cargo.toml
  2. 49 4
      src/db/mod.rs
  3. 21 4
      src/dispatcher.rs
  4. 28 8
      src/macros.rs
  5. 15 10
      src/main.rs
  6. 4 1
      src/value.rs

+ 1 - 0
Cargo.toml

@@ -12,4 +12,5 @@ tokio={version="1", features = ["full", "tracing"] }
 tokio-util={version="^0.6", features = ["full"] }
 futures = { version = "0.3.0", features = ["thread-pool"]}
 tokio-stream="0.1"
+seahash = "4"
 bytes = "1"

+ 49 - 4
src/db/mod.rs

@@ -1,9 +1,54 @@
-use crate::value::Value;
+use crate::{error::Error, value::Value};
+use bytes::Bytes;
+use seahash::hash;
 use std::collections::{BTreeMap, HashMap};
-use std::sync::{Arc, RwLock};
+use std::sync::RwLock;
 use tokio::time::Instant;
 
+#[derive(Debug)]
 pub struct Db {
-    entries: Arc<RwLock<HashMap<String, Value>>>,
-    expiration: Arc<RwLock<BTreeMap<(Instant, u64), String>>>,
+    entries: Vec<RwLock<HashMap<Bytes, Value>>>,
+    expirations: RwLock<BTreeMap<(Instant, u64), String>>,
+    slots: usize,
+}
+
+impl Db {
+    pub fn new(slots: usize) -> Self {
+        let mut entries = vec![];
+
+        for _i in 0..slots {
+            entries.push(RwLock::new(HashMap::new()));
+        }
+
+        Self {
+            entries,
+            expirations: RwLock::new(BTreeMap::new()),
+            slots,
+        }
+    }
+
+    fn get_slot(&self, key: &Bytes) -> usize {
+        (hash(key) as usize) % self.entries.len()
+    }
+
+    pub fn get(&self, key: &Value) -> Result<Value, Error> {
+        match key {
+            Value::Blob(key) => {
+                let entries = self.entries[self.get_slot(key)].read().unwrap();
+                Ok(entries.get(key).cloned().unwrap_or(Value::Null))
+            }
+            _ => Err(Error::WrongType),
+        }
+    }
+
+    pub fn set(&self, key: &Value, value: &Value) -> Result<Value, Error> {
+        match key {
+            Value::Blob(key) => {
+                let mut entries = self.entries[self.get_slot(key)].write().unwrap();
+                entries.insert(key.clone(), value.clone());
+                Ok(Value::OK)
+            }
+            _ => Err(Error::WrongType),
+        }
+    }
 }

+ 21 - 4
src/dispatcher.rs

@@ -1,8 +1,9 @@
-use crate::{dispatcher, value::Value};
+use crate::{db::Db, dispatcher, error::Error, value::Value};
+use std::convert::TryInto;
 use std::time::SystemTime;
 use std::time::UNIX_EPOCH;
 
-fn do_time(_args: &[Value]) -> Result<Value, String> {
+fn do_time(_db: &Db, _args: &[Value]) -> Result<Value, Error> {
     let now = SystemTime::now();
     let since_the_epoch = now.duration_since(UNIX_EPOCH).expect("Time went backwards");
     let seconds = format!("{}", since_the_epoch.as_secs());
@@ -11,7 +12,7 @@ fn do_time(_args: &[Value]) -> Result<Value, String> {
     Ok(vec![seconds.as_str(), millis.as_str()].into())
 }
 
-fn do_command(_args: &[Value]) -> Result<Value, String> {
+fn do_command(_db: &Db, _args: &[Value]) -> Result<Value, Error> {
     let now = SystemTime::now();
     let since_the_epoch = now.duration_since(UNIX_EPOCH).expect("Time went backwards");
     let in_ms: i128 =
@@ -19,17 +20,33 @@ fn do_command(_args: &[Value]) -> Result<Value, String> {
     Ok(format!("{}", in_ms).as_str().into())
 }
 
+fn get(db: &Db, args: &[Value]) -> Result<Value, Error> {
+    db.get(&args[1])
+}
+
+fn set(db: &Db, args: &[Value]) -> Result<Value, Error> {
+    db.set(&args[1], &args[2])
+}
+
 dispatcher! {
     command  {
         do_command,
         ["random" "loading" "stale"],
+        1,
     },
     get {
-        do_command,
+        get,
+        ["random" "loading" "stale"],
+        2,
+    },
+    set {
+        set,
         ["random" "loading" "stale"],
+        -3,
     },
     time {
         do_time,
         ["random" "loading" "stale"],
+        1,
     },
 }

+ 28 - 8
src/macros.rs

@@ -4,6 +4,7 @@ macro_rules! dispatcher {
         $($command:ident {
             $handler:ident,
             [$($tag:tt)+],
+            $min_args:expr,
         },)+$(,)?
     }=>  {
         $(
@@ -13,19 +14,30 @@ macro_rules! dispatcher {
 
                 pub struct Command {
                     pub tags: &'static [&'static str],
+                    pub min_args: i32,
                 }
 
                 impl Command {
                     pub fn new() -> Self {
                         Self {
                             tags: &[$($tag,)+],
+                            min_args: $min_args,
                         }
                     }
                 }
 
                 impl ExecutableCommand for Command {
-                    fn execute(&self, args: &[Value]) -> Result<Value, String> {
-                        $handler(args)
+                    fn execute(&self, db: &Db, args: &[Value]) -> Result<Value, Error> {
+                        $handler(db, args)
+                    }
+
+                    fn check_number_args(&self, n: usize) -> bool {
+                        if ($min_args >= 0) {
+                            n == ($min_args as i32).try_into().unwrap()
+                        } else {
+                            let s: usize = ($min_args as i32).abs().try_into().unwrap();
+                            n >= s
+                        }
                     }
 
                     fn name(&self) -> &'static str {
@@ -37,7 +49,9 @@ macro_rules! dispatcher {
         use std::ops::Deref;
 
         pub trait ExecutableCommand {
-            fn execute(&self, args: &[Value]) -> Result<Value, String>;
+            fn execute(&self, db: &Db, args: &[Value]) -> Result<Value, Error>;
+
+            fn check_number_args(&self, n: usize) -> bool;
 
             fn name(&self) -> &'static str;
         }
@@ -50,18 +64,24 @@ macro_rules! dispatcher {
         }
 
         impl Dispatcher {
-            pub fn new(command: &Value) -> Result<Self, String> {
-                let command = match command {
+            pub fn new(args: &[Value]) -> Result<Self, Error> {
+                let command = match &args[0] {
                     Value::String(x) => Ok(x.as_str()),
                     Value::Blob(x) => Ok(unsafe { std::str::from_utf8_unchecked(&x) }),
-                    _ => Err("Invalid type"),
+                    _ => Err(Error::ProtocolError("$".to_string(), "*".to_string())),
                 }?;
 
-                match command.to_lowercase().as_str() {
+                let command = match command.to_lowercase().as_str() {
                 $(
                     stringify!($command) => Ok(Self::$command($command::Command::new())),
                 )+
-                    _ => Err(format!("Command ({}) not found", command)),
+                    _ => Err(Error::CommandNotFound(command.into())),
+                }?;
+
+                if ! command.check_number_args(args.len()) {
+                    Err(Error::InvalidArgsCount(command.name().into()))
+                } else {
+                    Ok(command)
                 }
             }
         }

+ 15 - 10
src/main.rs

@@ -1,5 +1,6 @@
-mod dispatcher;
 mod db;
+mod dispatcher;
+mod error;
 mod macros;
 mod value;
 
@@ -11,11 +12,8 @@ use std::convert::TryFrom;
 use std::env;
 use std::error::Error;
 use std::ops::Deref;
-use std::{
-    io,
-    sync::{Arc, Mutex},
-};
-use tokio::net::{TcpListener, TcpStream};
+use std::{io, sync::Arc};
+use tokio::net::TcpListener;
 use tokio_stream::StreamExt;
 use tokio_util::codec::{Decoder, Encoder, Framed};
 use value::Value;
@@ -29,21 +27,28 @@ async fn main() -> Result<(), Box<dyn Error>> {
     let listener = TcpListener::bind(&addr).await?;
     println!("Listening on: {}", addr);
 
+    let db = Arc::new(db::Db::new(12));
+
     loop {
         match listener.accept().await {
             Ok((socket, _)) => {
+                let db = db.clone();
                 tokio::spawn(async move {
                     let mut transport = Framed::new(socket, RedisParser);
 
                     while let Some(result) = transport.next().await {
                         match result {
-                            Ok(Value::Array(args)) => match Dispatcher::new(&args[0]) {
+                            Ok(Value::Array(args)) => match Dispatcher::new(&args) {
                                 Ok(handler) => {
-                                    let r = handler.deref().execute(&args);
-                                    transport.send(r.unwrap()).await;
+                                    let r = handler
+                                        .deref()
+                                        .execute(&db, &args)
+                                        .unwrap_or_else(|x| x.into());
+                                    transport.send(r).await;
                                 }
                                 Err(err) => {
-                                    println!("invalid command {:?}", err);
+                                    let err: Value = err.into();
+                                    transport.send(err).await;
                                 }
                             },
                             Ok(x) => {

+ 4 - 1
src/value.rs

@@ -1,6 +1,6 @@
 use crate::{value_try_from, value_vec_try_from};
-use redis_zero_parser::Value as ParsedValue;
 use bytes::{Bytes, BytesMut};
+use redis_zero_parser::Value as ParsedValue;
 use std::convert::TryFrom;
 
 #[derive(Debug, PartialEq, Clone)]
@@ -14,6 +14,7 @@ pub enum Value {
     Float(f64),
     BigInteger(i128),
     Null,
+    OK,
 }
 
 impl From<&Value> for Vec<u8> {
@@ -37,6 +38,8 @@ impl From<&Value> for Vec<u8> {
                 s.extend_from_slice(b"\r\n");
                 s.to_vec()
             }
+            Value::Err(x, y) => format!("-{} {}\r\n", x, y).into(),
+            Value::OK => "+OK\r\n".into(),
             _ => b"*-1\r\n".to_vec(),
         }
     }