|
@@ -2,6 +2,22 @@ use crate::{connection::Connection, error::Error, value::Value};
|
|
|
use bytes::Bytes;
|
|
|
use std::collections::HashSet;
|
|
|
|
|
|
+fn store(conn: &Connection, key: &Bytes, values: &Vec<Value>) -> i64 {
|
|
|
+ #[allow(clippy::mutable_key_type)]
|
|
|
+ let mut x = HashSet::new();
|
|
|
+ let mut len = 0;
|
|
|
+
|
|
|
+ for val in values.iter() {
|
|
|
+ if let Value::Blob(blob) = val {
|
|
|
+ if x.insert(blob.clone()) {
|
|
|
+ len += 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ conn.db().set(&key, x.into(), None);
|
|
|
+ len
|
|
|
+}
|
|
|
+
|
|
|
async fn compare_sets<F1>(conn: &Connection, keys: &[Bytes], op: F1) -> Result<Value, Error>
|
|
|
where
|
|
|
F1: Fn(&mut HashSet<Bytes>, &HashSet<Bytes>) -> bool,
|
|
@@ -106,21 +122,7 @@ pub async fn sdiff(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
|
|
|
pub async fn sdiffstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
if let Value::Array(values) = sdiff(conn, &args[1..]).await? {
|
|
|
- #[allow(clippy::mutable_key_type)]
|
|
|
- let mut x = HashSet::new();
|
|
|
- let mut len = 0;
|
|
|
-
|
|
|
- for val in values.iter() {
|
|
|
- if let Value::Blob(blob) = val {
|
|
|
- if x.insert(blob.clone()) {
|
|
|
- len += 1;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- conn.db().set(&args[1], x.into(), None);
|
|
|
-
|
|
|
- Ok(len.into())
|
|
|
+ Ok(store(conn, &args[1], &values).into())
|
|
|
} else {
|
|
|
Ok(0.into())
|
|
|
}
|
|
@@ -151,21 +153,7 @@ pub async fn sintercard(conn: &Connection, args: &[Bytes]) -> Result<Value, Erro
|
|
|
|
|
|
pub async fn sinterstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
if let Value::Array(values) = sinter(conn, &args[1..]).await? {
|
|
|
- #[allow(clippy::mutable_key_type)]
|
|
|
- let mut x = HashSet::new();
|
|
|
- let mut len = 0;
|
|
|
-
|
|
|
- for val in values.iter() {
|
|
|
- if let Value::Blob(blob) = val {
|
|
|
- if x.insert(blob.clone()) {
|
|
|
- len += 1;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- conn.db().set(&args[1], x.into(), None);
|
|
|
-
|
|
|
- Ok(len.into())
|
|
|
+ Ok(store(conn, &args[1], &values).into())
|
|
|
} else {
|
|
|
Ok(0.into())
|
|
|
}
|
|
@@ -212,13 +200,7 @@ pub async fn smismember(conn: &Connection, args: &[Bytes]) -> Result<Value, Erro
|
|
|
let x = x.read();
|
|
|
Ok((&args[2..])
|
|
|
.iter()
|
|
|
- .map(|member| {
|
|
|
- if x.contains(member) {
|
|
|
- 1
|
|
|
- } else {
|
|
|
- 0
|
|
|
- }
|
|
|
- })
|
|
|
+ .map(|member| if x.contains(member) { 1 } else { 0 })
|
|
|
.collect::<Vec<i32>>()
|
|
|
.into())
|
|
|
}
|
|
@@ -228,6 +210,25 @@ pub async fn smismember(conn: &Connection, args: &[Bytes]) -> Result<Value, Erro
|
|
|
)
|
|
|
}
|
|
|
|
|
|
+pub async fn sunion(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
+ compare_sets(conn, &args[1..], |all_entries, elements| {
|
|
|
+ for element in elements.iter() {
|
|
|
+ all_entries.insert(element.clone());
|
|
|
+ }
|
|
|
+
|
|
|
+ true
|
|
|
+ })
|
|
|
+ .await
|
|
|
+}
|
|
|
+
|
|
|
+pub async fn sunionstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
+ if let Value::Array(values) = sunion(conn, &args[1..]).await? {
|
|
|
+ Ok(store(conn, &args[1], &values).into())
|
|
|
+ } else {
|
|
|
+ Ok(0.into())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
#[cfg(test)]
|
|
|
mod test {
|
|
|
use crate::{
|
|
@@ -465,4 +466,33 @@ mod test {
|
|
|
run_command(&c, &["smismember", "foo", "5", "6", "3"]).await
|
|
|
);
|
|
|
}
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
+ async fn sunion() {
|
|
|
+ let c = create_connection();
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
|
|
|
+ run_command(&c, &["scard", "1"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "2", "c", "x"]).await,
|
|
|
+ run_command(&c, &["scard", "2"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
|
|
|
+ run_command(&c, &["scard", "3"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ 6,
|
|
|
+ if let Ok(Value::Array(x)) = run_command(&c, &["sunion", "1", "2", "3"]).await {
|
|
|
+ x.len()
|
|
|
+ } else {
|
|
|
+ 0
|
|
|
+ }
|
|
|
+ );
|
|
|
+ }
|
|
|
}
|