Selaa lähdekoodia

Better dispatcher/command routing (#20)

A lot of the logic of the dispatcher/command router is behind an obscure
macro.

Some optimizations are:

  1. Prior each command was their own individual struct, sharing only a
     common trait. That was pointless. To improve this and remove all
     dynamic bits from runtime a single Command struct is available now.
     It is created in a Dispatcher object, one instance of command per
     command, with all their definition. Having the same struct removes
     the need of a dynamic trait.
  2. The command struct definition is outside of a macro, making the
     code more readable.
César D. Rodas 3 vuotta sitten
vanhempi
säilyke
80f8ab840f
7 muutettua tiedostoa jossa 224 lisäystä ja 208 poistoa
  1. 0 1
      Cargo.toml
  2. 1 3
      src/cmd/mod.rs
  3. 5 7
      src/cmd/transaction.rs
  4. 128 0
      src/dispatcher/command.rs
  5. 2 0
      src/dispatcher/mod.rs
  6. 69 166
      src/macros.rs
  7. 19 31
      src/server.rs

+ 0 - 1
Cargo.toml

@@ -11,7 +11,6 @@ redis-zero-protocol-parser = {path = "redis-protocol-parser"}
 tokio={version="1", features = ["full", "tracing"] }
 parking_lot="^0.11"
 tokio-util={version="^0.6", features = ["full"] }
-async-trait = "0.1.50"
 crc32fast="^1.2"
 futures = { version = "0.3.0", features = ["thread-pool"]}
 tokio-stream="0.1"

+ 1 - 3
src/cmd/mod.rs

@@ -57,8 +57,6 @@ mod test {
         let args: Vec<Bytes> = cmd.iter().map(|s| Bytes::from(s.to_string())).collect();
 
         let dispatcher = Dispatcher::new();
-        let handler = dispatcher.get_handler(&args)?;
-
-        handler.execute(conn, &args).await
+        dispatcher.execute(conn, &args).await
     }
 }

+ 5 - 7
src/cmd/transaction.rs

@@ -43,14 +43,12 @@ pub async fn exec(conn: &Connection, _: &[Bytes]) -> Result<Value, Error> {
     let mut results = vec![];
 
     if let Some(commands) = conn.get_queue_commands() {
+        let dispatcher = conn.all_connections().get_dispatcher();
         for args in commands.iter() {
-            let result = match conn.all_connections().get_dispatcher().get_handler(args) {
-                Ok(handler) => handler
-                    .execute(conn, args)
-                    .await
-                    .unwrap_or_else(|x| x.into()),
-                Err(err) => err.into(),
-            };
+            let result = dispatcher
+                .execute(conn, args)
+                .await
+                .unwrap_or_else(|x| x.into());
             results.push(result);
         }
     }

+ 128 - 0
src/dispatcher/command.rs

@@ -0,0 +1,128 @@
+//! # Dispatcher
+//!
+//! Here is where every command is defined. Each command has some definition and a handler. Their
+//! handler are rust functions.
+//!
+//! Each command is defined with the dispatcher macro, which generates efficient and developer
+//! friendly code.
+use crate::{
+    connection::{Connection, ConnectionStatus},
+    dispatcher,
+    error::Error,
+    value::Value,
+};
+use bytes::Bytes;
+use metered::{ErrorCount, HitCount, InFlight, ResponseTime, Throughput};
+use std::convert::TryInto;
+
+/// Command definition
+#[derive(Debug)]
+pub struct Command {
+    name: &'static str,
+    group: &'static str,
+    tags: &'static [&'static str],
+    min_args: i32,
+    key_start: i32,
+    key_stop: i32,
+    key_step: usize,
+    is_queueable: bool,
+    metrics: Metrics,
+}
+
+/// Metric struct for all command
+#[derive(Debug, Default, serde::Serialize)]
+pub struct Metrics {
+    /// Command hits
+    pub hit_count: HitCount,
+    /// Error count
+    pub error_count: ErrorCount,
+    /// How many concurrent executions are happening right now
+    pub in_flight: InFlight,
+    /// Response time
+    pub response_time: ResponseTime,
+    /// Throughput
+    pub throughput: Throughput,
+}
+
+impl Command {
+    /// Creates a new comamnd
+    pub fn new(
+        name: &'static str,
+        group: &'static str,
+        tags: &'static [&'static str],
+        min_args: i32,
+        key_start: i32,
+        key_stop: i32,
+        key_step: usize,
+        is_queueable: bool,
+    ) -> Self {
+        Self {
+            name,
+            group,
+            tags,
+            min_args,
+            key_start,
+            key_stop,
+            key_step,
+            is_queueable,
+            metrics: Metrics::default(),
+        }
+    }
+
+    /// Returns a reference to the metrics
+    pub fn metrics(&self) -> &Metrics {
+        &self.metrics
+    }
+
+    /// Can this command be executed in a pub-sub only mode?
+    pub fn is_pubsub_executable(&self) -> bool {
+        self.group == "pubsub" || self.name == "ping" || self.name == "reset"
+    }
+
+    /// Can this command be queued in a transaction or should it be executed right away?
+    pub fn is_queueable(&self) -> bool {
+        self.is_queueable
+    }
+
+    /// Returns all database keys from the command arguments
+    pub fn get_keys<'a>(&self, args: &'a [Bytes]) -> Vec<&'a Bytes> {
+        let start = self.key_start;
+        let stop = if self.key_stop > 0 {
+            self.key_stop
+        } else {
+            (args.len() as i32) + self.key_stop
+        };
+
+        if start == 0 {
+            return vec![];
+        }
+
+        let mut result = vec![];
+
+        for i in (start..stop + 1).step_by(self.key_step) {
+            result.push(&args[i as usize]);
+        }
+
+        result
+    }
+
+    /// Checks if a given number of args is expected by this command
+    pub fn check_number_args(&self, n: usize) -> bool {
+        if (self.min_args >= 0) {
+            n == (self.min_args as i32).try_into().unwrap_or(0)
+        } else {
+            let s: usize = (self.min_args as i32).abs().try_into().unwrap_or(0);
+            n >= s
+        }
+    }
+
+    /// Command group
+    pub fn group(&self) -> &'static str {
+        &self.group
+    }
+
+    /// Command name
+    pub fn name(&self) -> &'static str {
+        &self.name
+    }
+}

