Parcourir la source

Added sorted-set commands

Cesar Rodas il y a 3 ans
Parent
commit
b0259c52a3
7 fichiers modifiés avec 573 ajouts et 107 suppressions
  1. 1 0
      src/cmd/mod.rs
  2. 201 0
      src/cmd/sorted_set.rs
  3. 29 0
      src/dispatcher/mod.rs
  4. 31 2
      src/value/mod.rs
  5. 0 105
      src/value/sorted_set.rs
  6. 114 0
      src/value/sorted_set/insert.rs
  7. 197 0
      src/value/sorted_set/mod.rs

+ 1 - 0
src/cmd/mod.rs

@@ -11,6 +11,7 @@ pub mod metrics;
 pub mod pubsub;
 pub mod server;
 pub mod set;
+pub mod sorted_set;
 pub mod string;
 pub mod transaction;
 

+ 201 - 0
src/cmd/sorted_set.rs

@@ -0,0 +1,201 @@
+//! # Sorted Set command handlers
+use std::collections::VecDeque;
+
+use crate::{
+    connection::Connection,
+    error::Error,
+    value::{
+        bytes_to_number, bytes_to_range_floatord,
+        sorted_set::{IOption, IResult},
+    },
+    value::{sorted_set::SortedSet, Value},
+};
+use bytes::Bytes;
+use float_ord::FloatOrd;
+
+/// Adds all the specified members with the specified scores to the sorted set
+/// stored at key. It is possible to specify multiple score / member pairs. If a
+/// specified member is already a member of the sorted set, the score is updated
+/// and the element reinserted at the right position to ensure the correct
+/// ordering.
+///
+/// If key does not exist, a new sorted set with the specified members as sole
+/// members is created, like if the sorted set was empty. If the key exists but
+/// does not hold a sorted set, an error is returned.
+///
+/// The score values should be the string representation of a double precision
+/// floating point number. +inf and -inf values are valid values as well.
+pub async fn zadd(conn: &Connection, mut args: VecDeque<Bytes>) -> Result<Value, Error> {
+    let key = args.pop_front().ok_or(Error::Syntax)?;
+    let option = IOption::new(&mut args)?;
+    if args.is_empty() {
+        return Err(Error::InvalidArgsCount("ZADD".to_owned()));
+    }
+    if args.len() % 2 != 0 {
+        return Err(Error::Syntax);
+    }
+    if args.len() != 2 && option.incr {
+        return Err(Error::OptsNotCompatible(
+            "INCR option supports a single increment-element pair".to_owned(),
+        ));
+    }
+    let result = conn
+        .db()
+        .get(&key)
+        .map_mut(|v| match v {
+            Value::SortedSet(x) => {
+                let mut insert: usize = 0;
+                let mut updated: usize = 0;
+
+                loop {
+                    let score = match args.pop_front() {
+                        Some(x) => bytes_to_number::<f64>(&x)?,
+                        None => break,
+                    };
+                    let value = args.pop_front().ok_or(Error::Syntax)?;
+                    match x.insert(FloatOrd(score), value, &option) {
+                        IResult::Inserted => insert += 1,
+                        IResult::Updated => updated += 1,
+                        _ => {}
+                    }
+                }
+
+                Ok(if option.return_change {
+                    updated
+                } else {
+                    insert
+                }
+                .into())
+            }
+            _ => Err(Error::WrongType),
+        })
+        .unwrap_or_else(|| {
+            let mut x = SortedSet::new();
+            let mut insert: usize = 0;
+            let mut updated: usize = 0;
+
+            loop {
+                let score = match args.pop_front() {
+                    Some(x) => bytes_to_number::<f64>(&x)?,
+                    None => break,
+                };
+                let value = args.pop_front().ok_or(Error::Syntax)?;
+                match x.insert(FloatOrd(score), value, &option) {
+                    IResult::Inserted => insert += 1,
+                    IResult::Updated => updated += 1,
+                    _ => {}
+                }
+            }
+
+            conn.db().set(key.clone(), x.into(), None);
+
+            Ok(if option.return_change {
+                updated
+            } else {
+                insert
+            }
+            .into())
+        })?;
+
+    conn.db().bump_version(&key);
+
+    Ok(result)
+}
+
+/// Returns the sorted set cardinality (number of elements) of the sorted set
+/// stored at key.
+pub async fn zcard(conn: &Connection, args: VecDeque<Bytes>) -> Result<Value, Error> {
+    conn.db()
+        .get(&args[0])
+        .map(|v| match v {
+            Value::SortedSet(x) => Ok(x.len().into()),
+            _ => Err(Error::WrongType),
+        })
+        .unwrap_or(Ok(0.into()))
+}
+
+/// Returns the number of elements in the sorted set at key with a score between
+/// min and max.
+///
+/// The min and max arguments have the same semantic as described for
+/// ZRANGEBYSCORE.
+pub async fn zcount(conn: &Connection, args: VecDeque<Bytes>) -> Result<Value, Error> {
+    let min = bytes_to_range_floatord(&args[1])?;
+    let max = bytes_to_range_floatord(&args[2])?;
+    conn.db()
+        .get(&args[0])
+        .map(|v| match v {
+            Value::SortedSet(x) => Ok(x.count_values_by_score_range(min, max).into()),
+            _ => Err(Error::WrongType),
+        })
+        .unwrap_or(Ok(0.into()))
+}
+
+#[cfg(test)]
+mod test {
+    use crate::{
+        cmd::test::{create_connection, run_command},
+        error::Error,
+    };
+
+    #[tokio::test]
+    async fn test_set_wrong_type() {
+        let c = create_connection();
+
+        let _ = run_command(&c, &["set", "foo", "1"]).await;
+
+        assert_eq!(
+            Err(Error::WrongType),
+            run_command(&c, &["zadd", "foo", "5", "bar", "1", "foo"]).await,
+        );
+    }
+
+    #[tokio::test]
+    async fn test_zadd() {
+        let c = create_connection();
+
+        assert_eq!(
+            Ok(2.into()),
+            run_command(&c, &["zadd", "foo", "5", "bar", "1", "foo"]).await,
+        );
+        assert_eq!(
+            Ok(0.into()),
+            run_command(&c, &["zadd", "foo", "5", "bar", "1", "foo"]).await,
+        );
+        assert_eq!(Ok(2.into()), run_command(&c, &["zcard", "foo"]).await,);
+    }
+
+    #[tokio::test]
+    async fn test_zcount() {
+        let c = create_connection();
+
+        assert_eq!(
+            Ok(3.into()),
+            run_command(
+                &c,
+                &["zadd", "foo", "5", "bar", "1", "foo", "5.9", "foobar"]
+            )
+            .await,
+        );
+        assert_eq!(
+            Ok(0.into()),
+            run_command(&c, &["zadd", "foo", "nx", "511", "bar", "10", "foo"]).await,
+        );
+        assert_eq!(
+            Ok(2.into()),
+            run_command(&c, &["zcount", "foo", "1", "5"]).await,
+        );
+        assert_eq!(
+            Ok(1.into()),
+            run_command(&c, &["zcount", "foo", "1", "(5"]).await,
+        );
+        assert_eq!(
+            Ok(0.into()),
+            run_command(&c, &["zcount", "foo", "(1", "(5"]).await,
+        );
+        assert_eq!(
+            Ok(3.into()),
+            run_command(&c, &["zcount", "foo", "-inf", "+inf"]).await,
+        );
+    }
+}

