|
@@ -2,6 +2,48 @@ use crate::{connection::Connection, error::Error, value::checksum, value::Value}
|
|
|
use bytes::Bytes;
|
|
|
use std::collections::HashSet;
|
|
|
|
|
|
+async fn compare_sets<F1>(conn: &Connection, keys: &[Bytes], op: F1) -> Result<Value, Error>
|
|
|
+where
|
|
|
+ F1: Fn(&mut HashSet<checksum::Value>, &HashSet<checksum::Value>) -> bool,
|
|
|
+{
|
|
|
+ conn.db().get_map_or(
|
|
|
+ &keys[0],
|
|
|
+ |v| match v {
|
|
|
+ Value::Set(x) => {
|
|
|
+ #[allow(clippy::mutable_key_type)]
|
|
|
+ let mut all_entries = x.read().clone();
|
|
|
+ for key in keys[1..].iter() {
|
|
|
+ let mut do_break = false;
|
|
|
+ let _ = conn.db().get_map_or(
|
|
|
+ key,
|
|
|
+ |v| match v {
|
|
|
+ Value::Set(x) => {
|
|
|
+ if !op(&mut all_entries, &x.read()) {
|
|
|
+ do_break = true;
|
|
|
+ }
|
|
|
+ Ok(Value::Null)
|
|
|
+ }
|
|
|
+ _ => Err(Error::WrongType),
|
|
|
+ },
|
|
|
+ || Ok(Value::Null),
|
|
|
+ )?;
|
|
|
+ if do_break {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ Ok(all_entries
|
|
|
+ .iter()
|
|
|
+ .map(|entry| entry.clone_value())
|
|
|
+ .collect::<Vec<Value>>()
|
|
|
+ .into())
|
|
|
+ }
|
|
|
+ _ => Err(Error::WrongType),
|
|
|
+ },
|
|
|
+ || Ok(Value::Array(vec![])),
|
|
|
+ )
|
|
|
+}
|
|
|
+
|
|
|
pub async fn sadd(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
conn.db().get_map_or(
|
|
|
&args[1],
|
|
@@ -50,15 +92,91 @@ pub async fn scard(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
)
|
|
|
}
|
|
|
|
|
|
+pub async fn sdiff(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
+ compare_sets(conn, &args[1..], |all_entries, elements| {
|
|
|
+ for element in elements.iter() {
|
|
|
+ if all_entries.contains(element) {
|
|
|
+ all_entries.remove(element);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ true
|
|
|
+ })
|
|
|
+ .await
|
|
|
+}
|
|
|
+
|
|
|
+pub async fn sdiffstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
+ if let Value::Array(values) = sdiff(conn, &args[1..]).await? {
|
|
|
+ #[allow(clippy::mutable_key_type)]
|
|
|
+ let mut x = HashSet::new();
|
|
|
+ let mut len = 0;
|
|
|
+
|
|
|
+ for val in values.iter() {
|
|
|
+ if let Value::Blob(blob) = val {
|
|
|
+ if x.insert(checksum::Value::new(blob.clone())) {
|
|
|
+ len += 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ conn.db().set(&args[1], x.into(), None);
|
|
|
+
|
|
|
+ Ok(len.into())
|
|
|
+ } else {
|
|
|
+ Ok(0.into())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub async fn sinter(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
+ compare_sets(conn, &args[1..], |all_entries, elements| {
|
|
|
+ all_entries.retain(|element| elements.contains(element));
|
|
|
+
|
|
|
+ for element in elements.iter() {
|
|
|
+ if !all_entries.contains(element) {
|
|
|
+ all_entries.remove(element);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ !all_entries.is_empty()
|
|
|
+ })
|
|
|
+ .await
|
|
|
+}
|
|
|
+
|
|
|
+pub async fn sintercard(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
+ if let Ok(Value::Array(x)) = sinter(conn, args).await {
|
|
|
+ Ok((x.len() as i64).into())
|
|
|
+ } else {
|
|
|
+ Ok(0.into())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+pub async fn sinterstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
+ if let Value::Array(values) = sinter(conn, &args[1..]).await? {
|
|
|
+ #[allow(clippy::mutable_key_type)]
|
|
|
+ let mut x = HashSet::new();
|
|
|
+ let mut len = 0;
|
|
|
+
|
|
|
+ for val in values.iter() {
|
|
|
+ if let Value::Blob(blob) = val {
|
|
|
+ if x.insert(checksum::Value::new(blob.clone())) {
|
|
|
+ len += 1;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ conn.db().set(&args[1], x.into(), None);
|
|
|
+
|
|
|
+ Ok(len.into())
|
|
|
+ } else {
|
|
|
+ Ok(0.into())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
pub async fn sismember(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
|
|
|
conn.db().get_map_or(
|
|
|
&args[1],
|
|
|
|v| match v {
|
|
|
Value::Set(x) => {
|
|
|
- if x.read()
|
|
|
- .get(&checksum::Value::new(args[2].clone()))
|
|
|
- .is_some()
|
|
|
- {
|
|
|
+ if x.read().contains(&checksum::Value::new(args[2].clone())) {
|
|
|
Ok(1.into())
|
|
|
} else {
|
|
|
Ok(0.into())
|
|
@@ -132,6 +250,155 @@ mod test {
|
|
|
}
|
|
|
|
|
|
#[tokio::test]
|
|
|
+ async fn sdiff() {
|
|
|
+ let c = create_connection();
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
|
|
|
+ run_command(&c, &["scard", "1"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "2", "c"]).await,
|
|
|
+ run_command(&c, &["scard", "2"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
|
|
|
+ run_command(&c, &["scard", "3"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ match run_command(&c, &["sdiff", "1", "2", "3"]).await {
|
|
|
+ Ok(Value::Array(v)) => {
|
|
|
+ assert_eq!(2, v.len());
|
|
|
+ if v[0] == Value::Blob("b".into()) {
|
|
|
+ assert_eq!(v[1], Value::Blob("d".into()));
|
|
|
+ } else {
|
|
|
+ assert_eq!(v[1], Value::Blob("b".into()));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ _ => unreachable!(),
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
+ async fn sdiffstore() {
|
|
|
+ let c = create_connection();
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
|
|
|
+ run_command(&c, &["scard", "1"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "2", "c"]).await,
|
|
|
+ run_command(&c, &["scard", "2"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
|
|
|
+ run_command(&c, &["scard", "3"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ Ok(Value::Integer(2)),
|
|
|
+ run_command(&c, &["sdiffstore", "4", "1", "2", "3"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ match run_command(&c, &["smembers", "4"]).await {
|
|
|
+ Ok(Value::Array(v)) => {
|
|
|
+ assert_eq!(2, v.len());
|
|
|
+ if v[0] == Value::Blob("b".into()) {
|
|
|
+ assert_eq!(v[1], Value::Blob("d".into()));
|
|
|
+ } else {
|
|
|
+ assert_eq!(v[1], Value::Blob("b".into()));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ _ => unreachable!(),
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
+ async fn sinter() {
|
|
|
+ let c = create_connection();
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
|
|
|
+ run_command(&c, &["scard", "1"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "2", "c", "x"]).await,
|
|
|
+ run_command(&c, &["scard", "2"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
|
|
|
+ run_command(&c, &["scard", "3"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ Ok(Value::Array(vec![Value::Blob("c".into())])),
|
|
|
+ run_command(&c, &["sinter", "1", "2", "3"]).await
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
+ async fn sintercard() {
|
|
|
+ let c = create_connection();
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
|
|
|
+ run_command(&c, &["scard", "1"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "2", "c", "x"]).await,
|
|
|
+ run_command(&c, &["scard", "2"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
|
|
|
+ run_command(&c, &["scard", "3"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ Ok(Value::Integer(1)),
|
|
|
+ run_command(&c, &["sintercard", "1", "2", "3"]).await
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
+ async fn sinterstore() {
|
|
|
+ let c = create_connection();
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
|
|
|
+ run_command(&c, &["scard", "1"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "2", "c", "x"]).await,
|
|
|
+ run_command(&c, &["scard", "2"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
|
|
|
+ run_command(&c, &["scard", "3"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ Ok(Value::Integer(1)),
|
|
|
+ run_command(&c, &["sinterstore", "foo", "1", "2", "3"]).await
|
|
|
+ );
|
|
|
+
|
|
|
+ assert_eq!(
|
|
|
+ Ok(Value::Array(vec![Value::Blob("c".into())])),
|
|
|
+ run_command(&c, &["smembers", "foo"]).await
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
async fn sismember() {
|
|
|
let c = create_connection();
|
|
|
|