Selaa lähdekoodia

Added spop and srandmember

Cesar Rodas 3 vuotta sitten
vanhempi
säilyke
fe8051c980
2 muutettua tiedostoa jossa 102 lisäystä ja 4 poistoa
  1. 92 4
      src/cmd/set.rs
  2. 10 0
      src/dispatcher.rs

+ 92 - 4
src/cmd/set.rs

@@ -1,8 +1,9 @@
-use crate::{connection::Connection, error::Error, value::Value};
+use crate::{connection::Connection, error::Error, value::bytes_to_number, value::Value};
 use bytes::Bytes;
-use std::collections::HashSet;
+use rand::Rng;
+use std::{cmp::min, collections::HashSet};
 
-fn store(conn: &Connection, key: &Bytes, values: &Vec<Value>) -> i64 {
+fn store(conn: &Connection, key: &Bytes, values: &[Value]) -> i64 {
     #[allow(clippy::mutable_key_type)]
     let mut x = HashSet::new();
     let mut len = 0;
@@ -14,7 +15,7 @@ fn store(conn: &Connection, key: &Bytes, values: &Vec<Value>) -> i64 {
             }
         }
     }
-    conn.db().set(&key, x.into(), None);
+    conn.db().set(key, x.into(), None);
     len
 }
 
@@ -210,6 +211,67 @@ pub async fn smismember(conn: &Connection, args: &[Bytes]) -> Result<Value, Erro
     )
 }
 
+pub async fn spop(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    let rand = srandmember(conn, args).await?;
+    conn.db().get_map_or(
+        &args[1],
+        |v| match v {
+            Value::Set(x) => {
+                let mut x = x.write();
+                match &rand {
+                    Value::Blob(value) => {
+                        x.remove(value);
+                    }
+                    Value::Array(values) => {
+                        for value in values.iter() {
+                            if let Value::Blob(value) = value {
+                                x.remove(value);
+                            }
+                        }
+                    }
+                    _ => unreachable!(),
+                };
+                Ok(rand)
+            }
+            _ => Err(Error::WrongType),
+        },
+        || Ok(0.into()),
+    )
+}
+
+pub async fn srandmember(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
+    conn.db().get_map_or(
+        &args[1],
+        |v| match v {
+            Value::Set(x) => {
+                let mut rng = rand::thread_rng();
+                let set = x.read();
+
+                let mut items = set
+                    .iter()
+                    .map(|x| (x, rng.gen()))
+                    .collect::<Vec<(&Bytes, i128)>>();
+
+                items.sort_by(|a, b| a.1.cmp(&b.1));
+
+                if args.len() == 2 {
+                    let item = items[0].0.clone();
+                    Ok(Value::Blob(item))
+                } else {
+                    let len: usize = min(items.len(), bytes_to_number(&args[2])?);
+                    Ok(items[0..len]
+                        .iter()
+                        .map(|item| Value::Blob(item.0.clone()))
+                        .collect::<Vec<Value>>()
+                        .into())
+                }
+            }
+            _ => Err(Error::WrongType),
+        },
+        || Ok(0.into()),
+    )
+}
+
 pub async fn sunion(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
     compare_sets(conn, &args[1..], |all_entries, elements| {
         for element in elements.iter() {
@@ -468,6 +530,32 @@ mod test {
     }
 
     #[tokio::test]
+    async fn spop() {
+        let c = create_connection();
+
+        assert_eq!(
+            run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
+            run_command(&c, &["scard", "1"]).await
+        );
+
+        let _ = run_command(&c, &["spop", "1"]).await;
+
+        assert_eq!(
+            Ok(Value::Integer(3)),
+            run_command(&c, &["scard", "1"]).await
+        );
+
+        if let Ok(Value::Array(x)) = run_command(&c, &["spop", "1", "2"]).await {
+            assert_eq!(2, x.len());
+        }
+
+        assert_eq!(
+            Ok(Value::Integer(1)),
+            run_command(&c, &["scard", "1"]).await
+        );
+    }
+
+    #[tokio::test]
     async fn sunion() {
         let c = create_connection();
 

+ 10 - 0
src/dispatcher.rs

@@ -73,6 +73,16 @@ dispatcher! {
             [""],
             -3,
         },
+        spop {
+            cmd::set::spop,
+            [""],
+            -2,
+        },
+        srandmember {
+            cmd::set::srandmember,
+            [""],
+            -2,
+        },
         sunion {
             cmd::set::sunion,
             [""],