pubsub_server.rs 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. use crate::{connection::Connection, error::Error, value::Value};
  2. use bytes::Bytes;
  3. use glob::Pattern;
  4. use parking_lot::RwLock;
  5. use std::collections::HashMap;
  6. use tokio::sync::mpsc;
  7. type Sender = mpsc::UnboundedSender<Value>;
  8. type Subscription = HashMap<u128, Sender>;
  9. #[derive(Debug)]
  10. pub struct Pubsub {
  11. subscriptions: RwLock<HashMap<Bytes, Subscription>>,
  12. psubscriptions: RwLock<HashMap<Pattern, Subscription>>,
  13. number_of_psubscriptions: RwLock<i64>,
  14. }
  15. impl Pubsub {
  16. pub fn new() -> Self {
  17. Self {
  18. subscriptions: RwLock::new(HashMap::new()),
  19. psubscriptions: RwLock::new(HashMap::new()),
  20. number_of_psubscriptions: RwLock::new(0),
  21. }
  22. }
  23. pub fn channels(&self) -> Vec<Bytes> {
  24. self.subscriptions.read().keys().cloned().collect()
  25. }
  26. pub fn get_number_of_psubscribers(&self) -> i64 {
  27. *(self.number_of_psubscriptions.read())
  28. }
  29. pub fn get_number_of_subscribers(&self, channels: &[Bytes]) -> Vec<(Bytes, usize)> {
  30. let subscribers = self.subscriptions.read();
  31. let mut ret = vec![];
  32. for channel in channels.iter() {
  33. if let Some(subs) = subscribers.get(channel) {
  34. ret.push((channel.clone(), subs.len()));
  35. } else {
  36. ret.push((channel.clone(), 0));
  37. }
  38. }
  39. ret
  40. }
  41. pub fn psubscribe(&self, channels: &[Bytes], conn: &Connection) -> Result<(), Error> {
  42. let mut subscriptions = self.psubscriptions.write();
  43. for bytes_channel in channels.iter() {
  44. let channel = String::from_utf8_lossy(bytes_channel);
  45. let channel =
  46. Pattern::new(&channel).map_err(|_| Error::InvalidPattern(channel.to_string()))?;
  47. if let Some(subs) = subscriptions.get_mut(&channel) {
  48. subs.insert(conn.id(), conn.pubsub_client().sender());
  49. } else {
  50. let mut h = HashMap::new();
  51. h.insert(conn.id(), conn.pubsub_client().sender());
  52. subscriptions.insert(channel.clone(), h);
  53. }
  54. if !conn.pubsub_client().is_psubcribed() {
  55. let mut psubs = self.number_of_psubscriptions.write();
  56. conn.pubsub_client().make_psubcribed();
  57. *psubs += 1;
  58. }
  59. let _ = conn.pubsub_client().sender().send(
  60. vec![
  61. "psubscribe".into(),
  62. Value::Blob(bytes_channel.clone()),
  63. conn.pubsub_client().new_psubscription(&channel).into(),
  64. ]
  65. .into(),
  66. );
  67. }
  68. Ok(())
  69. }
  70. pub async fn publish(&self, channel: &Bytes, message: &Bytes) -> u32 {
  71. let mut i = 0;
  72. if let Some(subs) = self.subscriptions.read().get(channel) {
  73. for sender in subs.values() {
  74. let _ = sender.send(Value::Array(vec![
  75. "message".into(),
  76. Value::Blob(channel.clone()),
  77. Value::Blob(message.clone()),
  78. ]));
  79. i += 1;
  80. }
  81. }
  82. let str_channel = String::from_utf8_lossy(channel);
  83. for (pattern, subs) in self.psubscriptions.read().iter() {
  84. if !pattern.matches(&str_channel) {
  85. continue;
  86. }
  87. for sub in subs.values() {
  88. let _ = sub.send(Value::Array(vec![
  89. "pmessage".into(),
  90. pattern.as_str().into(),
  91. Value::Blob(channel.clone()),
  92. Value::Blob(message.clone()),
  93. ]));
  94. i += 1;
  95. }
  96. }
  97. i
  98. }
  99. pub fn punsubscribe(&self, channels: &[Pattern], conn: &Connection) -> u32 {
  100. let mut all_subs = self.psubscriptions.write();
  101. let conn_id = conn.id();
  102. let mut removed = 0;
  103. channels
  104. .iter()
  105. .map(|channel| {
  106. if let Some(subs) = all_subs.get_mut(channel) {
  107. if let Some(sender) = subs.remove(&conn_id) {
  108. let _ = sender.send(Value::Array(vec![
  109. "punsubscribe".into(),
  110. channel.as_str().into(),
  111. 1.into(),
  112. ]));
  113. removed += 1;
  114. }
  115. if subs.is_empty() {
  116. all_subs.remove(channel);
  117. }
  118. }
  119. })
  120. .for_each(drop);
  121. removed
  122. }
  123. pub fn subscribe(&self, channels: &[Bytes], conn: &Connection) {
  124. let mut subscriptions = self.subscriptions.write();
  125. channels
  126. .iter()
  127. .map(|channel| {
  128. if let Some(subs) = subscriptions.get_mut(channel) {
  129. subs.insert(conn.id(), conn.pubsub_client().sender());
  130. } else {
  131. let mut h = HashMap::new();
  132. h.insert(conn.id(), conn.pubsub_client().sender());
  133. subscriptions.insert(channel.clone(), h);
  134. }
  135. let _ = conn.pubsub_client().sender().send(
  136. vec![
  137. "subscribe".into(),
  138. Value::Blob(channel.clone()),
  139. conn.pubsub_client().new_subscription(channel).into(),
  140. ]
  141. .into(),
  142. );
  143. })
  144. .for_each(drop);
  145. }
  146. pub fn unsubscribe(&self, channels: &[Bytes], conn: &Connection) -> u32 {
  147. let mut all_subs = self.subscriptions.write();
  148. let conn_id = conn.id();
  149. let mut removed = 0;
  150. channels
  151. .iter()
  152. .map(|channel| {
  153. if let Some(subs) = all_subs.get_mut(channel) {
  154. if let Some(sender) = subs.remove(&conn_id) {
  155. let _ = sender.send(Value::Array(vec![
  156. "unsubscribe".into(),
  157. Value::Blob(channel.clone()),
  158. 1.into(),
  159. ]));
  160. removed += 1;
  161. }
  162. if subs.is_empty() {
  163. all_subs.remove(channel);
  164. }
  165. }
  166. })
  167. .for_each(drop);
  168. removed
  169. }
  170. }