Переглянути джерело

Pubsub (#11)

* Prototype: Added pub-sub

The entire pub-sub is quite simple and rely heavely on Tokio's mpsc
channels.

The logic to subscribe to patterns is not yet done, but it will be added
in the future with heavy caching to make sure it is as efficient as it
can possibly be.

The main evented-loop will read off the client and from the internal
pub-sub channels to send a response.

* Add ConnectionStatus

This new enum will switch the different possible status for a client
connection (normal, pubsubm, multi or executing tx).

By having an enum instead of different booleans will make checking the
status much easier and will remove a bunch of methods

* WIP: Adding support for pattern subscription

Started support for pattern subscriptions

* Added command to list channels

* Fixed warnings

* Adding pubsub unit tests

* Work on unsubscribe features

* Splitting connection file into separated files

The connection file has the logic for many things, it should splitten
into separated files inside the same folder.

* Automatically unsubscribe on disconnect

* Working on punsubscribe and reset
César D. Rodas 3 роки тому
батько
коміт
e82b29149f

+ 1 - 0
Cargo.toml

@@ -17,6 +17,7 @@ futures = { version = "0.3.0", features = ["thread-pool"]}
 tokio-stream="0.1"
 seahash = "4"
 log="0.4"
+glob="^0.2"
 env_logger = "0.8.4"
 bytes = "1"
 rand = "0.8.0"

+ 5 - 0
src/cmd/client.rs

@@ -50,3 +50,8 @@ pub async fn ping(_conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
         _ => Err(Error::InvalidArgsCount("ping".to_owned())),
     }
 }
+
+pub async fn reset(conn: &Connection, _: &[Bytes]) -> Result<Value, Error> {
+    conn.reset();
+    Ok(Value::String("RESET".to_owned()))
+}

+ 3 - 3
src/cmd/hash.rs

@@ -382,7 +382,7 @@ mod test {
                         || x[0] == Value::Blob("f3".into())
                 )
             }
-            _ => assert!(false),
+            _ => unreachable!(),
         };
     }
 
@@ -397,9 +397,9 @@ mod test {
         match r {
             Ok(Value::Blob(x)) => {
                 let x = String::from_utf8_lossy(&x);
-                assert!(x == "f1".to_owned() || x == "f2".to_owned() || x == "f3".to_owned());
+                assert!(x == *"f1" || x == *"f2" || x == *"f3");
             }
-            _ => assert!(false),
+            _ => unreachable!(),
         };
     }
 

+ 7 - 3
src/cmd/list.rs

@@ -1,5 +1,9 @@
 use crate::{
-    check_arg, connection::Connection, error::Error, value::bytes_to_number, value::checksum,
+    check_arg,
+    connection::{Connection, ConnectionStatus},
+    error::Error,
+    value::bytes_to_number,
+    value::checksum,
     value::Value,
 };
 use bytes::Bytes;
@@ -65,7 +69,7 @@ pub async fn blpop(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
             };
         }
 
-        if Instant::now() >= timeout || conn.is_executing_transaction() {
+        if Instant::now() >= timeout || conn.status() == ConnectionStatus::ExecutingTx {
             break;
         }
 
@@ -87,7 +91,7 @@ pub async fn brpop(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
             };
         }
 
