Explorar el Código

Add better parsing.

Added a global Message that will parse all messages, either server
replies or client requests.
Cesar Rodas hace 2 años
padre
commit
ce649b542b

+ 4 - 84
crates/types/src/client/request.rs

@@ -1,9 +1,5 @@
 use crate::types;
-use serde::{
-    de::{self, Deserializer},
-    ser::{self, SerializeSeq, Serializer},
-    Deserialize, Serialize,
-};
+use serde::ser::{self, SerializeSeq, Serializer};
 
 #[derive(Debug, Clone, Default)]
 pub struct Request {
@@ -11,84 +7,6 @@ pub struct Request {
     pub filters: Vec<types::Filter>,
 }
 
-impl<'de> de::Deserialize<'de> for Request {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: Deserializer<'de>,
-    {
-        #[derive(Deserialize, Serialize, Debug)]
-        #[serde(untagged)]
-        enum StringOrFilter {
-            Filter(types::Filter),
-            String(String),
-        }
-
-        impl<'de> StringOrFilter {
-            pub fn to_filter<D>(&self, element: usize) -> Result<types::Filter, D::Error>
-            where
-                D: Deserializer<'de>,
-            {
-                match self {
-                    Self::Filter(f) => Ok(f.clone()),
-                    _ => Err(de::Error::custom(format!(
-                        "Expecting a filter in element {}, got a filter object",
-                        element,
-                    ))),
-                }
-            }
-
-            pub fn to_string<D>(&self, element: usize) -> Result<String, D::Error>
-            where
-                D: Deserializer<'de>,
-            {
-                match self {
-                    Self::String(string) => Ok(string.to_owned()),
-                    _ => Err(de::Error::custom(format!(
-                        "Expecting an string in element {}, got a filter object",
-                        element,
-                    ))),
-                }
-            }
-        }
-
-        let s: Vec<StringOrFilter> = Deserialize::deserialize(deserializer)?;
-        if s.len() < 2 {
-            return Err(de::Error::custom(
-                "Array too small, it must have at least 2 elements",
-            ));
-        }
-        let header = s[0].to_string::<D>(0)?;
-        let subscription_id = s[1]
-            .to_string::<D>(1)?
-            .as_str()
-            .try_into()
-            .map_err(|e: types::subscription_id::Error| de::Error::custom(e.to_string()))?;
-
-        if header != "REQ" {
-            return Err(de::Error::custom(format!(
-                "Invalid header, got {} and expected REQ",
-                header
-            )));
-        }
-
-        let mut index = 1;
-        Ok(Request {
-            subscription_id,
-            filters: if s.len() > 2 {
-                s[2..]
-                    .iter()
-                    .map(|filter| {
-                        index += 1;
-                        filter.to_filter::<D>(index)
-                    })
-                    .collect::<Result<Vec<types::Filter>, D::Error>>()?
-            } else {
-                vec![]
-            },
-        })
-    }
-}
-
 impl ser::Serialize for Request {
     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
     where
@@ -107,6 +25,7 @@ impl ser::Serialize for Request {
 #[cfg(test)]
 mod test {
     use super::*;
+    use crate::Message;
     use chrono::Utc;
 
     #[test]
@@ -132,7 +51,8 @@ mod test {
         };
 
         let serialized = serde_json::to_string(&r).expect("valid json string");
-        let obj: Request = serde_json::from_str(&serialized).expect("valid request");
+        let obj: Message = serde_json::from_str(&serialized).expect("valid request");
+        let obj = obj.get_request().expect("request");
         assert_eq!(r.subscription_id, obj.subscription_id);
         assert_eq!(r.filters.len(), obj.filters.len());
     }

+ 129 - 0
crates/types/src/lib.rs

@@ -1,2 +1,131 @@
+use serde::{
+    de::{self, Deserializer},
+    Deserialize, Serialize,
+};
+use thiserror::Error;
+
 pub mod client;
 pub mod types;
+
+#[derive(Debug, Error)]
+pub enum Error {}
+
+#[derive(Serialize, Debug, Clone)]
+pub enum Message {
+    Request(client::Request),
+    Close(types::SubscriptionId),
+    Notice(String),
+}
+
+impl Message {
+    pub fn get_request(&self) -> Option<&client::Request> {
+        match self {
+            Self::Request(x) => Some(x),
+            _ => None,
+        }
+    }
+
+    pub fn get_notice(&self) -> Option<&str> {
+        match self {
+            Self::Notice(x) => Some(x),
+            _ => None,
+        }
+    }
+
+    pub fn get_close_subscription_id(&self) -> Option<&types::SubscriptionId> {
+        match self {
+            Self::Close(x) => Some(x),
+            _ => None,
+        }
+    }
+}
+
+impl<'de> de::Deserialize<'de> for Message {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        let array: Vec<serde_json::Value> = Deserialize::deserialize(deserializer)?;
+        if array.is_empty() {
+            return Err(de::Error::custom(
+                "Invalid array length, expecting at least one",
+            ));
+        }
+
+        let tag = array
+            .get(0)
+            .ok_or_else(|| de::Error::custom("Invalid array length, expecting at least one"))?
+            .as_str()
+            .ok_or_else(|| de::Error::custom("Invalid type for element 0 of the array"))?;
+
+        match tag {
+            "NOTICE" => Ok(Self::Notice(
+                serde_json::from_value(
+                    array
+                        .get(1)
+                        .ok_or_else(|| de::Error::custom("Missing element 1 of the array"))?
+                        .clone(),
+                )
+                .map_err(|e: serde_json::Error| de::Error::custom(e.to_string()))?,
+            )),
+            "CLOSE" => Ok(Self::Close(
+                serde_json::from_value(
+                    array
+                        .get(1)
+                        .ok_or_else(|| de::Error::custom("Missing element 1 of the array"))?
+                        .clone(),
+                )
+                .map_err(|e: serde_json::Error| de::Error::custom(e.to_string()))?,
+            )),
+            "REQ" => {
+                let subscription_id = array
+                    .get(1)
+                    .ok_or_else(|| de::Error::custom("Missing element 1 in the array"))?
+                    .as_str()
+                    .ok_or_else(|| {
+                        de::Error::custom("Invalid type for element 1, expecting a string")
+                    })?
+                    .try_into()
+                    .map_err(|e: types::subscription_id::Error| {
+                        de::Error::custom(format!("Invalid subscription id: {}", e))
+                    })?;
+
+                Ok(Self::Request(client::Request {
+                    subscription_id,
+                    filters: if array.len() > 2 {
+                        serde_json::from_value::<Vec<types::Filter>>(serde_json::Value::Array(
+                            array[2..].to_owned(),
+                        ))
+                        .map_err(|e: serde_json::Error| de::Error::custom(e.to_string()))?
+                    } else {
+                        vec![]
+                    },
+                }))
+            }
+            tag => Err(de::Error::custom(format!("{} is not a support tag", tag))),
+        }
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    fn unsupported() {
+        let json = "[\"CLOSEX\", \"foo\"]";
+        let message: Result<Message, _> = serde_json::from_str(json);
+        assert!(message.is_err());
+    }
+
+    #[test]
+    fn close() {
+        let json = "[\"CLOSE\", \"foo\"]";
+        let message: Message = serde_json::from_str(json).expect("valid message");
+        let subscription_id = message
+            .get_close_subscription_id()
+            .expect("valid subscription_id");
+
+        assert_eq!("foo".to_owned(), subscription_id.to_string());
+    }
+}

+ 13 - 5
crates/types/src/types/addr.rs

@@ -78,17 +78,17 @@ impl Addr {
     }
 }
 
-impl TryFrom<&str> for Addr {
+impl TryFrom<String> for Addr {
     type Error = Error;
 
-    fn try_from(value: &str) -> Result<Self, Self::Error> {
-        if let Ok(bytes) = hex::decode(value) {
+    fn try_from(value: String) -> Result<Self, Self::Error> {
+        if let Ok(bytes) = hex::decode(&value) {
             return Ok(Self {
                 bytes,
                 typ: Type::Unknown,
             });
         }
-        let (hrp, bytes, _) = bech32::decode(value)?;
+        let (hrp, bytes, _) = bech32::decode(&value)?;
         let typ = match hrp.to_lowercase().as_str() {
             "npub" => Ok(Type::NPub),
             "nsec" => Ok(Type::NSec),
@@ -103,12 +103,20 @@ impl TryFrom<&str> for Addr {
     }
 }
 
+impl TryFrom<&str> for Addr {
+    type Error = Error;
+
+    fn try_from(value: &str) -> Result<Self, Self::Error> {
+        value.to_owned().try_into()
+    }
+}
+
 impl<'de> Deserialize<'de> for Addr {
     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
     where
         D: Deserializer<'de>,
     {
-        let s = <&str>::deserialize(deserializer)?;
+        let s = <String>::deserialize(deserializer)?;
         let t = s
             .try_into()
             .map_err(|e: Error| de::Error::custom(e.to_string()))?;

+ 11 - 4
crates/types/src/types/subscription_id.rs

@@ -21,7 +21,7 @@ impl<'de> Deserialize<'de> for SubscriptionId {
     where
         D: Deserializer<'de>,
     {
-        let s = <&str>::deserialize(deserializer)?;
+        let s = <String>::deserialize(deserializer)?;
         let t = s
             .try_into()
             .map_err(|e: Error| de::Error::custom(e.to_string()))?;
@@ -29,13 +29,20 @@ impl<'de> Deserialize<'de> for SubscriptionId {
     }
 }
 
-impl TryFrom<&str> for SubscriptionId {
+impl TryFrom<String> for SubscriptionId {
     type Error = Error;
-    fn try_from(s: &str) -> Result<Self, Self::Error> {
+    fn try_from(s: String) -> Result<Self, Self::Error> {
         if s.as_bytes().len() > 64 {
             return Err(Error::TooLong);
         }
-        Ok(SubscriptionId(s.into()))
+        Ok(SubscriptionId(s))
+    }
+}
+
+impl TryFrom<&str> for SubscriptionId {
+    type Error = Error;
+    fn try_from(s: &str) -> Result<Self, Self::Error> {
+        s.to_owned().try_into()
     }
 }