set.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. use crate::{connection::Connection, error::Error, value::Value};
  2. use bytes::Bytes;
  3. use std::collections::HashSet;
  4. async fn compare_sets<F1>(conn: &Connection, keys: &[Bytes], op: F1) -> Result<Value, Error>
  5. where
  6. F1: Fn(&mut HashSet<Bytes>, &HashSet<Bytes>) -> bool,
  7. {
  8. conn.db().get_map_or(
  9. &keys[0],
  10. |v| match v {
  11. Value::Set(x) => {
  12. #[allow(clippy::mutable_key_type)]
  13. let mut all_entries = x.read().clone();
  14. for key in keys[1..].iter() {
  15. let mut do_break = false;
  16. let _ = conn.db().get_map_or(
  17. key,
  18. |v| match v {
  19. Value::Set(x) => {
  20. if !op(&mut all_entries, &x.read()) {
  21. do_break = true;
  22. }
  23. Ok(Value::Null)
  24. }
  25. _ => Err(Error::WrongType),
  26. },
  27. || Ok(Value::Null),
  28. )?;
  29. if do_break {
  30. break;
  31. }
  32. }
  33. Ok(all_entries
  34. .iter()
  35. .map(|entry| Value::Blob(entry.clone()))
  36. .collect::<Vec<Value>>()
  37. .into())
  38. }
  39. _ => Err(Error::WrongType),
  40. },
  41. || Ok(Value::Array(vec![])),
  42. )
  43. }
  44. pub async fn sadd(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
  45. conn.db().get_map_or(
  46. &args[1],
  47. |v| match v {
  48. Value::Set(x) => {
  49. let mut x = x.write();
  50. let mut len = 0;
  51. for val in (&args[2..]).iter() {
  52. if x.insert(val.clone()) {
  53. len += 1;
  54. }
  55. }
  56. Ok(len.into())
  57. }
  58. _ => Err(Error::WrongType),
  59. },
  60. || {
  61. #[allow(clippy::mutable_key_type)]
  62. let mut x = HashSet::new();
  63. let mut len = 0;
  64. for val in (&args[2..]).iter() {
  65. if x.insert(val.clone()) {
  66. len += 1;
  67. }
  68. }
  69. conn.db().set(&args[1], x.into(), None);
  70. Ok(len.into())
  71. },
  72. )
  73. }
  74. pub async fn scard(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
  75. conn.db().get_map_or(
  76. &args[1],
  77. |v| match v {
  78. Value::Set(x) => Ok((x.read().len() as i64).into()),
  79. _ => Err(Error::WrongType),
  80. },
  81. || Ok(0.into()),
  82. )
  83. }
  84. pub async fn sdiff(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
  85. compare_sets(conn, &args[1..], |all_entries, elements| {
  86. for element in elements.iter() {
  87. if all_entries.contains(element) {
  88. all_entries.remove(element);
  89. }
  90. }
  91. true
  92. })
  93. .await
  94. }
  95. pub async fn sdiffstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
  96. if let Value::Array(values) = sdiff(conn, &args[1..]).await? {
  97. #[allow(clippy::mutable_key_type)]
  98. let mut x = HashSet::new();
  99. let mut len = 0;
  100. for val in values.iter() {
  101. if let Value::Blob(blob) = val {
  102. if x.insert(blob.clone()) {
  103. len += 1;
  104. }
  105. }
  106. }
  107. conn.db().set(&args[1], x.into(), None);
  108. Ok(len.into())
  109. } else {
  110. Ok(0.into())
  111. }
  112. }
  113. pub async fn sinter(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
  114. compare_sets(conn, &args[1..], |all_entries, elements| {
  115. all_entries.retain(|element| elements.contains(element));
  116. for element in elements.iter() {
  117. if !all_entries.contains(element) {
  118. all_entries.remove(element);
  119. }
  120. }
  121. !all_entries.is_empty()
  122. })
  123. .await
  124. }
  125. pub async fn sintercard(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
  126. if let Ok(Value::Array(x)) = sinter(conn, args).await {
  127. Ok((x.len() as i64).into())
  128. } else {
  129. Ok(0.into())
  130. }
  131. }
  132. pub async fn sinterstore(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
  133. if let Value::Array(values) = sinter(conn, &args[1..]).await? {
  134. #[allow(clippy::mutable_key_type)]
  135. let mut x = HashSet::new();
  136. let mut len = 0;
  137. for val in values.iter() {
  138. if let Value::Blob(blob) = val {
  139. if x.insert(blob.clone()) {
  140. len += 1;
  141. }
  142. }
  143. }
  144. conn.db().set(&args[1], x.into(), None);
  145. Ok(len.into())
  146. } else {
  147. Ok(0.into())
  148. }
  149. }
  150. pub async fn sismember(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
  151. conn.db().get_map_or(
  152. &args[1],
  153. |v| match v {
  154. Value::Set(x) => {
  155. if x.read().contains(&args[2]) {
  156. Ok(1.into())
  157. } else {
  158. Ok(0.into())
  159. }
  160. }
  161. _ => Err(Error::WrongType),
  162. },
  163. || Ok(0.into()),
  164. )
  165. }
  166. pub async fn smembers(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
  167. conn.db().get_map_or(
  168. &args[1],
  169. |v| match v {
  170. Value::Set(x) => Ok(x
  171. .read()
  172. .iter()
  173. .map(|x| Value::Blob(x.clone()))
  174. .collect::<Vec<Value>>()
  175. .into()),
  176. _ => Err(Error::WrongType),
  177. },
  178. || Ok(Value::Array(vec![])),
  179. )
  180. }
  181. pub async fn smismember(conn: &Connection, args: &[Bytes]) -> Result<Value, Error> {
  182. conn.db().get_map_or(
  183. &args[1],
  184. |v| match v {
  185. Value::Set(x) => {
  186. let x = x.read();
  187. Ok((&args[2..])
  188. .iter()
  189. .map(|member| {
  190. if x.contains(member) {
  191. 1
  192. } else {
  193. 0
  194. }
  195. })
  196. .collect::<Vec<i32>>()
  197. .into())
  198. }
  199. _ => Err(Error::WrongType),
  200. },
  201. || Ok(0.into()),
  202. )
  203. }
  204. #[cfg(test)]
  205. mod test {
  206. use crate::{
  207. cmd::test::{create_connection, run_command},
  208. error::Error,
  209. value::Value,
  210. };
  211. #[tokio::test]
  212. async fn test_set_wrong_type() {
  213. let c = create_connection();
  214. let _ = run_command(&c, &["set", "foo", "1"]).await;
  215. assert_eq!(
  216. Err(Error::WrongType),
  217. run_command(&c, &["sadd", "foo", "1", "2", "3", "4", "5", "5"]).await,
  218. );
  219. }
  220. #[tokio::test]
  221. async fn sadd() {
  222. let c = create_connection();
  223. assert_eq!(
  224. Ok(Value::Integer(5)),
  225. run_command(&c, &["sadd", "foo", "1", "2", "3", "4", "5", "5"]).await,
  226. );
  227. assert_eq!(
  228. Ok(Value::Integer(1)),
  229. run_command(&c, &["sadd", "foo", "1", "2", "3", "4", "5", "6"]).await,
  230. );
  231. }
  232. #[tokio::test]
  233. async fn scard() {
  234. let c = create_connection();
  235. assert_eq!(
  236. run_command(&c, &["sadd", "foo", "1", "2", "3", "4", "5", "5"]).await,
  237. run_command(&c, &["scard", "foo"]).await
  238. );
  239. }
  240. #[tokio::test]
  241. async fn sdiff() {
  242. let c = create_connection();
  243. assert_eq!(
  244. run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
  245. run_command(&c, &["scard", "1"]).await
  246. );
  247. assert_eq!(
  248. run_command(&c, &["sadd", "2", "c"]).await,
  249. run_command(&c, &["scard", "2"]).await
  250. );
  251. assert_eq!(
  252. run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
  253. run_command(&c, &["scard", "3"]).await
  254. );
  255. match run_command(&c, &["sdiff", "1", "2", "3"]).await {
  256. Ok(Value::Array(v)) => {
  257. assert_eq!(2, v.len());
  258. if v[0] == Value::Blob("b".into()) {
  259. assert_eq!(v[1], Value::Blob("d".into()));
  260. } else {
  261. assert_eq!(v[1], Value::Blob("b".into()));
  262. }
  263. }
  264. _ => unreachable!(),
  265. };
  266. }
  267. #[tokio::test]
  268. async fn sdiffstore() {
  269. let c = create_connection();
  270. assert_eq!(
  271. run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
  272. run_command(&c, &["scard", "1"]).await
  273. );
  274. assert_eq!(
  275. run_command(&c, &["sadd", "2", "c"]).await,
  276. run_command(&c, &["scard", "2"]).await
  277. );
  278. assert_eq!(
  279. run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
  280. run_command(&c, &["scard", "3"]).await
  281. );
  282. assert_eq!(
  283. Ok(Value::Integer(2)),
  284. run_command(&c, &["sdiffstore", "4", "1", "2", "3"]).await
  285. );
  286. match run_command(&c, &["smembers", "4"]).await {
  287. Ok(Value::Array(v)) => {
  288. assert_eq!(2, v.len());
  289. if v[0] == Value::Blob("b".into()) {
  290. assert_eq!(v[1], Value::Blob("d".into()));
  291. } else {
  292. assert_eq!(v[1], Value::Blob("b".into()));
  293. }
  294. }
  295. _ => unreachable!(),
  296. };
  297. }
  298. #[tokio::test]
  299. async fn sinter() {
  300. let c = create_connection();
  301. assert_eq!(
  302. run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
  303. run_command(&c, &["scard", "1"]).await
  304. );
  305. assert_eq!(
  306. run_command(&c, &["sadd", "2", "c", "x"]).await,
  307. run_command(&c, &["scard", "2"]).await
  308. );
  309. assert_eq!(
  310. run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
  311. run_command(&c, &["scard", "3"]).await
  312. );
  313. assert_eq!(
  314. Ok(Value::Array(vec![Value::Blob("c".into())])),
  315. run_command(&c, &["sinter", "1", "2", "3"]).await
  316. );
  317. }
  318. #[tokio::test]
  319. async fn sintercard() {
  320. let c = create_connection();
  321. assert_eq!(
  322. run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
  323. run_command(&c, &["scard", "1"]).await
  324. );
  325. assert_eq!(
  326. run_command(&c, &["sadd", "2", "c", "x"]).await,
  327. run_command(&c, &["scard", "2"]).await
  328. );
  329. assert_eq!(
  330. run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
  331. run_command(&c, &["scard", "3"]).await
  332. );
  333. assert_eq!(
  334. Ok(Value::Integer(1)),
  335. run_command(&c, &["sintercard", "1", "2", "3"]).await
  336. );
  337. }
  338. #[tokio::test]
  339. async fn sinterstore() {
  340. let c = create_connection();
  341. assert_eq!(
  342. run_command(&c, &["sadd", "1", "a", "b", "c", "d"]).await,
  343. run_command(&c, &["scard", "1"]).await
  344. );
  345. assert_eq!(
  346. run_command(&c, &["sadd", "2", "c", "x"]).await,
  347. run_command(&c, &["scard", "2"]).await
  348. );
  349. assert_eq!(
  350. run_command(&c, &["sadd", "3", "a", "c", "e"]).await,
  351. run_command(&c, &["scard", "3"]).await
  352. );
  353. assert_eq!(
  354. Ok(Value::Integer(1)),
  355. run_command(&c, &["sinterstore", "foo", "1", "2", "3"]).await
  356. );
  357. assert_eq!(
  358. Ok(Value::Array(vec![Value::Blob("c".into())])),
  359. run_command(&c, &["smembers", "foo"]).await
  360. );
  361. }
  362. #[tokio::test]
  363. async fn sismember() {
  364. let c = create_connection();
  365. assert_eq!(
  366. run_command(&c, &["sadd", "foo", "1", "2", "3", "4", "5", "5"]).await,
  367. run_command(&c, &["scard", "foo"]).await
  368. );
  369. assert_eq!(
  370. Ok(Value::Integer(1)),
  371. run_command(&c, &["sismember", "foo", "5"]).await
  372. );
  373. assert_eq!(
  374. Ok(Value::Integer(0)),
  375. run_command(&c, &["sismember", "foo", "6"]).await
  376. );
  377. assert_eq!(
  378. Ok(Value::Integer(0)),
  379. run_command(&c, &["sismember", "foobar", "5"]).await
  380. );
  381. }
  382. #[tokio::test]
  383. async fn smismember() {
  384. let c = create_connection();
  385. assert_eq!(
  386. run_command(&c, &["sadd", "foo", "1", "2", "3", "4", "5", "5"]).await,
  387. run_command(&c, &["scard", "foo"]).await
  388. );
  389. assert_eq!(
  390. Ok(Value::Array(vec![
  391. Value::Integer(1),
  392. Value::Integer(0),
  393. Value::Integer(1),
  394. ])),
  395. run_command(&c, &["smismember", "foo", "5", "6", "3"]).await
  396. );
  397. }
  398. }