mod.rs 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. use std::collections::HashMap;
  2. use axum::extract::ws::{Message, WebSocket};
  3. use cdk::nuts::nut17::NotificationPayload;
  4. use cdk::pub_sub::SubId;
  5. use cdk::ws::{
  6. notification_to_ws_message, NotificationInner, WsErrorBody, WsMessageOrResponse,
  7. WsMethodRequest, WsRequest,
  8. };
  9. use futures::StreamExt;
  10. use tokio::sync::mpsc;
  11. use uuid::Uuid;
  12. use crate::MintState;
  13. mod error;
  14. mod subscribe;
  15. mod unsubscribe;
  16. async fn process(
  17. context: &mut WsContext,
  18. body: WsRequest,
  19. ) -> Result<serde_json::Value, serde_json::Error> {
  20. let response = match body.method {
  21. WsMethodRequest::Subscribe(sub) => subscribe::handle(context, sub).await,
  22. WsMethodRequest::Unsubscribe(unsub) => unsubscribe::handle(context, unsub).await,
  23. }
  24. .map_err(WsErrorBody::from);
  25. let response: WsMessageOrResponse = (body.id, response).into();
  26. serde_json::to_value(response)
  27. }
  28. pub use error::WsError;
  29. pub struct WsContext {
  30. state: MintState,
  31. subscriptions: HashMap<SubId, tokio::task::JoinHandle<()>>,
  32. publisher: mpsc::Sender<(SubId, NotificationPayload<Uuid>)>,
  33. }
  34. /// Main function for websocket connections
  35. ///
  36. /// This function will handle all incoming websocket connections and keep them in their own loop.
  37. ///
  38. /// For simplicity sake this function will spawn tasks for each subscription and
  39. /// keep them in a hashmap, and will have a single subscriber for all of them.
  40. pub async fn main_websocket(mut socket: WebSocket, state: MintState) {
  41. let (publisher, mut subscriber) = mpsc::channel(100);
  42. let mut context = WsContext {
  43. state,
  44. subscriptions: HashMap::new(),
  45. publisher,
  46. };
  47. loop {
  48. tokio::select! {
  49. Some((sub_id, payload)) = subscriber.recv() => {
  50. if !context.subscriptions.contains_key(&sub_id) {
  51. // It may be possible an incoming message has come from a dropped Subscriptions that has not yet been
  52. // unsubscribed from the subscription manager, just ignore it.
  53. continue;
  54. }
  55. let notification = notification_to_ws_message(NotificationInner {
  56. sub_id,
  57. payload,
  58. });
  59. let message = match serde_json::to_string(&notification) {
  60. Ok(message) => message,
  61. Err(err) => {
  62. tracing::error!("Could not serialize notification: {}", err);
  63. continue;
  64. }
  65. };
  66. if let Err(err)= socket.send(Message::Text(message.into())).await {
  67. tracing::error!("Could not send websocket message: {}", err);
  68. break;
  69. }
  70. }
  71. Some(Ok(Message::Text(text))) = socket.next() => {
  72. let request = match serde_json::from_str::<WsRequest>(&text) {
  73. Ok(request) => request,
  74. Err(err) => {
  75. tracing::error!("Could not parse request: {}", err);
  76. continue;
  77. }
  78. };
  79. match process(&mut context, request).await {
  80. Ok(result) => {
  81. if let Err(err) = socket
  82. .send(Message::Text(result.to_string().into()))
  83. .await
  84. {
  85. tracing::error!("Could not send request: {}", err);
  86. break;
  87. }
  88. }
  89. Err(err) => {
  90. tracing::error!("Error serializing response: {}", err);
  91. break;
  92. }
  93. }
  94. }
  95. else => {
  96. }
  97. }
  98. }
  99. }