+ 29 - 0
src/dispatcher/mod.rs

@@ -19,6 +19,35 @@ pub mod command;
 
 // Returns the server time
 dispatcher! {
+    sorted_set {
+        ZADD {
+            cmd::sorted_set::zadd,
+            [Flag::Write Flag::DenyOom Flag::Fast],
+            -4,
+            1,
+            1,
+            1,
+            true,
+        },
+        ZCARD {
+            cmd::sorted_set::zcard,
+            [Flag::ReadOnly Flag::Fast],
+            2,
+            1,
+            1,
+            1,
+            true,
+        },
+        ZCOUNT {
+            cmd::sorted_set::zcount,
+            [Flag::ReadOnly Flag::Fast],
+            4,
+            1,
+            1,
+            1,
+            true,
+        },
+    },
     set {
         SADD {
             cmd::set::sadd,

+ 31 - 2
src/value/mod.rs

@@ -14,7 +14,7 @@ use float_ord::FloatOrd;
 use redis_zero_protocol_parser::Value as ParsedValue;
 use sha2::{Digest, Sha256};
 use std::{
-    collections::{HashMap, HashSet, VecDeque},
+    collections::{Bound, HashMap, HashSet, VecDeque},
     convert::{TryFrom, TryInto},
     str::FromStr,
 };
@@ -33,7 +33,7 @@ pub enum Value {
     /// Set. This type cannot be serialized
     Set(HashSet<Bytes>),
     /// Sorted set
-    SortedSet(sorted_set::SortedSet<FloatOrd<f64>, Bytes>),
+    SortedSet(sorted_set::SortedSet),
     /// Vector/Array of values
     Array(Vec<Value>),
     /// Bytes/Strings/Binary data
@@ -229,6 +229,29 @@ pub fn bytes_to_int<T: FromStr>(bytes: &[u8]) -> Result<T, Error> {
         .map_err(|_| Error::NotANumberType("an integer".to_owned()))
 }
 
+/// Converts bytes to a Range number
+pub fn bytes_to_range<T: FromStr>(bytes: &[u8]) -> Result<Bound<T>, Error> {
+    match bytes {
+        b"-inf" | b"+inf" | b"inf" => Ok(Bound::Unbounded),
+        _ => {
+            if bytes[0] == b'(' {
+                Ok(Bound::Excluded(bytes_to_number::<T>(&(bytes[1..]))?))
+            } else {
+                Ok(Bound::Included(bytes_to_number::<T>(bytes)?))
+            }
+        }
+    }
+}
+
+/// Converts bytes to a Range of float FloatOrd numbers
+pub fn bytes_to_range_floatord(bytes: &[u8]) -> Result<Bound<FloatOrd<f64>>, Error> {
+    match bytes_to_range(bytes)? {
+        Bound::Included(n) => Ok(Bound::Included(FloatOrd(n))),
+        Bound::Excluded(n) => Ok(Bound::Excluded(FloatOrd(n))),
+        Bound::Unbounded => Ok(Bound::Unbounded),
+    }
+}
+
 impl<'a> From<&ParsedValue<'a>> for Value {
     fn from(value: &ParsedValue) -> Self {
         match value {
@@ -285,6 +308,12 @@ impl From<&str> for Value {
     }
 }
 
+impl From<sorted_set::SortedSet> for Value {
+    fn from(value: sorted_set::SortedSet) -> Self {
+        Value::SortedSet(value)
+    }
+}
+
 impl From<HashMap<Bytes, Bytes>> for Value {
     fn from(value: HashMap<Bytes, Bytes>) -> Value {
         Value::Hash(value)

+ 0 - 105
src/value/sorted_set.rs

@@ -1,105 +0,0 @@
-//! # Sorted Set module
-use std::{
-    collections::{btree_map::Iter, BTreeMap, HashMap},
-    hash::Hash,
-};
-
-/// Sorted set structure
-#[derive(Debug, Clone)]
-pub struct SortedSet<S: Clone + PartialEq + Ord, V: Clone + Eq + Hash> {
-    set: HashMap<V, (S, usize)>,
-    order: BTreeMap<S, V>,
-    position_updated: bool,
-}
-
-impl<S: PartialEq + Clone + Ord, V: Eq + Clone + Hash> PartialEq for SortedSet<S, V> {
-    fn eq(&self, other: &SortedSet<S, V>) -> bool {
-        self.order == other.order
-    }
-}
-
-impl<S: PartialEq + Clone + Ord, V: Eq + Clone + Hash> SortedSet<S, V> {
-    /// Creates a new instance
-    pub fn new() -> Self {
-        Self {
-            set: HashMap::new(),
-            order: BTreeMap::new(),
-            position_updated: true,
-        }
-    }
-
-    /// Clears the map, removing all elements.
-    pub fn clear(&mut self) {
-        self.set.clear();
-        self.order.clear();
-    }
-
-    /// Gets an iterator over the entries of the map, sorted by score.
-    pub fn iter(&self) -> Iter<'_, S, V> {
-        self.order.iter()
-    }
-
-    /// Adds a value to the set.
-    /// If the set did not have this value present, true is returned.
-    ///
-    /// If the set did have this value present, false is returned.
-    pub fn insert(&mut self, score: S, value: &V) -> bool {
-        if self.set.get(value).is_none() {
-            self.set.insert(value.clone(), (score.clone(), 0));
-            self.order.insert(score, value.clone());
-            self.position_updated = false;
-            true
-        } else {
-            false
-        }
-    }
-
-    /// Returns a reference to the score in the set, if any, that is equal to the given value.
-    pub fn get_score(&self, value: &V) -> Option<&S> {
-        self.set.get(value).map(|(value, _)| value)
-    }
-
-    /// Returns all the values sorted by their score
-    pub fn get_values(&self) -> Vec<V> {
-        self.order.values().cloned().collect()
-    }
-
-    /// Adds the position in the set to each value based on their score
-    fn update_value_position(&mut self) {
-        let mut i = 0;
-        for element in self.order.values() {
-            if let Some(value) = self.set.get_mut(element) {
-                value.1 = i;
-            }
-            i += 1;
-        }
-        self.position_updated = true;
-    }
-
-    /// Return the position into the set based on their score
-    pub fn get_value_pos(&mut self, value: &V) -> Option<usize> {
-        if self.position_updated {
-            Some(self.set.get(value)?.1)
-        } else {
-            self.update_value_position();
-            Some(self.set.get(value)?.1)
-        }
-    }
-}
-
-#[cfg(test)]
-mod test {
-    use super::*;
-
-    #[test]
-    fn basic_usage() {
-        let mut set: SortedSet<i64, i64> = SortedSet::new();
-        assert!(set.insert(1, &2));
-        assert!(set.insert(0, &3));
-        assert!(!set.insert(33, &3));
-        assert_eq!(vec![3, 2], set.get_values());
-        assert_eq!(Some(1), set.get_value_pos(&2));
-        assert_eq!(Some(0), set.get_value_pos(&3));
-        assert_eq!(None, set.get_value_pos(&5));
-    }
-}

+ 114 - 0
src/value/sorted_set/insert.rs

@@ -0,0 +1,114 @@
+use crate::error::Error;
+use bytes::Bytes;
+use std::{collections::VecDeque, fmt::Debug};
+
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
+pub(crate) enum IPolicy {
+    NX,
+    XX,
+}
+
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
+pub(crate) enum UPolicyScore {
+    LT,
+    GT,
+}
+
+#[derive(Debug, Default, Clone)]
+/// Insert option
+pub struct IOption {
+    pub(crate) insert_policy: Option<IPolicy>,
+    pub(crate) update_policy_score: Option<UPolicyScore>,
+    /// Modify the return value from the number of new elements added, to the
+    /// total number of elements changed (CH is an abbreviation of changed).
+    /// Changed elements are new elements added and elements already existing
+    /// for which the score was updated. So elements specified in the command
+    /// line having the same score as they had in the past are not counted.
+    /// Note: normally the return value of ZADD only counts the number of new
+    /// elements added.
+    pub return_change: bool,
+    /// Increments instead of adding
+    pub incr: bool,
+}
+
+impl IOption {
+    /// Creates a new instance
+    pub fn new(args: &mut VecDeque<Bytes>) -> Result<Self, Error> {
+        let mut update_policy = None;
+        let mut update_policy_score = None;
+        let mut return_change = false;
+        let mut incr = false;
+        loop {
+            match args.get(0) {
+                Some(t) => {
+                    let command = String::from_utf8_lossy(t);
+                    match command.to_uppercase().as_str() {
+                        "NX" => {
+                            if update_policy == Some(IPolicy::XX) {
+                                return Err(Error::OptsNotCompatible("XX AND NX".to_owned()));
+                            }
+                            update_policy = Some(IPolicy::NX);
+                            args.pop_front();
+                        }
+                        "XX" => {
+                            if update_policy == Some(IPolicy::NX) {
+                                return Err(Error::OptsNotCompatible("XX AND NX".to_owned()));
+                            }
+                            update_policy = Some(IPolicy::XX);
+                            args.pop_front();
+                        }
+                        "LT" => {
+                            if update_policy == Some(IPolicy::NX)
+                                || update_policy_score == Some(UPolicyScore::GT)
+                            {
+                                return Err(Error::OptsNotCompatible(
+                                    "GT, LT, and/or NX".to_owned(),
+                                ));
+                            }
+                            update_policy_score = Some(UPolicyScore::LT);
+                            args.pop_front();
+                        }
+                        "GT" => {
+                            if update_policy == Some(IPolicy::NX)
+                                || update_policy_score == Some(UPolicyScore::LT)
+                            {
+                                return Err(Error::OptsNotCompatible(
+                                    "GT, LT, and/or NX".to_owned(),
+                                ));
+                            }
+                            update_policy_score = Some(UPolicyScore::GT);
+                            args.pop_front();
+                        }
+                        "CH" => {
+                            return_change = true;
+                            args.pop_front();
+                        }
+                        "INCR" => {
+                            incr = true;
+                            args.pop_front();
+                        }
+                        _ => break,
+                    }
+                }
+                None => break,
+            }
+        }
+        Ok(Self {
+            insert_policy: update_policy,
+            update_policy_score,
+            return_change,
+            incr,
+        })
+    }
+}
+
+/// Insert result
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub enum IResult {
+    /// No operation has taken place
+    NoOp,
+    /// A new element has been added
+    Inserted,
+    /// An element has been updated
+    Updated,
+}

+ 197 - 0
src/value/sorted_set/mod.rs

@@ -0,0 +1,197 @@
+//! # Sorted Set module
+use bytes::Bytes;
+use float_ord::FloatOrd;
+use std::{
+    collections::{BTreeMap, HashMap},
+    fmt::Debug,
+    ops::Bound,
+};
+
+mod insert;
+
+pub use insert::{IOption, IResult};
+use insert::{IPolicy, UPolicyScore};
+
+/// Sorted set structure
+#[derive(Debug, Clone)]
+pub struct SortedSet {
+    set: HashMap<Bytes, (FloatOrd<f64>, usize)>,
+    order: BTreeMap<(FloatOrd<f64>, Bytes), usize>,
+}
+
+impl PartialEq for SortedSet {
+    fn eq(&self, other: &SortedSet) -> bool {
+        self.order == other.order
+    }
+}
+
+impl SortedSet {
+    /// Creates a new instance
+    pub fn new() -> Self {
+        Self {
+            set: HashMap::new(),
+            order: BTreeMap::new(),
+        }
+    }
+
+    /// Clears the map, removing all elements.
+    pub fn clear(&mut self) {
+        self.set.clear();
+        self.order.clear();
+    }
+
+    /// Returns the number of elements in the set
+    pub fn len(&self) -> usize {
+        self.set.len()
+    }
+
+    /// Adds a value to the set.
+    /// If the set did not have this value present, true is returned.
+    ///
+    /// If the set did have this value present, false is returned.
+    pub fn insert(&mut self, score: FloatOrd<f64>, value: Bytes, option: &IOption) -> IResult {
+        if let Some((current_score, _)) = self.set.get(&value).cloned() {
+            if option.insert_policy == Some(IPolicy::NX) {
+                return IResult::NoOp;
+            }
+            let cmp = current_score.cmp(&score);
+            let update_based_on_score =
+                option
+                    .update_policy_score
+                    .map_or(true, |policy| match policy {
+                        UPolicyScore::LT => cmp == std::cmp::Ordering::Less,
+                        UPolicyScore::GT => cmp == std::cmp::Ordering::Greater,
+                    });
+
+            if !update_based_on_score {
+                return IResult::NoOp;
+            }
+            // remove the previous order entry
+            self.order.remove(&(current_score, value.clone()));
+
+            let score = if option.incr {
+                FloatOrd(current_score.0 + score.0)
+            } else {
+                score
+            };
+
+            // update and insert the new order entry
+            self.set.insert(value.clone(), (score, 0));
+            self.order.insert((score, value), 0);
+
+            self.update_value_position();
+            IResult::Updated
+        } else {
+            if option.insert_policy == Some(IPolicy::XX) {
+                return IResult::NoOp;
+            }
+            self.set.insert(value.clone(), (score, 0));
+            self.order.insert((score, value), 0);
+            self.update_value_position();
+            IResult::Inserted
+        }
+    }
+
+    /// Returns a reference to the score in the set, if any, that is equal to the given value.
+    pub fn get_score(&self, value: &Bytes) -> Option<FloatOrd<f64>> {
+        self.set.get(value).map(|(value, _)| *value)
+    }
+
+    /// Returns all the values sorted by their score
+    pub fn get_values(&self) -> Vec<Bytes> {
+        self.order.keys().map(|(_, value)| value.clone()).collect()
+    }
+
+    #[inline]
+    fn convert_to_range(
+        min: Bound<FloatOrd<f64>>,
+        max: Bound<FloatOrd<f64>>,
+    ) -> (Bound<(FloatOrd<f64>, Bytes)>, Bound<(FloatOrd<f64>, Bytes)>) {
+        let min_bytes = Bytes::new();
+        let max_bytes = Bytes::copy_from_slice(&vec![255u8; 4096]);
+
+        (
+            match min {
+                Bound::Included(value) => Bound::Included((value, min_bytes.clone())),
+                Bound::Excluded(value) => Bound::Excluded((value, max_bytes.clone())),
+                Bound::Unbounded => Bound::Unbounded,
+            },
+            match max {
+                Bound::Included(value) => Bound::Included((value, max_bytes)),
+                Bound::Excluded(value) => Bound::Excluded((value, min_bytes)),
+                Bound::Unbounded => Bound::Unbounded,
+            },
+        )
+    }
+
+    /// Get total number of values in a score range
+    pub fn count_values_by_score_range(
+        &self,
+        min: Bound<FloatOrd<f64>>,
+        max: Bound<FloatOrd<f64>>,
+    ) -> usize {
+        self.order.range(Self::convert_to_range(min, max)).count()
+    }
+
+    /// Get values in a score range
+    pub fn get_values_by_score_range(
+        &self,
+        min: Bound<FloatOrd<f64>>,
+        max: Bound<FloatOrd<f64>>,
+    ) -> Vec<Bytes> {
+        self.order
+            .range(Self::convert_to_range(min, max))
+            .map(|(k, _)| k.1.clone())
+            .collect()
+    }
+
+    /// Adds the position in the set to each value based on their score
+    #[inline]
+    fn update_value_position(&mut self) {
+        let mut i = 0;
+        for ((_, key), value) in self.order.iter_mut() {
+            *value = i;
+            if let Some(value) = self.set.get_mut(key) {
+                value.1 = i;
+            }
+            i += 1;
+        }
+    }
+
+    /// Return the position into the set based on their score
+    pub fn get_value_pos(&self, value: &Bytes) -> Option<usize> {
+        Some(self.set.get(value)?.1)
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    fn basic_usage() {
+        let mut set: SortedSet = SortedSet::new();
+        let mut op = IOption::default();
+        op.insert_policy = Some(IPolicy::NX);
+
+        assert_eq!(
+            set.insert(FloatOrd(1.0), "2".into(), &op),
+            IResult::Inserted
+        );
+        assert_eq!(
+            set.insert(FloatOrd(0.0), "3".into(), &op),
+            IResult::Inserted
+        );
+        assert_eq!(set.insert(FloatOrd(33.1), "3".into(), &op), IResult::NoOp);
+
+        op.insert_policy = None;
+        op.incr = true;
+        assert_eq!(set.insert(FloatOrd(2.0), "2".into(), &op), IResult::Updated);
+
+        //assert_eq!(vec![3, 2], set.get_values());
+        assert_eq!(Some(FloatOrd(3.0)), set.get_score(&"2".into()));
+        assert_eq!(Some(1), set.get_value_pos(&"2".into()));
+        assert_eq!(Some(0), set.get_value_pos(&"3".into()));
+        assert_eq!(None, set.get_value_pos(&"5".into()));
+    }
+}