+ 2 - 0
src/dispatcher.rs → src/dispatcher/mod.rs

@@ -17,6 +17,8 @@ use std::convert::TryInto;
 use std::time::SystemTime;
 use std::time::UNIX_EPOCH;
 
+pub mod command;
+
 /// Returns the server time
 async fn do_time(_conn: &Connection, _args: &[Bytes]) -> Result<Value, Error> {
     let now = SystemTime::now();

+ 69 - 166
src/macros.rs

@@ -22,174 +22,17 @@ macro_rules! dispatcher {
                 $key_start:expr,
                 $key_stop:expr,
                 $key_step:expr,
-                $queueable:expr,
+                $is_queueable:expr,
             }),+$(,)?
         }),+$(,)?
     }=>  {
-        $($(
-            #[allow(non_snake_case, non_camel_case_types)]
-            pub mod $command {
-                //! # Command mod
-                //!
-                //! Each individual command is defined in their own namespace
-                use super::*;
-                use async_trait::async_trait;
-                use metered::measure;
-
-                /// Command definition
-                #[derive(Debug)]
-                pub struct Command {
-                    tags: &'static [&'static str],
-                    min_args: i32,
-                    key_start: i32,
-                    key_stop: i32,
-                    key_step: usize,
-                    metrics: Metrics,
-                }
-
-                impl Command {
-                    /// Creates a new comamnd
-                    pub fn new() -> Self {
-                        Self {
-                            tags: &[$($tag,)+],
-                            min_args: $min_args,
-                            key_start: $key_start,
-                            key_stop: $key_stop,
-                            key_step: $key_step,
-                            metrics: Metrics::default(),
-                        }
-                    }
-                }
-
-                #[async_trait]
-                impl ExecutableCommand for Command {
-                    async fn execute(&self, conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
-                        let metrics = self.metrics();
-                        let hit_count = &metrics.hit_count;
-                        let error_count = &metrics.error_count;
-                        let in_flight = &metrics.in_flight;
-                        let response_time = &metrics.response_time;
-                        let throughput = &metrics.throughput;
-
-                        let status = conn.status();
-                        if status == ConnectionStatus::Multi && self.is_queueable() {
-                            conn.queue_command(args);
-                            conn.tx_keys(self.get_keys(args));
-                            return Ok(Value::Queued);
-                        } else if status == ConnectionStatus::Pubsub && ! self.is_pubsub_executable() {
-                            return Err(Error::PubsubOnly(stringify!($command).to_owned()));
-                        }
-
-                        measure!(hit_count, {
-                            measure!(response_time, {
-                                measure!(throughput, {
-                                    measure!(in_flight, {
-                                        measure!(error_count, $handler(conn, args).await)
-                                    })
-                                })
-                            })
-                        })
-                    }
-
-                    fn metrics(&self) -> &Metrics {
-                        &self.metrics
-                    }
-
-                    fn is_pubsub_executable(&self) -> bool {
-                        stringify!($ns) == "pubsub" || stringify!($command) == "ping" || stringify!($command) == "reset"
-                    }
-
-                    fn is_queueable(&self) -> bool {
-                        $queueable
-                    }
-
-                    fn get_keys<'a>(&self, args: &'a [Bytes]) -> Vec<&'a Bytes> {
-                        let start = self.key_start;
-                        let stop  = if self.key_stop > 0 {
-                            self.key_stop
-                        } else {
-                            (args.len() as i32) + self.key_stop
-                        };
-
-                        if start == 0 {
-                            return vec![];
-                        }
-
-                        let mut result = vec![];
-
-                        for i in (start .. stop+1).step_by(self.key_step) {
-                            result.push(&args[i as usize]);
-                        }
-
-                        result
-                    }
-
-                    fn check_number_args(&self, n: usize) -> bool {
-                        if ($min_args >= 0) {
-                            n == ($min_args as i32).try_into().unwrap_or(0)
-                        } else {
-                            let s: usize = ($min_args as i32).abs().try_into().unwrap_or(0);
-                            n >= s
-                        }
-                    }
-
-                    fn group(&self) -> &'static str {
-                        stringify!($ns)
-                    }
-
-                    fn name(&self) -> &'static str {
-                        stringify!($command)
-                    }
-                }
-            }
-        )+)+
-
-        use async_trait::async_trait;
-        use metered::{Throughput, HitCount, ErrorCount, InFlight, ResponseTime};
-
-        /// Executable command trait
-        #[async_trait]
-        pub trait ExecutableCommand {
-            /// Call the command handler
-            async fn execute(&self, conn: &Connection, args: &[Bytes]) -> Result<Value, Error>;
-
-            /// Returns a reference to the metrics
-            fn metrics(&self) -> &Metrics;
-
-            /// Can this command be queued in a transaction or should it be executed right away?
-            fn is_queueable(&self) -> bool;
-
-            /// Can this command be executed in a pub-sub only mode?
-            fn is_pubsub_executable(&self) -> bool;
-
-            /// Returns all database keys from the command arguments
-            fn get_keys<'a>(&self, args: &'a [Bytes]) -> Vec<&'a Bytes>;
-
-            /// Checks if a given number of args is expected by this command
-            fn check_number_args(&self, n: usize) -> bool;
-
-            /// Command group
-            fn group(&self) -> &'static str;
-
-            /// Command name
-            fn name(&self) -> &'static str;
-        }
-
-        /// Metric struct for all command
-        #[derive(Debug, Default, serde::Serialize)]
-        pub struct Metrics {
-            hit_count: HitCount,
-            error_count: ErrorCount,
-            in_flight: InFlight,
-            response_time: ResponseTime,
-            throughput: Throughput,
-        }
+        use futures::future::FutureExt;
 
         /// Metrics for all defined commands
         #[derive(serde::Serialize)]
         pub struct ServiceMetricRegistry<'a> {
             $($(
-            $command: &'a Metrics,
+            $command: &'a command::Metrics,
             )+)+
         }
 