-        if Instant::now() >= timeout || conn.is_executing_transaction() {
+        if Instant::now() >= timeout || conn.status() == ConnectionStatus::ExecutingTx {
             break;
         }
 

+ 24 - 3
src/cmd/mod.rs

@@ -2,6 +2,7 @@ pub mod client;
 pub mod hash;
 pub mod key;
 pub mod list;
+pub mod pubsub;
 pub mod set;
 pub mod string;
 pub mod transaction;
@@ -9,7 +10,7 @@ pub mod transaction;
 #[cfg(test)]
 mod test {
     use crate::{
-        connection::{Connection, Connections},
+        connection::{connections::Connections, Connection},
         db::Db,
         dispatcher::Dispatcher,
         error::Error,
@@ -21,21 +22,41 @@ mod test {
         ops::Deref,
         sync::Arc,
     };
+    use tokio::sync::mpsc::UnboundedReceiver;
 
     pub fn create_connection() -> Arc<Connection> {
-        let all_connections = Arc::new(Connections::new());
         let db = Arc::new(Db::new(1000));
+        let all_connections = Arc::new(Connections::new(db.clone()));
+
+        let client = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
+
+        all_connections.new_connection(db.clone(), client).1
+    }
+
+    pub fn create_connection_and_pubsub() -> (UnboundedReceiver<Value>, Arc<Connection>) {
+        let db = Arc::new(Db::new(1000));
+        let all_connections = Arc::new(Connections::new(db.clone()));
 
         let client = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
 
         all_connections.new_connection(db.clone(), client)
     }
 
+    pub fn create_new_connection_from_connection(
+        conn: &Connection,
+    ) -> (UnboundedReceiver<Value>, Arc<Connection>) {
+        let all_connections = conn.all_connections();
+
+        let client = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
+
+        all_connections.new_connection(all_connections.db(), client)
+    }
+
     pub async fn run_command(conn: &Connection, cmd: &[&str]) -> Result<Value, Error> {
         let args: Vec<Bytes> = cmd.iter().map(|s| Bytes::from(s.to_string())).collect();
 
         let handler = Dispatcher::new(&args)?;
 
-        handler.deref().execute(&conn, &args).await
+        handler.deref().execute(conn, &args).await
     }
 }

+ 257 - 0
src/cmd/pubsub.rs

@@ -0,0 +1,257 @@
+use crate::{check_arg, connection::Connection, error::Error, value::Value};
+use bytes::Bytes;
+use glob::Pattern;
+
+pub async fn publish(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    Ok(conn.pubsub().publish(&args[1], &args[2]).await.into())
+}
+
+pub async fn pubsub(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    match String::from_utf8_lossy(&args[1]).to_lowercase().as_str() {
+        "channels" => Ok(Value::Array(
+            conn.pubsub()
+                .channels()
+                .iter()
+                .map(|v| Value::Blob(v.clone()))
+                .collect(),
+        )),
+        "help" => Ok(Value::Array(vec![
+            Value::String("PUBSUB <subcommand> arg arg ... arg. Subcommands are:".to_owned()),
+            Value::String("CHANNELS [<pattern>] -- Return the currently active channels matching a pattern (default: all).".to_owned()),
+            Value::String("NUMPAT -- Return number of subscriptions to patterns.".to_owned()),
+            Value::String("NUMSUB [channel-1 .. channel-N] -- Returns the number of subscribers for the specified channels (excluding patterns, default: none).".to_owned()),
+        ])),
+        "numpat" => Ok(conn.pubsub().get_number_of_psubscribers().into()),
+        "numsub" => Ok(conn
+            .pubsub()
+            .get_number_of_subscribers(&args[2..])
+            .iter()
+            .map(|(channel, subs)| vec![Value::Blob(channel.clone()), (*subs).into()])
+            .flatten()
+            .collect::<Vec<Value>>()
+            .into()),
+        cmd => Ok(Value::Err(
+            "ERR".to_owned(),
+            format!(
+                "Unknown subcommand or wrong number of arguments for '{}'. Try PUBSUB HELP.",
+                cmd
+            ),
+        )),
+    }
+}
+
+pub async fn subscribe(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    let pubsub = conn.pubsub();
+
+    let channels = &args[1..];
+
+    if check_arg!(args, 0, "PSUBSCRIBE") {
+        pubsub.psubscribe(channels, conn)?;
+    } else {
+        pubsub.subscribe(channels, conn);
+    }
+
+    conn.start_pubsub()
+}
+
+pub async fn punsubscribe(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    let channels = if args.len() == 1 {
+        conn.pubsub_client().psubscriptions()
+    } else {
+        (&args[1..])
+            .iter()
+            .map(|channel| {
+                let channel = String::from_utf8_lossy(channel);
+                Pattern::new(&channel).map_err(|_| Error::InvalidPattern(channel.to_string()))
+            })
+            .collect::<Result<Vec<Pattern>, Error>>()?
+    };
+
+    Ok(conn.pubsub_client().punsubscribe(&channels, conn).into())
+}
+
+pub async fn unsubscribe(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    let channels = if args.len() == 1 {
+        conn.pubsub_client().subscriptions()
+    } else {
+        (&args[1..]).to_vec()
+    };
+
+    Ok(conn.pubsub_client().unsubscribe(&channels, conn).into())
+}
+
+#[cfg(test)]
+mod test {
+    use crate::{
+        cmd::test::{
+            create_connection_and_pubsub, create_new_connection_from_connection, run_command,
+        },
+        value::Value,
+    };
+    use tokio::sync::mpsc::UnboundedReceiver;
+
+    async fn test_subscription_confirmation_and_first_message(
+        msg: &str,
+        channel: &str,
+        recv: &mut UnboundedReceiver<Value>,
+    ) {
+        assert_eq!(
+            Some(Value::Array(vec![
+                "subscribe".into(),
+                channel.into(),
+                1.into()
+            ])),
+            recv.recv().await
+        );
+
+        assert_eq!(
+            Some(Value::Array(vec![
+                Value::Blob("message".into()),
+                channel.into(),
+                msg.into()
+            ])),
+            recv.recv().await
+        );
+    }
+
+    #[tokio::test]
+    async fn test_subscribe_multiple_channels() {
+        let (mut recv, c1) = create_connection_and_pubsub();
+
+        assert_eq!(
+            Ok(Value::Ok),
+            run_command(&c1, &["subscribe", "foo", "bar"]).await
+        );
+
+        assert_eq!(
+            Some(Value::Array(vec![
+                "subscribe".into(),
+                "foo".into(),
+                1.into()
+            ])),
+            recv.recv().await
+        );
+
+        assert_eq!(
+            Some(Value::Array(vec![
+                "subscribe".into(),
+                "bar".into(),
+                2.into()
+            ])),
+            recv.recv().await
+        );
+    }
+
+    #[tokio::test]
+    async fn test_subscribe_multiple_channels_one_by_one() {
+        let (mut recv, c1) = create_connection_and_pubsub();
+
+        assert_eq!(Ok(Value::Ok), run_command(&c1, &["subscribe", "foo"]).await);
+
+        assert_eq!(Ok(Value::Ok), run_command(&c1, &["subscribe", "bar"]).await);
+
+        assert_eq!(
+            Some(Value::Array(vec![
+                "subscribe".into(),
+                "foo".into(),
+                1.into()
+            ])),
+            recv.recv().await
+        );
+
+        assert_eq!(
+            Some(Value::Array(vec![
+                "subscribe".into(),
+                "bar".into(),
+                2.into()
+            ])),
+            recv.recv().await
+        );
+    }
+
+    #[tokio::test]
+    async fn test_unsubscribe_with_args() {
+        let (mut recv, c1) = create_connection_and_pubsub();
+
+        assert_eq!(
+            Ok(Value::Ok),
+            run_command(&c1, &["subscribe", "foo", "bar"]).await
+        );
+
+        assert_eq!(
+            Ok(Value::Integer(2)),
+            run_command(&c1, &["unsubscribe", "foo", "bar"]).await
+        );
+
+        assert_eq!(
+            Some(Value::Array(vec![
+                "subscribe".into(),
+                "foo".into(),
+                1.into()
+            ])),
+            recv.recv().await
+        );
+
+        assert_eq!(
+            Some(Value::Array(vec![
+                "subscribe".into(),
+                "bar".into(),
+                2.into()
+            ])),
+            recv.recv().await
+        );
+
+        assert_eq!(
+            Some(Value::Array(vec![
+                "unsubscribe".into(),
+                "foo".into(),
+                1.into()
+            ])),
+            recv.recv().await
+        );
+
+        assert_eq!(
+            Some(Value::Array(vec![
+                "unsubscribe".into(),
+                "bar".into(),
+                1.into()
+            ])),
+            recv.recv().await
+        );
+    }
+
+    #[tokio::test]
+    async fn pubsub_publish() {
+        let (mut sub1, c1) = create_connection_and_pubsub();
+        let (mut sub2, c2) = create_new_connection_from_connection(&c1);
+        let (_, c3) = create_new_connection_from_connection(&c1);
+
+        assert_eq!(Ok(Value::Ok), run_command(&c1, &["subscribe", "foo"]).await);
+        assert_eq!(Ok(Value::Ok), run_command(&c2, &["subscribe", "foo"]).await);
+
+        let msg = "foo - message";
+
+        let _ = run_command(&c3, &["publish", "foo", msg]).await;
+
+        test_subscription_confirmation_and_first_message(msg, "foo", &mut sub1).await;
+        test_subscription_confirmation_and_first_message(msg, "foo", &mut sub2).await;
+    }
+
+    #[tokio::test]
+    async fn pubsub_numpat() {
+        let (_, c1) = create_connection_and_pubsub();
+        let (_, c2) = create_new_connection_from_connection(&c1);
+
+        assert_eq!(
+            Ok(Value::Integer(0)),
+            run_command(&c1, &["pubsub", "numpat"]).await
+        );
+
+        let _ = run_command(&c2, &["psubscribe", "foo", "bar*", "xxx*"]).await;
+
+        assert_eq!(
+            Ok(Value::Integer(1)),
+            run_command(&c1, &["pubsub", "numpat"]).await
+        );
+    }
+}

+ 7 - 4
src/cmd/transaction.rs

@@ -1,4 +1,9 @@
-use crate::{connection::Connection, dispatcher::Dispatcher, error::Error, value::Value};
+use crate::{
+    connection::{Connection, ConnectionStatus},
+    dispatcher::Dispatcher,
+    error::Error,
+    value::Value,
+};
 use bytes::Bytes;
 
 pub async fn discard(conn: &Connection, _: &[Bytes]) -> Result<Value, Error> {
@@ -10,7 +15,7 @@ pub async fn multi(conn: &Connection, _: &[Bytes]) -> Result<Value, Error> {
 }
 
 pub async fn exec(conn: &Connection, _: &[Bytes]) -> Result<Value, Error> {
-    if !conn.in_transaction() {
+    if conn.status() != ConnectionStatus::Multi {
         return Err(Error::NotInTx);
     }
 
@@ -22,8 +27,6 @@ pub async fn exec(conn: &Connection, _: &[Bytes]) -> Result<Value, Error> {
     let db = conn.db();
     let locked_keys = conn.get_tx_keys();
 
-    conn.start_executing_transaction();
-
     db.lock_keys(&locked_keys);
 
     let mut results = vec![];

+ 69 - 0
src/connection/connections.rs

@@ -0,0 +1,69 @@
+use super::{pubsub_connection::PubsubClient, pubsub_server::Pubsub, Connection, ConnectionInfo};
+use crate::{db::Db, value::Value};
+use parking_lot::RwLock;
+use std::{collections::BTreeMap, net::SocketAddr, sync::Arc};
+
+use tokio::sync::mpsc;
+
+#[derive(Debug)]
+pub struct Connections {
+    connections: RwLock<BTreeMap<u128, Arc<Connection>>>,
+    db: Arc<Db>,
+    pubsub: Arc<Pubsub>,
+    counter: RwLock<u128>,
+}
+
+impl Connections {
+    pub fn new(db: Arc<Db>) -> Self {
+        Self {
+            counter: RwLock::new(0),
+            db,
+            pubsub: Arc::new(Pubsub::new()),
+            connections: RwLock::new(BTreeMap::new()),
+        }
+    }
+
+    #[allow(dead_code)]
+    pub fn db(&self) -> Arc<Db> {
+        self.db.clone()
+    }
+
+    pub fn pubsub(&self) -> Arc<Pubsub> {
+        self.pubsub.clone()
+    }
+
+    pub fn remove(self: Arc<Connections>, conn: Arc<Connection>) {
+        let id = conn.id();
+        self.connections.write().remove(&id);
+    }
+
+    pub fn new_connection(
+        self: &Arc<Connections>,
+        db: Arc<Db>,
+        addr: SocketAddr,
+    ) -> (mpsc::UnboundedReceiver<Value>, Arc<Connection>) {
+        let mut id = self.counter.write();
+        *id += 1;
+
+        let (pubsub_sender, pubsub_receiver) = mpsc::unbounded_channel();
+
+        let conn = Arc::new(Connection {
+            id: *id,
+            db: db.new_db_instance(*id),
+            addr,
+            all_connections: self.clone(),
+            current_db: 0,
+            info: RwLock::new(ConnectionInfo::new()),
+            pubsub_client: PubsubClient::new(pubsub_sender),
+        });
+
+        self.connections.write().insert(*id, conn.clone());
+        (pubsub_receiver, conn)
+    }
+
+    pub fn iter(&self, f: &mut dyn FnMut(Arc<Connection>)) {
+        for (_, value) in self.connections.read().iter() {
+            f(value.clone())
+        }
+    }
+}

+ 64 - 77
src/connection.rs → src/connection/mod.rs

@@ -1,16 +1,26 @@
 use crate::{db::Db, error::Error, value::Value};
 use bytes::Bytes;
 use parking_lot::RwLock;
-use std::{
-    collections::{BTreeMap, HashSet},
-    net::SocketAddr,
-    sync::Arc,
-};
+use std::{collections::HashSet, net::SocketAddr, sync::Arc};
 
-#[derive(Debug)]
-pub struct Connections {
-    connections: RwLock<BTreeMap<u128, Arc<Connection>>>,
-    counter: RwLock<u128>,
+use self::pubsub_server::Pubsub;
+
+pub mod connections;
+pub mod pubsub_connection;
+pub mod pubsub_server;
+
+#[derive(Debug, Clone, Copy, Eq, PartialEq)]
+pub enum ConnectionStatus {
+    Multi,
+    ExecutingTx,
+    Pubsub,
+    Normal,
+}
+
+impl Default for ConnectionStatus {
+    fn default() -> Self {
+        ConnectionStatus::Normal
+    }
 }
 
 #[derive(Debug)]
@@ -18,8 +28,7 @@ pub struct ConnectionInfo {
     pub name: Option<String>,
     pub watch_keys: Vec<(Bytes, u128)>,
     pub tx_keys: HashSet<Bytes>,
-    pub in_transaction: bool,
-    pub in_executing_transaction: bool,
+    pub status: ConnectionStatus,
     pub commands: Option<Vec<Vec<Bytes>>>,
 }
 
@@ -28,50 +37,10 @@ pub struct Connection {
     id: u128,
     db: Db,
     current_db: u32,
-    connections: Arc<Connections>,
+    all_connections: Arc<connections::Connections>,
     addr: SocketAddr,
     info: RwLock<ConnectionInfo>,
-}
-
-impl Connections {
-    pub fn new() -> Self {
-        Self {
-            counter: RwLock::new(0),
-            connections: RwLock::new(BTreeMap::new()),
-        }
-    }
-
-    pub fn remove(self: Arc<Connections>, conn: Arc<Connection>) {
-        let id = conn.id();
-        self.connections.write().remove(&id);
-    }
-
-    pub fn new_connection(
-        self: &Arc<Connections>,
-        db: Arc<Db>,
-        addr: SocketAddr,
-    ) -> Arc<Connection> {
-        let mut id = self.counter.write();
-        *id += 1;
-
-        let conn = Arc::new(Connection {
-            id: *id,
-            db: db.new_db_instance(*id),
-            addr,
-            connections: self.clone(),
-            current_db: 0,
-            info: RwLock::new(ConnectionInfo::new()),
-        });
-
-        self.connections.write().insert(*id, conn.clone());
-        conn
-    }
-
-    pub fn iter(&self, f: &mut dyn FnMut(Arc<Connection>)) {
-        for (_, value) in self.connections.read().iter() {
-            f(value.clone())
-        }
-    }
+    pubsub_client: pubsub_connection::PubsubClient,
 }
 
 impl ConnectionInfo {
@@ -81,8 +50,7 @@ impl ConnectionInfo {
             watch_keys: vec![],
             tx_keys: HashSet::new(),
             commands: None,
-            in_transaction: false,
-            in_executing_transaction: false,
+            status: ConnectionStatus::Normal,
         }
     }
 }
@@ -92,18 +60,25 @@ impl Connection {
         &self.db
     }
 
+    pub fn pubsub(&self) -> Arc<Pubsub> {
+        self.all_connections.pubsub()
+    }
+
+    pub fn pubsub_client(&self) -> &pubsub_connection::PubsubClient {
+        &self.pubsub_client
+    }
+
     pub fn id(&self) -> u128 {
         self.id
     }
 
     pub fn stop_transaction(&self) -> Result<Value, Error> {
         let info = &mut self.info.write();
-        if info.in_transaction {
+        if info.status == ConnectionStatus::Multi {
             info.commands = None;
             info.watch_keys.clear();
             info.tx_keys.clear();
-            info.in_transaction = false;
-            info.in_executing_transaction = true;
+            info.status = ConnectionStatus::ExecutingTx;
 
             Ok(Value::Ok)
         } else {
@@ -113,31 +88,40 @@ impl Connection {
 
     pub fn start_transaction(&self) -> Result<Value, Error> {
         let mut info = self.info.write();
-        if !info.in_transaction {
-            info.in_transaction = true;
+        if info.status == ConnectionStatus::Normal {
+            info.status = ConnectionStatus::Multi;
             Ok(Value::Ok)
         } else {
             Err(Error::NestedTx)
         }
     }
 
-    /// We are inside a MULTI, most transactions are rather queued for later
-    /// execution instead of being executed right away.
-    pub fn in_transaction(&self) -> bool {
-        self.info.read().in_transaction
+    pub fn start_pubsub(&self) -> Result<Value, Error> {
+        let mut info = self.info.write();
+        match info.status {
+            ConnectionStatus::Normal | ConnectionStatus::Pubsub => {
+                info.status = ConnectionStatus::Pubsub;
+                Ok(Value::Ok)
+            }
+            _ => Err(Error::NestedTx),
+        }
     }
 
-    /// The commands are being executed inside a transaction (by EXEC). It is
-    /// important to keep track of this because some commands change their
-    /// behaviour.
-    pub fn is_executing_transaction(&self) -> bool {
-        self.info.read().in_executing_transaction
+    pub fn reset(&self) {
+        let mut info = self.info.write();
+        info.status = ConnectionStatus::Normal;
+        info.name = None;
+        info.watch_keys = vec![];
+        info.commands = None;
+        info.tx_keys = HashSet::new();
+
+        let pubsub = self.pubsub();
+        pubsub.unsubscribe(&self.pubsub_client.subscriptions(), self);
+        pubsub.punsubscribe(&self.pubsub_client.psubscriptions(), self);
     }
 
-    /// EXEC has been called and we need to keep track
-    pub fn start_executing_transaction(&self) {
-        let info = &mut self.info.write();
-        info.in_executing_transaction = true;
+    pub fn status(&self) -> ConnectionStatus {
+        self.info.read().status
     }
 
     pub fn watch_key(&self, keys: &[(&Bytes, u128)]) {
@@ -184,7 +168,7 @@ impl Connection {
     pub fn get_queue_commands(&self) -> Option<Vec<Vec<Bytes>>> {
         let info = &mut self.info.write();
         info.watch_keys = vec![];
-        info.in_transaction = false;
+        info.status = ConnectionStatus::ExecutingTx;
         info.commands.take()
     }
 
@@ -199,11 +183,14 @@ impl Connection {
     }
 
     pub fn destroy(self: Arc<Connection>) {
-        self.connections.clone().remove(self);
+        let pubsub = self.pubsub();
+        pubsub.unsubscribe(&self.pubsub_client.subscriptions(), &self);
+        pubsub.punsubscribe(&self.pubsub_client.psubscriptions(), &self);
+        self.all_connections.clone().remove(self);
     }
 
-    pub fn all_connections(&self) -> Arc<Connections> {
-        self.connections.clone()
+    pub fn all_connections(&self) -> Arc<connections::Connections> {
+        self.all_connections.clone()
     }
 
     pub fn name(&self) -> Option<String> {

+ 105 - 0
src/connection/pubsub_connection.rs

@@ -0,0 +1,105 @@
+use super::Connection;
+use crate::value::Value;
+use bytes::Bytes;
+use glob::Pattern;
+use parking_lot::RwLock;
+use std::collections::HashMap;
+use tokio::sync::mpsc;
+
+#[derive(Debug)]
+pub struct PubsubClient {
+    meta: RwLock<MetaData>,
+    sender: mpsc::UnboundedSender<Value>,
+}
+
+#[derive(Debug)]
+struct MetaData {
+    subscriptions: HashMap<Bytes, bool>,
+    psubscriptions: HashMap<Pattern, bool>,
+    is_psubcribed: bool,
+    id: usize,
+}
+
+impl PubsubClient {
+    pub fn new(sender: mpsc::UnboundedSender<Value>) -> Self {
+        Self {
+            meta: RwLock::new(MetaData {
+                subscriptions: HashMap::new(),
+                psubscriptions: HashMap::new(),
+                is_psubcribed: false,
+                id: 0,
+            }),
+            sender,
+        }
+    }
+
+    pub fn punsubscribe(&self, channels: &[Pattern], conn: &Connection) -> u32 {
+        let mut meta = self.meta.write();
+        channels
+            .iter()
+            .map(|channel| meta.psubscriptions.remove(channel))
+            .for_each(drop);
+        if meta.psubscriptions.len() + meta.subscriptions.len() == 0 {
+            drop(meta);
+            conn.reset();
+        }
+        conn.pubsub().punsubscribe(channels, conn)
+    }
+
+    pub fn unsubscribe(&self, channels: &[Bytes], conn: &Connection) -> u32 {
+        let mut meta = self.meta.write();
+        channels
+            .iter()
+            .map(|channel| meta.subscriptions.remove(channel))
+            .for_each(drop);
+        if meta.psubscriptions.len() + meta.subscriptions.len() == 0 {
+            drop(meta);
+            conn.reset();
+        }
+        conn.pubsub().unsubscribe(channels, conn)
+    }
+
+    pub fn subscriptions(&self) -> Vec<Bytes> {
+        self.meta
+            .read()
+            .subscriptions
+            .keys()
+            .cloned()
+            .collect::<Vec<Bytes>>()
+    }
+
+    pub fn psubscriptions(&self) -> Vec<Pattern> {
+        self.meta
+            .read()
+            .psubscriptions
+            .keys()
+            .cloned()
+            .collect::<Vec<Pattern>>()
+    }
+
+    pub fn new_subscription(&self, channel: &Bytes) -> usize {
+        let mut meta = self.meta.write();
+        meta.subscriptions.insert(channel.clone(), true);
+        meta.id += 1;
+        meta.id
+    }
+
+    pub fn new_psubscription(&self, channel: &Pattern) -> usize {
+        let mut meta = self.meta.write();
+        meta.psubscriptions.insert(channel.clone(), true);
+        meta.id += 1;
+        meta.id
+    }
+
+    pub fn is_psubcribed(&self) -> bool {
+        self.meta.read().is_psubcribed
+    }
+
+    pub fn make_psubcribed(&self) {
+        self.meta.write().is_psubcribed = true;
+    }
+
+    pub fn sender(&self) -> mpsc::UnboundedSender<Value> {
+        self.sender.clone()
+    }
+}

+ 195 - 0
src/connection/pubsub_server.rs

@@ -0,0 +1,195 @@
+use crate::{connection::Connection, error::Error, value::Value};
+use bytes::Bytes;
+use glob::Pattern;
+use parking_lot::RwLock;
+use std::collections::HashMap;
+use tokio::sync::mpsc;
+
+type Sender = mpsc::UnboundedSender<Value>;
+type Subscription = HashMap<u128, Sender>;
+
+#[derive(Debug)]
+pub struct Pubsub {
+    subscriptions: RwLock<HashMap<Bytes, Subscription>>,
+    psubscriptions: RwLock<HashMap<Pattern, Subscription>>,
+    number_of_psubscriptions: RwLock<i64>,
+}
+
+impl Pubsub {
+    pub fn new() -> Self {
+        Self {
+            subscriptions: RwLock::new(HashMap::new()),
+            psubscriptions: RwLock::new(HashMap::new()),
+            number_of_psubscriptions: RwLock::new(0),
+        }
+    }
+
+    pub fn channels(&self) -> Vec<Bytes> {
+        self.subscriptions.read().keys().cloned().collect()
+    }
+
+    pub fn get_number_of_psubscribers(&self) -> i64 {
+        *(self.number_of_psubscriptions.read())
+    }
+
+    pub fn get_number_of_subscribers(&self, channels: &[Bytes]) -> Vec<(Bytes, usize)> {
+        let subscribers = self.subscriptions.read();
+        let mut ret = vec![];
+        for channel in channels.iter() {
+            if let Some(subs) = subscribers.get(channel) {
+                ret.push((channel.clone(), subs.len()));
+            } else {
+                ret.push((channel.clone(), 0));
+            }
+        }
+
+        ret
+    }
+
+    pub fn psubscribe(&self, channels: &[Bytes], conn: &Connection) -> Result<(), Error> {
+        let mut subscriptions = self.psubscriptions.write();
+
+        for bytes_channel in channels.iter() {
+            let channel = String::from_utf8_lossy(bytes_channel);
+            let channel =
+                Pattern::new(&channel).map_err(|_| Error::InvalidPattern(channel.to_string()))?;
+
+            if let Some(subs) = subscriptions.get_mut(&channel) {
+                subs.insert(conn.id(), conn.pubsub_client().sender());
+            } else {
+                let mut h = HashMap::new();
+                h.insert(conn.id(), conn.pubsub_client().sender());
+                subscriptions.insert(channel.clone(), h);
+            }
+            if !conn.pubsub_client().is_psubcribed() {
+                let mut psubs = self.number_of_psubscriptions.write();
+                conn.pubsub_client().make_psubcribed();
+                *psubs += 1;
+            }
+
+            let _ = conn.pubsub_client().sender().send(
+                vec![
+                    "psubscribe".into(),
+                    Value::Blob(bytes_channel.clone()),
+                    conn.pubsub_client().new_psubscription(&channel).into(),
+                ]
+                .into(),
+            );
+        }
+
+        Ok(())
+    }
+
+    pub async fn publish(&self, channel: &Bytes, message: &Bytes) -> u32 {
+        let mut i = 0;
+
+        if let Some(subs) = self.subscriptions.read().get(channel) {
+            for sender in subs.values() {
+                let _ = sender.send(Value::Array(vec![
+                    "message".into(),
+                    Value::Blob(channel.clone()),
+                    Value::Blob(message.clone()),
+                ]));
+                i += 1;
+            }
+        }
+
+        let str_channel = String::from_utf8_lossy(channel);
+
+        for (pattern, subs) in self.psubscriptions.read().iter() {
+            if !pattern.matches(&str_channel) {
+                continue;
+            }
+
+            for sub in subs.values() {
+                let _ = sub.send(Value::Array(vec![
+                    "pmessage".into(),
+                    pattern.as_str().into(),
+                    Value::Blob(channel.clone()),
+                    Value::Blob(message.clone()),
+                ]));
+                i += 1;
+            }
+        }
+
+        i
+    }
+
+    pub fn punsubscribe(&self, channels: &[Pattern], conn: &Connection) -> u32 {
+        let mut all_subs = self.psubscriptions.write();
+        let conn_id = conn.id();
+        let mut removed = 0;
+        channels
+            .iter()
+            .map(|channel| {
+                if let Some(subs) = all_subs.get_mut(channel) {
+                    if let Some(sender) = subs.remove(&conn_id) {
+                        let _ = sender.send(Value::Array(vec![
+                            "punsubscribe".into(),
+                            channel.as_str().into(),
+                            1.into(),
+                        ]));
+                        removed += 1;
+                    }
+                    if subs.is_empty() {
+                        all_subs.remove(channel);
+                    }
+                }
+            })
+            .for_each(drop);
+
+        removed
+    }
+
+    pub fn subscribe(&self, channels: &[Bytes], conn: &Connection) {
+        let mut subscriptions = self.subscriptions.write();
+
+        channels
+            .iter()
+            .map(|channel| {
+                if let Some(subs) = subscriptions.get_mut(channel) {
+                    subs.insert(conn.id(), conn.pubsub_client().sender());
+                } else {
+                    let mut h = HashMap::new();
+                    h.insert(conn.id(), conn.pubsub_client().sender());
+                    subscriptions.insert(channel.clone(), h);
+                }
+
+                let _ = conn.pubsub_client().sender().send(
+                    vec![
+                        "subscribe".into(),
+                        Value::Blob(channel.clone()),
+                        conn.pubsub_client().new_subscription(channel).into(),
+                    ]
+                    .into(),
+                );
+            })
+            .for_each(drop);
+    }
+
+    pub fn unsubscribe(&self, channels: &[Bytes], conn: &Connection) -> u32 {
+        let mut all_subs = self.subscriptions.write();
+        let conn_id = conn.id();
+        let mut removed = 0;
+        channels
+            .iter()
+            .map(|channel| {
+                if let Some(subs) = all_subs.get_mut(channel) {
+                    if let Some(sender) = subs.remove(&conn_id) {
+                        let _ = sender.send(Value::Array(vec![
+                            "unsubscribe".into(),
+                            Value::Blob(channel.clone()),
+                            1.into(),
+                        ]));
+                        removed += 1;
+                    }
+                    if subs.is_empty() {
+                        all_subs.remove(channel);
+                    }
+                }
+            })
+            .for_each(drop);
+
+        removed
+    }
+}

+ 8 - 8
src/db/mod.rs

@@ -35,16 +35,16 @@ pub struct Db {
     /// Number of HashMaps that are available.
     slots: usize,
 
-    // A Database is attached to a conn_id. The entries and expiration data
-    // structures are shared between all connections.
-    //
-    // This particular database instace is attached to a conn_id, used to block
-    // all keys in case of a transaction.
+    /// A Database is attached to a conn_id. The entries and expiration data
+    /// structures are shared between all connections.
+    ///
+    /// This particular database instace is attached to a conn_id, used to block
+    /// all keys in case of a transaction.
     conn_id: u128,
 
-    // HashMap of all blocked keys by other connections. If a key appears in
-    // here and it is not being hold by the current connection, current
-    // connection must wait.
+    /// HashMap of all blocked keys by other connections. If a key appears in
+    /// here and it is not being hold by the current connection, current
+    /// connection must wait.
     tx_key_locks: Arc<RwLock<HashMap<Bytes, u128>>>,
 }
 

+ 72 - 1
src/dispatcher.rs

@@ -1,4 +1,10 @@
-use crate::{cmd, connection::Connection, dispatcher, error::Error, value::Value};
+use crate::{
+    cmd,
+    connection::{Connection, ConnectionStatus},
+    dispatcher,
+    error::Error,
+    value::Value,
+};
 use bytes::Bytes;
 use std::convert::TryInto;
 use std::time::SystemTime;
@@ -709,6 +715,15 @@ dispatcher! {
             0,
             true,
         },
+        reset {
+            cmd::client::reset,
+            [""],
+            1,
+            0,
+            0,
+            0,
+            false,
+        },
     },
     transaction {
         discard {
@@ -757,6 +772,62 @@ dispatcher! {
             true,
         },
     },
+    pubsub {
+        publish {
+            cmd::pubsub::publish,
+            [""],
+            3,
+            0,
+            0,
+            0,
+            true,
+        },
+        pubsub {
+            cmd::pubsub::pubsub,
+            [""],
+            -2,
+            0,
+            0,
+            0,
+            true,
+        },
+        psubscribe {
+            cmd::pubsub::subscribe,
+            [""],
+            -2,
+            0,
+            0,
+            0,
+            true,
+        },
+        punsubscribe {
+            cmd::pubsub::punsubscribe,
+            [""],
+            -1,
+            0,
+            0,
+            0,
+            true,
+        },
+        subscribe {
+            cmd::pubsub::subscribe,
+            [""],
+            -2,
+            0,
+            0,
+            0,
+            true,
+        },
+        unsubscribe {
+            cmd::pubsub::unsubscribe,
+            [""],
+            -1,
+            0,
+            0,
+            0,
+            true,
+        },
+    },
     server {
         time {
             do_time,

+ 4 - 0
src/error.rs

@@ -4,10 +4,12 @@ use crate::value::Value;
 pub enum Error {
     CommandNotFound(String),
     InvalidArgsCount(String),
+    InvalidPattern(String),
     Protocol(String, String),
     WrongArgument(String, String),
     NotFound,
     OutOfRange,
+    PubsubOnly(String),
     Syntax,
     NotANumber,
     NotInTx,
@@ -27,6 +29,7 @@ impl From<Error> for Value {
         let err_msg = match value {
             Error::CommandNotFound(x) => format!("unknown command `{}`", x),
             Error::InvalidArgsCount(x) => format!("wrong number of arguments for '{}' command", x),
+            Error::InvalidPattern(x) => format!("'{}' is not a valid pattern", x),
             Error::Protocol(x, y) => format!("Protocol error: expected '{}', got '{}'", x, y),
             Error::NotInTx => " without MULTI".to_owned(),
             Error::NotANumber => "value is not an integer or out of range".to_owned(),
@@ -34,6 +37,7 @@ impl From<Error> for Value {
             Error::Syntax => "syntax error".to_owned(),
             Error::NotFound => "no such key".to_owned(),
             Error::NestedTx => "calls can not be nested".to_owned(),
+            Error::PubsubOnly(x) => format!("Can't execute '{}': only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT / RESET are allowed in this context", x),
             Error::WrongArgument(x, y) => format!(
                 "Unknown subcommand or wrong number of arguments for '{}'. Try {} HELP.",
                 y, x

+ 10 - 1
src/macros.rs

@@ -42,15 +42,22 @@ macro_rules! dispatcher {
                 #[async_trait]
                 impl ExecutableCommand for Command {
                     async fn execute(&self, conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
-                        if conn.in_transaction() && self.is_queueable() {
+                        let status = conn.status();
+                        if status == ConnectionStatus::Multi && self.is_queueable() {
                             conn.queue_command(args);
                             conn.tx_keys(self.get_keys(args));
                             Ok(Value::Queued)
+                        } else if status == ConnectionStatus::Pubsub && ! self.is_pubsub_executable() {
+                            Err(Error::PubsubOnly(stringify!($command).to_owned()))
                         } else {
                             $handler(conn, args).await
                         }
                     }
 
+                    fn is_pubsub_executable(&self) -> bool {
+                        stringify!($ns) == "pubsub" || stringify!($command) == "ping" || stringify!($command) == "reset"
+                    }
+
                     fn is_queueable(&self) -> bool {
                         $queueable
                     }
@@ -105,6 +112,8 @@ macro_rules! dispatcher {
 
             fn is_queueable(&self) -> bool;
 
+            fn is_pubsub_executable(&self) -> bool;
+
             fn get_keys<'a>(&self, args: &'a [Bytes]) -> Vec<&'a Bytes>;
 
             fn check_number_args(&self, n: usize) -> bool;

+ 38 - 14
src/server.rs

@@ -1,4 +1,9 @@
-use crate::{connection::Connections, db::Db, dispatcher::Dispatcher, value::Value};
+use crate::{
+    connection::{connections::Connections, ConnectionStatus},
+    db::Db,
+    dispatcher::Dispatcher,
+    value::Value,
+};
 use bytes::{Buf, Bytes, BytesMut};
 use futures::SinkExt;
 use log::{info, trace, warn};
@@ -51,7 +56,7 @@ pub async fn serve(addr: String) -> Result<(), Box<dyn Error>> {
     info!("Listening on: {}", addr);
 
     let db = Arc::new(Db::new(1000));
-    let all_connections = Arc::new(Connections::new());
+    let all_connections = Arc::new(Connections::new(db.clone()));
 
     let db_for_purging = db.clone();
     tokio::spawn(async move {
@@ -64,34 +69,53 @@ pub async fn serve(addr: String) -> Result<(), Box<dyn Error>> {
     loop {
         match listener.accept().await {
             Ok((socket, addr)) => {
-                let conn = all_connections.new_connection(db.clone(), addr);
+                let (mut pubsub, conn) = all_connections.new_connection(db.clone(), addr);
 
                 tokio::spawn(async move {
                     let mut transport = Framed::new(socket, RedisParser);
 
                     trace!("New connection {}", conn.id());
 
-                    while let Some(result) = transport.next().await {
-                        match result {
-                            Ok(args) => match Dispatcher::new(&args) {
+                    loop {
+                        tokio::select! {
+                            Some(msg) = pubsub.recv() => {
+                                if transport.send(msg).await.is_err() {
+                                    break;
+                                }
+                            }
+                            result = transport.next() => match result {
+                            Some(Ok(args)) => match Dispatcher::new(&args) {
                                 Ok(handler) => {
-                                    let r = handler
+                                    match handler
                                         .execute(&conn, &args)
-                                        .await
-                                        .unwrap_or_else(|x| x.into());
-                                    if transport.send(r).await.is_err() {
-                                        break;
-                                    }
-                                }
+                                        .await {
+                                            Ok(result) => {
+                                                if conn.status() == ConnectionStatus::Pubsub {
+                                                    continue;
+                                                }
+                                                if transport.send(result).await.is_err() {
+                                                    break;
+                                                }
+                                            },
+                                            Err(err) => {
+                                                if transport.send(err.into()).await.is_err() {
+                                                    break;
+                                                }
+                                            }
+                                        };
+
+                                },
                                 Err(err) => {
                                     if transport.send(err.into()).await.is_err() {
                                         break;
                                     }
                                 }
                             },
-                            Err(e) => {
+                            Some(Err(e)) => {
                                 warn!("error on decoding from socket; error = {:?}", e);
                                 break;
+                            },
+                            None => break,
                             }
                         }
                     }

+ 1 - 0
src/value/mod.rs

@@ -114,6 +114,7 @@ impl<'a> From<&ParsedValue<'a>> for Value {
 
 value_try_from!(f64, Value::Float);
 value_try_from!(i32, Value::Integer);
+value_try_from!(u32, Value::Integer);
 value_try_from!(i64, Value::Integer);
 value_try_from!(i128, Value::BigInteger);