Browse Source

Improve blocking tasks (#53)

Introduce a generic way to spawn blocking tasks, adding an optional
timeout option and a retry mechanism.

The spawned task will subscribe to a set of keys and will be awaken
everytime the data version changes in any of the watched keys. The
worker function will return true or false, false meaning they want to be
rescheduled again.

This mechanism will improve CPU usage on current approach which is
sleep/retry every few milliseconds, which is suboptimal to say the
least.
César D. Rodas 2 years ago
parent
commit
4a79d00983
5 changed files with 228 additions and 141 deletions
  1. 5 0
      src/cmd/client.rs
  2. 124 121
      src/cmd/list.rs
  3. 41 9
      src/connection/mod.rs
  4. 58 8
      src/db/mod.rs
  5. 0 3
      tests/unit/type/list.tcl

+ 5 - 0
src/cmd/client.rs

@@ -59,6 +59,11 @@ pub async fn client(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
             };
 
             Ok(if other_conn.unblock(reason) {
+                other_conn.append_response(if reason == UnblockReason::Error {
+                    Error::UnblockByError.into()
+                } else {
+                    Value::Null
+                });
                 1.into()
             } else {
                 0.into()

+ 124 - 121
src/cmd/list.rs

@@ -10,8 +10,12 @@ use crate::{
     value::Value,
 };
 use bytes::Bytes;
-use std::collections::VecDeque;
-use tokio::time::{sleep, Duration, Instant};
+use futures::{stream::FuturesUnordered, Future, StreamExt};
+use std::{collections::VecDeque, ops::Deref, sync::Arc};
+use tokio::{
+    sync::broadcast::{self, error::RecvError, Receiver},
+    time::{sleep, Duration, Instant},
+};
 
 #[allow(clippy::needless_range_loop)]
 /// Removes an element from a list
@@ -73,24 +77,85 @@ fn remove_element(
 }
 
 #[inline]
-/// Handles the timeout/sleep logic for all blocking commands.
-async fn handle_timeout(conn: &Connection, timeout: Option<Instant>) -> Result<bool, Error> {
+async fn wait_for_event(receiver: &mut Receiver<()>) -> () {
+    let _ = receiver.recv().await;
+    ()
+}
+
+#[inline]
+async fn schedule_blocking_task<F, T>(
+    conn: Arc<Connection>,
+    keys_to_watch: Vec<Bytes>,
+    worker: F,
+    args: Vec<Bytes>,
+    timeout: Option<Instant>,
+) where
+    F: Fn(Arc<Connection>, Vec<Bytes>, usize) -> T + Send + Sync + 'static,
+    T: Future<Output = Result<Value, Error>> + Send + Sync + 'static,
+{
+    let (mut timeout_sx, mut timeout_rx) = broadcast::channel::<()>(1);
+    conn.block();
+
     if let Some(timeout) = timeout {
-        if Instant::now() >= timeout {
-            conn.unblock(UnblockReason::Timeout);
-            return Ok(true);
-        }
+        // setup timeout triggering event
+        let conn_for_timeout = conn.clone();
+        let keys_to_watch_for_timeout = keys_to_watch.clone();
+        let block_id = conn.get_block_id();
+        tokio::spawn(async move {
+            sleep(timeout - Instant::now()).await;
+            if conn_for_timeout.get_block_id() != block_id {
+                // Timeout trigger event is not longer relevant
+                return;
+            }
+            conn_for_timeout.unblock(UnblockReason::Timeout);
+            conn_for_timeout.append_response(Value::Null);
+            // Notify timeout event to the worker thread
+            timeout_sx.send(());
+        });
     }
 
-    if let Some(reason) = conn.has_been_unblocked_externally() {
-        match reason {
-            UnblockReason::Error => Err(Error::UnblockByError),
-            _ => Ok(true),
+    tokio::spawn(async move {
+        let db = conn.db();
+
+        let mut changes_watchers = db.subscribe_to_key_changes(&keys_to_watch);
+        let mut externally_unblock_watcher = conn.get_unblocked_subscription();
+
+        let mut attempt = 1;
+
+        loop {
+            // Run task
+            match worker(conn.clone(), args.to_vec(), attempt).await {
+                Ok(Value::Ignore | Value::Null) => {}
+                Ok(result) => {
+                    conn.append_response(result);
+                    conn.unblock(UnblockReason::Finished);
+                }
+                Err(x) => {
+                    conn.append_response(x.into());
+                    conn.unblock(UnblockReason::Finished);
+                }
+            }
+
+            attempt += 1;
+
+            if !conn.is_blocked() {
+                break;
+            }
+
+            let mut futures = changes_watchers
+                .iter_mut()
+                .map(|c| wait_for_event(c))
+                .collect::<FuturesUnordered<_>>();
+
+            futures.push(wait_for_event(&mut timeout_rx));
+            if let Some(ref mut externally) = &mut externally_unblock_watcher {
+                futures.push(wait_for_event(externally));
+            }
+
+            // wait until a key changes or a timeout event occurs
+            let _ = futures.next().await;
         }
-    } else {
-        sleep(Duration::from_millis(5)).await;
-        Ok(false)
-    }
+    });
 }
 
 /// Parses timeout and returns an instant or none if it should wait forever.
@@ -119,56 +184,34 @@ fn parse_timeout(arg: &Bytes) -> Result<Option<Instant>, Error> {
 /// popped from the head of the first list that is non-empty, with the given keys being checked in
 /// the order that they are given.
 pub async fn blpop(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
-    let blpop_task = |conn: &Connection, args: &[Bytes]| -> Result<Value, Error> {
+    let blpop_task = |conn: Arc<Connection>, args: Vec<Bytes>, attempt| async move {
         for key in (1..args.len() - 1) {
             let key = &args[key];
-            match remove_element(&conn, key, None, true)? {
-                Value::Null => (),
-                n => return Ok(vec![Value::new(&key), n].into()),
+            match remove_element(&conn, key, None, true) {
+                Ok(Value::Null) => (),
+                Ok(n) => return Ok(vec![Value::new(&key), n].into()),
+                Err(x) => {
+                    if attempt == 1 {
+                        return Err(x);
+                    }
+                }
             };
         }
         Ok(Value::Null)
     };
 
     if conn.is_executing_tx() {
-        return blpop_task(conn, args);
+        return blpop_task(conn.clone(), args.to_vec(), 1).await;
     }
 
     let timeout = parse_timeout(&args[args.len() - 1])?;
     let conn = conn.clone();
     let args = args.to_vec();
+    let keys_to_watch = (&args[1..args.len() - 1]).to_vec();
 
     conn.block();
 
-    tokio::spawn(async move {
-        loop {
-            match blpop_task(&conn, &args) {
-                Ok(Value::Null) => {}
-                Ok(x) => {
-                    conn.append_response(x);
-                    conn.unblock(UnblockReason::Finished);
-                    break;
-                }
-                Err(x) => {
-                    conn.append_response(x.into());
-                    conn.unblock(UnblockReason::Finished);
-                    break;
-                }
-            }
-
-            match handle_timeout(&conn, timeout).await {
-                Ok(true) => {
-                    conn.append_response(Value::Null);
-                    break;
-                }
-                Err(x) => {
-                    conn.append_response(x.into());
-                    break;
-                }
-                _ => {}
-            }
-        }
-    });
+    schedule_blocking_task(conn.clone(), keys_to_watch, blpop_task, args, timeout).await;
 
     Ok(Value::Ignore)
 }
@@ -190,38 +233,16 @@ pub async fn blmove(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     }
 
     let timeout = parse_timeout(&args[5])?;
-    conn.block();
-
-    let conn = conn.clone();
-    let args = args.to_vec();
-    tokio::spawn(async move {
-        loop {
-            match lmove(&conn, &args).await {
-                Ok(Value::Null) => (),
-                Ok(n) => {
-                    conn.append_response(n);
-                    conn.unblock(UnblockReason::Finished);
-                    break;
-                }
-                Err(x) => {
-                    conn.append_response(x.into());
-                    conn.unblock(UnblockReason::Finished);
-                    break;
-                }
-            };
-            match handle_timeout(&conn, timeout).await {
-                Ok(true) => {
-                    conn.append_response(Value::Null);
-                    break;
-                }
-                Err(x) => {
-                    conn.append_response(x.into());
-                    break;
-                }
-                _ => {}
-            }
-        }
-    });
+    let keys_to_watch = (&args[1..=2]).to_vec();
+
+    schedule_blocking_task(
+        conn.clone(),
+        keys_to_watch,
+        |conn, args, _| async move { lmove(&conn, &args).await },
+        args.to_vec(),
+        timeout,
+    )
+    .await;
 
     Ok(Value::Ignore)
 }
@@ -252,55 +273,37 @@ pub async fn brpoplpush(conn: &Connection, args: &[Bytes]) -> Result<Value, Erro
 /// popped from the tail of the first list that is non-empty, with the given keys being checked in
 /// the order that they are given.
 pub async fn brpop(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
-    let brpop_task = |conn: &Connection, args: &[Bytes]| -> Result<Value, Error> {
+    let brpop_task = |conn: Arc<Connection>, args: Vec<Bytes>, attempt| async move {
         for key in (1..args.len() - 1) {
             let key = &args[key];
-            match remove_element(&conn, key, None, false)? {
-                Value::Null => (),
-                n => return Ok(vec![Value::new(&key), n].into()),
+            match remove_element(&conn, key, None, false) {
+                Ok(Value::Null) => (),
+                Ok(n) => return Ok(vec![Value::new(&key), n].into()),
+                Err(x) => {
+                    if attempt == 1 {
+                        return Err(x);
+                    }
+                }
             };
         }
         Ok(Value::Null)
     };
+
     if conn.is_executing_tx() {
-        return brpop_task(conn, args);
+        return brpop_task(conn.clone(), args.to_vec(), 1).await;
     }
 
     let timeout = parse_timeout(&args[args.len() - 1])?;
-    let conn = conn.clone();
-    let args = args.to_vec();
-
-    conn.block();
-
-    tokio::spawn(async move {
-        loop {
-            match brpop_task(&conn, &args) {
-                Ok(Value::Null) => {}
-                Ok(x) => {
-                    conn.append_response(x);
-                    conn.unblock(UnblockReason::Finished);
-                    break;
-                }
-                Err(x) => {
-                    conn.append_response(x.into());
-                    conn.unblock(UnblockReason::Finished);
-                    break;
-                }
-            }
-
-            match handle_timeout(&conn, timeout).await {
-                Ok(true) => {
-                    conn.append_response(Value::Null);
-                    break;
-                }
-                Err(x) => {
-                    conn.append_response(x.into());
-                    break;
-                }
-                _ => {}
-            }
-        }
-    });
+    let keys_to_watch = (&args[1..args.len() - 1]).to_vec();
+
+    schedule_blocking_task(
+        conn.clone(),
+        keys_to_watch,
+        brpop_task,
+        args.to_vec(),
+        timeout,
+    )
+    .await;
 
     Ok(Value::Ignore)
 }

+ 41 - 9
src/connection/mod.rs

@@ -1,10 +1,10 @@
 //! # Connection module
+use self::pubsub_server::Pubsub;
 use crate::{db::Db, error::Error, value::Value};
 use bytes::Bytes;
 use parking_lot::RwLock;
 use std::{collections::HashSet, sync::Arc};
-
-use self::pubsub_server::Pubsub;
+use tokio::sync::broadcast::{self, Receiver, Sender};
 
 pub mod connections;
 pub mod pubsub_connection;
@@ -31,7 +31,7 @@ impl Default for ConnectionStatus {
     }
 }
 
-#[derive(Debug, Copy, Clone)]
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
 /// Reason while a client was unblocked
 pub enum UnblockReason {
     /// Timeout
@@ -53,6 +53,8 @@ pub struct ConnectionInfo {
     status: ConnectionStatus,
     commands: Option<Vec<Vec<Bytes>>>,
     is_blocked: bool,
+    blocked_notification: Option<Sender<()>>,
+    block_id: usize,
     unblock_reason: Option<UnblockReason>,
 }
 
@@ -77,7 +79,9 @@ impl ConnectionInfo {
             tx_keys: HashSet::new(),
             commands: None,
             status: ConnectionStatus::Normal,
+            blocked_notification: None,
             is_blocked: false,
+            block_id: 0,
             unblock_reason: None,
         }
     }
@@ -129,29 +133,57 @@ impl Connection {
 
     /// Block the connection
     pub fn block(&self) {
+        let notification = broadcast::channel(1);
         let mut info = self.info.write();
         info.is_blocked = true;
+        info.blocked_notification = Some(notification.0);
+        info.block_id += 1;
         info.unblock_reason = None;
     }
 
+    /// Returns the current block task ID number. This is an internal ID to
+    /// identify each blocking command as unique
+    #[inline]
+    pub fn get_block_id(&self) -> Option<usize> {
+        let info = self.info.read();
+        if info.is_blocked {
+            Some(info.block_id)
+        } else {
+            None
+        }
+    }
+
+    /// Returns a receiver that will be called if the client is externally unblocked
+    #[inline]
+    pub fn get_unblocked_subscription(&self) -> Option<Receiver<()>> {
+        self.info
+            .read()
+            .blocked_notification
+            .as_ref()
+            .map(|notification| notification.subscribe())
+    }
+
     /// Unblock connection
     pub fn unblock(&self, reason: UnblockReason) -> bool {
         let mut info = self.info.write();
         if info.is_blocked {
+            let notification = info.blocked_notification.as_ref().map(|s| s.clone());
             info.is_blocked = false;
             info.unblock_reason = Some(reason);
+            info.blocked_notification = None;
+            drop(info); // drop write lock
+
+            if let Some(s) = notification {
+                // Notify connection about this change
+                s.send(());
+            }
+
             true
         } else {
             false
         }
     }
 
-    /// If the current connection has been externally unblocked
-    #[inline]
-    pub fn has_been_unblocked_externally(&self) -> Option<UnblockReason> {
-        self.info.read().unblock_reason
-    }
-
     /// Is the current connection blocked?
     #[inline]
     pub fn is_blocked(&self) -> bool {

+ 58 - 8
src/db/mod.rs

@@ -16,6 +16,7 @@ use bytes::{BufMut, Bytes, BytesMut};
 use core::num;
 use entry::{new_version, Entry};
 use expiration::ExpirationDb;
+use futures::Future;
 use glob::Pattern;
 use log::trace;
 use num_traits::CheckedAdd;
@@ -30,7 +31,10 @@ use std::{
     sync::Arc,
     thread,
 };
-use tokio::time::{Duration, Instant};
+use tokio::{
+    sync::broadcast::{self, Receiver, Sender},
+    time::{Duration, Instant},
+};
 
 mod entry;
 mod expiration;
@@ -69,6 +73,11 @@ pub struct Db {
     /// Data structure to store all expiring keys
     expirations: Arc<Mutex<ExpirationDb>>,
 
+    /// Key changes subscriptions hash. This hash contains all the senders to
+    /// key subscriptions. If a key does not exists here it means that no-one
+    /// wants to be notified of the current key changes.
+    change_subscriptions: Arc<RwLock<HashMap<Bytes, Sender<()>>>>,
+
     /// Number of HashMaps that are available.
     number_of_slots: usize,
 
@@ -81,7 +90,7 @@ pub struct Db {
     /// A Database is attached to a conn_id. The slots and expiration data
     /// structures are shared between all connections, regardless of conn_id.
     ///
-    /// This particular database instace is attached to a conn_id, which is used
+    /// This particular database instance is attached to a conn_id, which is used
     /// to lock keys exclusively for transactions and other atomic operations.
     conn_id: u128,
 
@@ -101,6 +110,7 @@ impl Db {
         Self {
             slots: Arc::new(slots),
             expirations: Arc::new(Mutex::new(ExpirationDb::new())),
+            change_subscriptions: Arc::new(RwLock::new(HashMap::new())),
             conn_id: 0,
             db_id: new_version(),
             tx_key_locks: Arc::new(RwLock::new(HashMap::new())),
@@ -118,6 +128,7 @@ impl Db {
             slots: self.slots.clone(),
             tx_key_locks: self.tx_key_locks.clone(),
             expirations: self.expirations.clone(),
+            change_subscriptions: self.change_subscriptions.clone(),
             conn_id,
             db_id: self.db_id,
             number_of_slots: self.number_of_slots,
@@ -154,7 +165,7 @@ impl Db {
 
     /// Locks keys exclusively
     ///
-    /// The locked keys are only accesible (read or write) by the connection
+    /// The locked keys are only accessible (read or write) by the connection
     /// that locked them, any other connection must wait until the locking
     /// connection releases them.
     ///
@@ -590,7 +601,7 @@ impl Db {
         let slot1 = self.get_slot(source);
         let slot2 = self.get_slot(target);
 
-        if slot1 == slot2 {
+        let result = if slot1 == slot2 {
             let mut slot = self.slots[slot1].write();
 
             if override_value == Override::No && slot.get(target).is_some() {
@@ -615,7 +626,14 @@ impl Db {
             } else {
                 Err(Error::NotFound)
             }
+        };
+
+        if result.is_ok() {
+            self.bump_version(source);
+            self.bump_version(target);
         }
+
+        result
     }
 
     /// Removes keys from the database
@@ -671,7 +689,7 @@ impl Db {
 
     /// get_map_or
     ///
-    /// Instead of returning an entry of the database, to avoid clonning, this function will
+    /// Instead of returning an entry of the database, to avoid cloning, this function will
     /// execute a callback function with the entry as a parameter. If no record is found another
     /// callback function is going to be executed, dropping the lock before doing so.
     ///
@@ -680,7 +698,7 @@ impl Db {
     /// entry itself.
     ///
     /// This function is useful to read non-scalar values from the database. Non-scalar values are
-    /// forbidden to clone, attempting cloning will endup in an error (Error::WrongType)
+    /// forbidden to clone, attempting cloning will end-up in an error (Error::WrongType)
     pub fn get_map_or<F1, F2>(&self, key: &Bytes, found: F1, not_found: F2) -> Result<Value, Error>
     where
         F1: FnOnce(&Value) -> Result<Value, Error>,
@@ -702,12 +720,44 @@ impl Db {
     /// Updates the entry version of a given key
     pub fn bump_version(&self, key: &Bytes) -> bool {
         let mut slot = self.slots[self.get_slot(key)].write();
-        slot.get_mut(key)
+        let to_return = slot
+            .get_mut(key)
             .filter(|x| x.is_valid())
             .map(|entry| {
                 entry.bump_version();
             })
-            .is_some()
+            .is_some();
+        drop(slot);
+        if to_return {
+            let senders = self.change_subscriptions.read();
+            if let Some(sender) = senders.get(key) {
+                if sender.receiver_count() == 0 {
+                    // Garbage collection
+                    drop(senders);
+                    self.change_subscriptions.write().remove(key);
+                } else {
+                    // Notify
+                    let _ = sender.send(());
+                }
+            }
+        }
+        to_return
+    }
+
+    /// Subscribe to key changes.
+    pub fn subscribe_to_key_changes(&self, keys: &[Bytes]) -> Vec<Receiver<()>> {
+        let mut subscriptions = self.change_subscriptions.write();
+        keys.iter()
+            .map(|key| {
+                if let Some(sender) = subscriptions.get(key) {
+                    sender.subscribe()
+                } else {
+                    let (sender, receiver) = broadcast::channel(1);
+                    subscriptions.insert(key.clone(), sender);
+                    receiver
+                }
+            })
+            .collect()
     }
 
     /// Returns the version of a given key

+ 0 - 3
tests/unit/type/list.tcl

@@ -423,8 +423,6 @@ start_server {
 
       r rpush list1{t} foo
 
-      after 50
-
       assert_equal {} [r lrange list1{t} 0 -1]
       assert_equal {} [r lrange list2{t} 0 -1]
       assert_equal {foo} [r lrange list3{t} 0 -1]
@@ -484,7 +482,6 @@ start_server {
         $watching_client get somekey{t}
         $watching_client read
         r lpush srclist{t} element
-        after 50
         $watching_client exec
         $watching_client read
     } {}