@@ -200,7 +43,7 @@ macro_rules! dispatcher {
         #[derive(Debug)]
         pub struct Dispatcher {
             $($(
-                $command: $command::Command,
+                $command: command::Command,
             )+)+
         }
 
@@ -209,7 +52,16 @@ macro_rules! dispatcher {
             pub fn new() -> Self {
                 Self {
                     $($(
-                        $command: $command::Command::new(),
+                        $command: command::Command::new(
+                            stringify!($command),
+                            stringify!($ns),
+                            &[$($tag,)+],
+                            $min_args,
+                            $key_start,
+                            $key_stop,
+                            $key_step,
+                            $is_queueable,
+                        ),
                     )+)+
                 }
             }
@@ -224,7 +76,7 @@ macro_rules! dispatcher {
             }
 
             /// Returns the handlers for defined commands.
-            pub fn get_all_commands(&self) -> Vec<&(dyn ExecutableCommand + Send + Sync + 'static)> {
+            pub fn get_all_commands(&self) -> Vec<&command::Command> {
                 vec![
                 $($(
                     &self.$command,
@@ -233,8 +85,9 @@ macro_rules! dispatcher {
             }
 
             /// Returns a command handler for a given command
-            pub fn get_handler_for_command(&self, command: &str) -> Result<&(dyn ExecutableCommand + Send + Sync + 'static), Error> {
-                match command {
+            #[inline(always)]
+            pub fn get_handler_for_command(&self, command: &str) -> Result<&command::Command, Error> {
+                match command.to_lowercase().as_str() {
                 $($(
                     stringify!($command) => Ok(&self.$command),
                 )+)+
@@ -247,7 +100,8 @@ macro_rules! dispatcher {
             /// Before returning the command handler this function will make sure the minimum
             /// required arguments are provided. This pre-validation ensures each command handler
             /// has fewer logic when reading the provided arguments.
-            pub fn get_handler(&self, args: &[Bytes]) -> Result<&(dyn ExecutableCommand + Send + Sync + 'static), Error> {
+            #[inline(always)]
+            pub fn get_handler(&self, args: &[Bytes]) -> Result<&command::Command, Error> {
                 let command = String::from_utf8_lossy(&args[0]).to_lowercase();
                 let command = self.get_handler_for_command(&command)?;
                 if ! command.check_number_args(args.len()) {
@@ -256,6 +110,55 @@ macro_rules! dispatcher {
                     Ok(command)
                 }
             }
+
+            /// Returns the command handler
+            ///
+            /// Before returning the command handler this function will make sure the minimum
+            /// required arguments are provided. This pre-validation ensures each command handler
+            /// has fewer logic when reading the provided arguments.
+            #[inline(always)]
+            pub fn execute<'a>(&'a self, conn: &'a Connection, args: &'a [Bytes]) -> futures::future::BoxFuture<'a, Result<Value, Error>> {
+                async move {
+                    let command = String::from_utf8_lossy(&args[0]);
+                    match command.to_lowercase().as_str() {
+                        $($(
+                            stringify!($command) => {
+                                let command = &self.$command;
+                                if ! command.check_number_args(args.len()) {
+                                    Err(Error::InvalidArgsCount(command.name().into()))
+                                } else {
+                                    let metrics = command.metrics();
+                                    let hit_count = &metrics.hit_count;
+                                    let error_count = &metrics.error_count;
+                                    let in_flight = &metrics.in_flight;
+                                    let response_time = &metrics.response_time;
+                                    let throughput = &metrics.throughput;
+
+                                    let status = conn.status();
+                                    if status == ConnectionStatus::Multi && command.is_queueable() {
+                                        conn.queue_command(args);
+                                        conn.tx_keys(command.get_keys(args));
+                                        return Ok(Value::Queued);
+                                    } else if status == ConnectionStatus::Pubsub && ! command.is_pubsub_executable() {
+                                        return Err(Error::PubsubOnly(stringify!($command).to_owned()));
+                                    }
+
+                                    metered::measure!(hit_count, {
+                                        metered::measure!(response_time, {
+                                            metered::measure!(throughput, {
+                                                metered::measure!(in_flight, {
+                                                    metered::measure!(error_count, $handler(conn, args).await)
+                                                })
+                                            })
+                                        })
+                                    })
+                                }
+                            }
+                        )+)+,
+                        _ => Err(Error::CommandNotFound(command.into())),
+                    }
+                }.boxed()
+            }
         }
     }
 }

+ 19 - 31
src/server.rs

@@ -145,38 +145,26 @@ pub async fn serve(addr: String) -> Result<(), Box<dyn Error>> {
                                 }
                             }
                             result = transport.next() => match result {
-                            Some(Ok(args)) => match all_connections.get_dispatcher().get_handler(&args) {
-                                Ok(handler) => {
-                                    match handler
-                                        .execute(&conn, &args)
-                                        .await {
-                                            Ok(result) => {
-                                                if conn.status() == ConnectionStatus::Pubsub {
-                                                    continue;
-                                                }
-                                                if transport.send(result).await.is_err() {
-                                                    break;
-                                                }
-                                            },
-                                            Err(err) => {
-                                                if transport.send(err.into()).await.is_err() {
-                                                    break;
-                                                }
-                                            }
-                                        };
-
-                                },
-                                Err(err) => {
-                                    if transport.send(err.into()).await.is_err() {
-                                        break;
+                                Some(Ok(args)) => match all_connections.get_dispatcher().execute(&conn, &args).await {
+                                    Ok(result) => {
+                                        if conn.status() == ConnectionStatus::Pubsub {
+                                            continue;
+                                        }
+                                        if transport.send(result).await.is_err() {
+                                            break;
+                                        }
+                                    },
+                                    Err(err) => {
+                                        if transport.send(err.into()).await.is_err() {
+                                            break;
+                                        }
                                     }
-                                }
-                            },
-                            Some(Err(e)) => {
-                                warn!("error on decoding from socket; error = {:?}", e);
-                                break;
-                            },
-                            None => break,
+                                },
+                                Some(Err(e)) => {
+                                    warn!("error on decoding from socket; error = {:?}", e);
+                                    break;
+                                },
+                                None => break,
                             }
                         }
                     }