Explorar o código

Added `sunion` and `sunionstore` support

Cesar Rodas %!s(int64=3) %!d(string=hai) anos
pai
achega
c4f107c9cc
Modificáronse 3 ficheiros con 84 adicións e 44 borrados
  1. 67 37
      src/cmd/set.rs
  2. 10 0
      src/dispatcher.rs
  3. 7 7
      src/value/checksum.rs

+ 67 - 37
src/cmd/set.rs

@@ -2,6 +2,22 @@ use crate::{connection::Connection, error::Error, value::Value};
 use bytes::Bytes;
 use std::collections::HashSet;
 
+fn store(conn: &Connection, key: &Bytes, values: &Vec<Value>) -> i64 {
+    #[allow(clippy::mutable_key_type)]
+    let mut x = HashSet::new();
+    let mut len = 0;
+
+    for val in values.iter() {
+        if let Value::Blob(blob) = val {
+            if x.insert(blob.clone()) {
+                len += 1;
+            }
+        }
+    }
+    conn.db().set(&key, x.into(), None);
+    len
+}
+
 async fn compare_sets<F1>(conn: &Connection, keys: &[Bytes], op: F1) -> Result<Value, Error>
 where
     F1: Fn(&mut HashSet<Bytes>, &HashSet<Bytes>) -> bool,
@@ -106,21 +122,7 @@ pub async fn sdiff(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
 
 pub async fn sdiffstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     if let Value::Array(values) = sdiff(conn, &args[1..]).await? {
-        #[allow(clippy::mutable_key_type)]
-        let mut x = HashSet::new();
-        let mut len = 0;
-
-        for val in values.iter() {
-            if let Value::Blob(blob) = val {
-                if x.insert(blob.clone()) {
-                    len += 1;
-                }
-            }
-        }
-
-        conn.db().set(&args[1], x.into(), None);
-
-        Ok(len.into())
+        Ok(store(conn, &args[1], &values).into())
     } else {
         Ok(0.into())
     }
@@ -151,21 +153,7 @@ pub async fn sintercard(conn: &Connection, args: &[Bytes]) -> Result<Value, Erro
 
 pub async fn sinterstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     if let Value::Array(values) = sinter(conn, &args[1..]).await? {
-        #[allow(clippy::mutable_key_type)]
-        let mut x = HashSet::new();
-        let mut len = 0;
-
-        for val in values.iter() {
-            if let Value::Blob(blob) = val {
-                if x.insert(blob.clone()) {
-                    len += 1;
-                }
-            }
-        }
-
-        conn.db().set(&args[1], x.into(), None);
-
-        Ok(len.into())
+        Ok(store(conn, &args[1], &values).into())
     } else {
         Ok(0.into())
     }
@@ -212,13 +200,7 @@ pub async fn smismember(conn: &Connection, args: &[Bytes]) -> Result<Value, Erro
                 let x = x.read();
                 Ok((&args[2..])
                     .iter()
-                    .map(|member| {
-                        if x.contains(member) {
-                            1
-                        } else {
-                            0
-                        }
-                    })
+                    .map(|member| if x.contains(member) { 1 } else { 0 })
                     .collect::<Vec<i32>>()
                     .into())
             }
@@ -228,6 +210,25 @@ pub async fn smismember(conn: &Connection, args: &[Bytes]) -> Result<Value, Erro
     )
 }
 
+pub async fn sunion(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    compare_sets(conn, &args[1..], |all_entries, elements| {
+        for element in elements.iter() {
+            all_entries.insert(element.clone());
+        }
+
+        true
+    })
+    .await
+}
+
+pub async fn sunionstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    if let Value::Array(values) = sunion(conn, &args[1..]).await? {
+        Ok(store(conn, &args[1], &values).into())
+    } else {
+        Ok(0.into())
+    }
+}
+
 #[cfg(test)]
 mod test {
     use crate::{
@@ -465,4 +466,33 @@ mod test {
             run_command(&c, &["smismember", "foo", "5", "6", "3"]).await
         );
     }
+
+    #[tokio::test]
+    async fn sunion() {
+        let c = create_connection();
+
+        assert_eq!(
+            run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
+            run_command(&c, &["scard", "1"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "2", "c", "x"]).await,
+            run_command(&c, &["scard", "2"]).await
+        );
+
+        assert_eq!(
+            run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
+            run_command(&c, &["scard", "3"]).await
+        );
+
+        assert_eq!(
+            6,
+            if let Ok(Value::Array(x)) = run_command(&c, &["sunion", "1", "2", "3"]).await {
+                x.len()
+            } else {
+                0
+            }
+        );
+    }
 }

+ 10 - 0
src/dispatcher.rs

@@ -73,6 +73,16 @@ dispatcher! {
             [""],
             -3,
         },
+        sunion {
+            cmd::set::sunion,
+            [""],
+            -2,
+        },
+        sunionstore {
+            cmd::set::sunionstore,
+            [""],
+            -2,
+        },
     },
     list {
         blpop {

+ 7 - 7
src/value/checksum.rs

@@ -4,13 +4,13 @@ use crc32fast::Hasher as Crc32Hasher;
 use std::hash::{Hash, Hasher};
 
 fn calculate_checksum(bytes: &Bytes) -> Option<u32> {
-        if bytes.len() < 1024 {
-            None
-        } else {
-            let mut hasher = Crc32Hasher::new();
-            hasher.update(bytes);
-            Some(hasher.finalize())
-        }
+    if bytes.len() < 1024 {
+        None
+    } else {
+        let mut hasher = Crc32Hasher::new();
+        hasher.update(bytes);
+        Some(hasher.finalize())
+    }
 }
 
 pub struct Ref<'a> {