Jelajahi Sumber

Added diff/inter methods

* sdiff
* sdiffstore
* sinter
* sintercard
* sinterstore
Cesar Rodas 3 tahun lalu
induk
melakukan
697febbc08
2 mengubah file dengan 296 tambahan dan 4 penghapusan
  1. 271 4
      src/cmd/set.rs
  2. 25 0
      src/dispatcher.rs

+ 271 - 4
src/cmd/set.rs

@@ -2,6 +2,48 @@ use crate::{connection::Connection, error::Error, value::checksum, value::Value}
 use bytes::Bytes;
 use std::collections::HashSet;
 
+async fn compare_sets<F1>(conn: &Connection, keys: &[Bytes], op: F1) -> Result<Value, Error>
+where
+    F1: Fn(&mut HashSet<checksum::Value>, &HashSet<checksum::Value>) -> bool,
+{
+    conn.db().get_map_or(
+        &keys[0],
+        |v| match v {
+            Value::Set(x) => {
+                #[allow(clippy::mutable_key_type)]
+                let mut all_entries = x.read().clone();
+                for key in keys[1..].iter() {
+                    let mut do_break = false;
+                    let _ = conn.db().get_map_or(
+                        key,
+                        |v| match v {
+                            Value::Set(x) => {
+                                if !op(&mut all_entries, &x.read()) {
+                                    do_break = true;
+                                }
+                                Ok(Value::Null)
+                            }
+                            _ => Err(Error::WrongType),
+                        },
+                        || Ok(Value::Null),
+                    )?;
+                    if do_break {
+                        break;
+                    }
+                }
+
+                Ok(all_entries
+                    .iter()
+                    .map(|entry| entry.clone_value())
+                    .collect::<Vec<Value>>()
+                    .into())
+            }
+            _ => Err(Error::WrongType),
+        },
+        || Ok(Value::Array(vec![])),
+    )
+}
+
 pub async fn sadd(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     conn.db().get_map_or(
         &args[1],
@@ -50,15 +92,91 @@ pub async fn scard(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     )
 }
 
+pub async fn sdiff(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    compare_sets(conn, &args[1..], |all_entries, elements| {
+        for element in elements.iter() {
+            if all_entries.contains(element) {
+                all_entries.remove(element);
+            }
+        }
+        true
+    })
+    .await
+}
+
+pub async fn sdiffstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    if let Value::Array(values) = sdiff(conn, &args[1..]).await? {
+        #[allow(clippy::mutable_key_type)]
+        let mut x = HashSet::new();
+        let mut len = 0;
+
+        for val in values.iter() {
+            if let Value::Blob(blob) = val {
+                if x.insert(checksum::Value::new(blob.clone())) {
+                    len += 1;
+                }
+            }
+        }
+
+        conn.db().set(&args[1], x.into(), None);
+
+        Ok(len.into())
+    } else {
+        Ok(0.into())
+    }
+}
+
+pub async fn sinter(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    compare_sets(conn, &args[1..], |all_entries, elements| {
+        all_entries.retain(|element| elements.contains(element));
+
+        for element in elements.iter() {
+            if !all_entries.contains(element) {
+                all_entries.remove(element);
+            }
+        }
+
+        !all_entries.is_empty()
+    })
+    .await
+}
+
+pub async fn sintercard(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    if let Ok(Value::Array(x)) = sinter(conn, args).await {
+        Ok((x.len() as i64).into())
+    } else {
+        Ok(0.into())
+    }
+}
+
+pub async fn sinterstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    if let Value::Array(values) = sinter(conn, &args[1..]).await? {
+        #[allow(clippy::mutable_key_type)]
+        let mut x = HashSet::new();
+        let mut len = 0;
+
+        for val in values.iter() {
+            if let Value::Blob(blob) = val {
+                if x.insert(checksum::Value::new(blob.clone())) {
+                    len += 1;
+                }
+            }
+        }
+
+        conn.db().set(&args[1], x.into(), None);
+
+        Ok(len.into())
+    } else {
+        Ok(0.into())
+    }
+}
+
 pub async fn sismember(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     conn.db().get_map_or(
         &args[1],
         |v| match v {
             Value::Set(x) => {
-                if x.read()
-                    .get(&checksum::Value::new(args[2].clone()))
-                    .is_some()
-                {
+                if x.read().contains(&checksum::Value::new(args[2].clone())) {
                     Ok(1.into())
                 } else {
                     Ok(0.into())
@@ -132,6 +250,155 @@ mod test {
     }
 
     #[tokio::test]
+    async fn sdiff() {
+        let c = create_connection();
+
+        assert_eq!(
+            run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
+            run_command(&c, &["scard", "1"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "2", "c"]).await,
+            run_command(&c, &["scard", "2"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
+            run_command(&c, &["scard", "3"]).await
+        );
+
+        match run_command(&c, &["sdiff", "1", "2", "3"]).await {
+            Ok(Value::Array(v)) => {
+                assert_eq!(2, v.len());
+                if v[0] == Value::Blob("b".into()) {
+                    assert_eq!(v[1], Value::Blob("d".into()));
+                } else {
+                    assert_eq!(v[1], Value::Blob("b".into()));
+                }
+            }
+            _ => unreachable!(),
+        };
+    }
+
+    #[tokio::test]
+    async fn sdiffstore() {
+        let c = create_connection();
+
+        assert_eq!(
+            run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
+            run_command(&c, &["scard", "1"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "2", "c"]).await,
+            run_command(&c, &["scard", "2"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
+            run_command(&c, &["scard", "3"]).await
+        );
+
+        assert_eq!(
+            Ok(Value::Integer(2)),
+            run_command(&c, &["sdiffstore", "4", "1", "2", "3"]).await
+        );
+
+        match run_command(&c, &["smembers", "4"]).await {
+            Ok(Value::Array(v)) => {
+                assert_eq!(2, v.len());
+                if v[0] == Value::Blob("b".into()) {
+                    assert_eq!(v[1], Value::Blob("d".into()));
+                } else {
+                    assert_eq!(v[1], Value::Blob("b".into()));
+                }
+            }
+            _ => unreachable!(),
+        };
+    }
+
+    #[tokio::test]
+    async fn sinter() {
+        let c = create_connection();
+
+        assert_eq!(
+            run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
+            run_command(&c, &["scard", "1"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "2", "c", "x"]).await,
+            run_command(&c, &["scard", "2"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
+            run_command(&c, &["scard", "3"]).await
+        );
+
+        assert_eq!(
+            Ok(Value::Array(vec![Value::Blob("c".into())])),
+            run_command(&c, &["sinter", "1", "2", "3"]).await
+        );
+    }
+
+    #[tokio::test]
+    async fn sintercard() {
+        let c = create_connection();
+
+        assert_eq!(
+            run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
+            run_command(&c, &["scard", "1"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "2", "c", "x"]).await,
+            run_command(&c, &["scard", "2"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
+            run_command(&c, &["scard", "3"]).await
+        );
+
+        assert_eq!(
+            Ok(Value::Integer(1)),
+            run_command(&c, &["sintercard", "1", "2", "3"]).await
+        );
+    }
+
+    #[tokio::test]
+    async fn sinterstore() {
+        let c = create_connection();
+
+        assert_eq!(
+            run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
+            run_command(&c, &["scard", "1"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "2", "c", "x"]).await,
+            run_command(&c, &["scard", "2"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
+            run_command(&c, &["scard", "3"]).await
+        );
+
+        assert_eq!(
+            Ok(Value::Integer(1)),
+            run_command(&c, &["sinterstore", "foo", "1", "2", "3"]).await
+        );
+
+        assert_eq!(
+            Ok(Value::Array(vec![Value::Blob("c".into())])),
+            run_command(&c, &["smembers", "foo"]).await
+        );
+    }
+
+    #[tokio::test]
     async fn sismember() {
         let c = create_connection();
 

+ 25 - 0
src/dispatcher.rs

@@ -33,6 +33,31 @@ dispatcher! {
             [""],
             2,
         },
+        sdiff {
+            cmd::set::sdiff,
+            [""],
+            -2,
+        },
+        sdiffstore {
+            cmd::set::sdiffstore,
+            [""],
+            -3,
+        },
+        sinter {
+            cmd::set::sinter,
+            [""],
+            -2,
+        },
+        sintercard {
+            cmd::set::sintercard,
+            [""],
+            -2,
+        },
+        sinterstore {
+            cmd::set::sinterstore,
+            [""],
+            -3,
+        },
         sismember {
             cmd::set::sismember,
             [""],