connection.rs 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. use crate::{get_id, Error};
  2. use futures_util::{SinkExt, StreamExt};
  3. use nostr_rs_types::{
  4. relayer::{Auth, ROk},
  5. types::Addr,
  6. Request, Response,
  7. };
  8. use parking_lot::RwLock;
  9. use std::collections::HashMap;
  10. use tokio::{
  11. net::TcpStream,
  12. sync::mpsc::{channel, Receiver, Sender},
  13. };
  14. #[allow(unused_imports)]
  15. use tokio_tungstenite::{accept_async, tungstenite::Message, WebSocketStream};
  16. #[derive(Debug)]
  17. pub struct Connection {
  18. #[allow(unused)]
  19. pub(crate) conn_id: u128,
  20. sender: Sender<Response>,
  21. subscriptions: RwLock<HashMap<String, u128>>,
  22. }
  23. const MAX_SUBSCRIPTIONS_BUFFER: usize = 100;
  24. impl Connection {
  25. #[cfg(test)]
  26. pub fn new_for_test() -> (Self, Receiver<Response>) {
  27. let (sender, receiver) = channel(MAX_SUBSCRIPTIONS_BUFFER);
  28. (
  29. Self {
  30. conn_id: 0,
  31. sender,
  32. subscriptions: RwLock::new(HashMap::new()),
  33. },
  34. receiver,
  35. )
  36. }
  37. pub async fn new(
  38. broadcast_request: Sender<(u128, Request)>,
  39. disconnection_notify: Option<Sender<u128>>,
  40. stream: TcpStream,
  41. ) -> Result<Self, Error> {
  42. let websocket = accept_async(stream).await?;
  43. let conn_id = get_id();
  44. let (sender, receiver) = channel(MAX_SUBSCRIPTIONS_BUFFER);
  45. Self::spawn(
  46. broadcast_request,
  47. websocket,
  48. receiver,
  49. disconnection_notify,
  50. conn_id,
  51. );
  52. let _ = sender.send(Auth::default().into()).await;
  53. Ok(Self {
  54. conn_id,
  55. sender,
  56. subscriptions: RwLock::new(HashMap::new()),
  57. })
  58. }
  59. #[allow(unused)]
  60. fn spawn(
  61. broadcast_request: Sender<(u128, Request)>,
  62. websocket: WebSocketStream<TcpStream>,
  63. mut receiver: Receiver<Response>,
  64. disconnection_notify: Option<Sender<u128>>,
  65. conn_id: u128,
  66. ) {
  67. tokio::spawn(async move {
  68. let mut _subscriptions: HashMap<String, (u128, Receiver<Response>)> = HashMap::new();
  69. let (mut writer, mut reader) = websocket.split();
  70. loop {
  71. tokio::select! {
  72. Some(msg) = receiver.recv() => {
  73. let msg = if let Ok(msg) = serde_json::to_string(&msg) {
  74. msg
  75. } else {
  76. continue;
  77. };
  78. if let Err(err) = writer.send(Message::Text(msg)).await {
  79. log::error!("Error sending message to client: {}", err);
  80. break;
  81. }
  82. }
  83. Some(msg) = reader.next() => {
  84. if let Ok(Message::Text(msg)) = msg {
  85. let msg: Result<Request, _> = serde_json::from_str(&msg);
  86. match msg {
  87. Ok(msg) => {
  88. let _ = broadcast_request.send((conn_id, msg)).await;
  89. },
  90. Err(err) => {
  91. log::error!("Error parsing message from client: {}", err);
  92. let reply: Response = ROk {
  93. id: Addr::default(),
  94. status: false,
  95. message: "Error parsing message".to_owned(),
  96. }.into();
  97. let reply = if let Ok(reply) = serde_json::to_string(&reply) {
  98. reply
  99. } else {
  100. continue;
  101. };
  102. if let Err(err) = writer.send(Message::Text(reply)).await {
  103. log::error!("Error sending message to client: {}", err);
  104. break;
  105. }
  106. }
  107. };
  108. }
  109. }
  110. else => {
  111. break;
  112. }
  113. }
  114. }
  115. if let Some(disconnection_notify) = disconnection_notify {
  116. let _ = disconnection_notify.try_send(conn_id);
  117. }
  118. });
  119. }
  120. #[inline]
  121. pub fn send(&self, response: Response) -> Result<(), Error> {
  122. self.sender
  123. .try_send(response)
  124. .map_err(|e| Error::TrySendError(Box::new(e)))
  125. }
  126. #[inline]
  127. pub fn get_sender(&self) -> Sender<Response> {
  128. self.sender.clone()
  129. }
  130. pub fn get_subscription_id(&self, id: &str) -> Option<u128> {
  131. let subscriptions = self.subscriptions.read();
  132. subscriptions.get(id).copied()
  133. }
  134. pub fn create_subscription(&self, id: String) -> (u128, Sender<Response>) {
  135. let mut subscriptions = self.subscriptions.write();
  136. let internal_id = subscriptions.entry(id).or_insert_with(get_id);
  137. (*internal_id, self.sender.clone())
  138. }
  139. }