Parcourir la source

Keyset ID: fix deserialization edge-case, add unit tests

ok300 il y a 4 mois
Parent
commit
09b5a55239
2 fichiers modifiés avec 60 ajouts et 60 suppressions
  1. 1 1
      crates/cdk/Cargo.toml
  2. 59 59
      crates/cdk/src/nuts/nut02.rs

+ 1 - 1
crates/cdk/Cargo.toml

@@ -35,7 +35,7 @@ reqwest = { version = "0.12", default-features = false, features = [
 ], optional = true }
 serde = { version = "1", default-features = false, features = ["derive"] }
 serde_json = "1"
-serde_with = "3.1"
+serde_with = "3"
 tracing = { version = "0.1", default-features = false, features = ["attributes", "log"] }
 thiserror = "1"
 futures = { version = "0.3.28", default-features = false, optional = true }

+ 59 - 59
crates/cdk/src/nuts/nut02.rs

@@ -18,7 +18,7 @@ use bitcoin::hashes::Hash;
 use bitcoin::key::Secp256k1;
 #[cfg(feature = "mint")]
 use bitcoin::secp256k1;
-use serde::{Deserialize, Deserializer, Serialize};
+use serde::{Deserialize, Serialize};
 use serde_with::{serde_as, VecSkipError};
 use thiserror::Error;
 
@@ -86,10 +86,11 @@ impl fmt::Display for KeySetVersion {
 
 /// A keyset ID is an identifier for a specific keyset. It can be derived by
 /// anyone who knows the set of public keys of a mint. The keyset ID **CAN**
-/// be stored in a Cashu token such that the token can be used to identify
+/// be stored in a Cashu token such that the token can be used to identify
 /// which mint or keyset it was generated from.
-#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
-#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
+#[serde(into = "String", try_from = "String")]
+#[cfg_attr(feature = "swagger", derive(utoipa::ToSchema), schema(as = String))]
 pub struct Id {
     version: KeySetVersion,
     id: [u8; Self::BYTELEN],
@@ -130,17 +131,16 @@ impl fmt::Display for Id {
     }
 }
 
-impl FromStr for Id {
-    type Err = Error;
+impl TryFrom<String> for Id {
+    type Error = Error;
 
-    fn from_str(s: &str) -> Result<Self, Self::Err> {
-        // Check if the string length is valid
+    fn try_from(s: String) -> Result<Self, Self::Error> {
         if s.len() != 16 {
             return Err(Error::Length);
         }
 
         Ok(Self {
-            version: KeySetVersion::Version00,
+            version: KeySetVersion::from_byte(&hex::decode(&s[..2])?[0])?,
             id: hex::decode(&s[2..])?
                 .try_into()
                 .map_err(|_| Error::Length)?,
@@ -148,63 +148,29 @@ impl FromStr for Id {
     }
 }
 
-impl Serialize for Id {
-    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
-    where
-        S: serde::Serializer,
-    {
-        serializer.serialize_str(&self.to_string())
+impl FromStr for Id {
+    type Err = Error;
+
+    fn from_str(s: &str) -> Result<Self, Self::Err> {
+        Self::try_from(s.to_string())
     }
 }
 
-impl<'de> Deserialize<'de> for Id {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: Deserializer<'de>,
-    {
-        struct IdVisitor;
-
-        impl<'de> serde::de::Visitor<'de> for IdVisitor {
-            type Value = Id;
-
-            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
-                formatter.write_str("Expecting a 14 char hex string")
-            }
-
-            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
-            where
-                E: serde::de::Error,
-            {
-                Id::from_str(v).map_err(|e| match e {
-                    Error::Length => E::custom(format!(
-                        "Invalid Length: Expected {}, got {}:
-                        {}",
-                        Id::STRLEN,
-                        v.len(),
-                        v
-                    )),
-                    _ => E::custom(e),
-                })
-            }
-        }
-
-        deserializer.deserialize_str(IdVisitor)
+impl From<Id> for String {
+    fn from(value: Id) -> Self {
+        value.to_string()
     }
 }
 
 impl From<&Keys> for Id {
+    /// As per NUT-02:
+    ///   1. sort public keys by their amount in ascending order
+    ///   2. concatenate all public keys to one string
+    ///   3. HASH_SHA256 the concatenated public keys
+    ///   4. take the first 14 characters of the hex-encoded hash
+    ///   5. prefix it with a keyset ID version byte
     fn from(map: &Keys) -> Self {
-        // REVIEW: Is it 16 or 14 bytes
-        /* NUT-02
-            1 - sort public keys by their amount in ascending order
-            2 - concatenate all public keys to one string
-            3 - HASH_SHA256 the concatenated public keys
-            4 - take the first 14 characters of the hex-encoded hash
-            5 - prefix it with a keyset ID version byte
-        */
-
         let mut keys: Vec<(&AmountStr, &super::PublicKey)> = map.iter().collect();
-
         keys.sort_by_key(|(amt, _v)| *amt);
 
         let pubkeys_concat: Vec<u8> = keys
@@ -400,12 +366,14 @@ impl From<&MintKeys> for Id {
 
 #[cfg(test)]
 mod test {
-
     use std::str::FromStr;
 
+    use rand::RngCore;
+
     use super::{KeySetInfo, Keys, KeysetResponse};
-    use crate::nuts::nut02::Id;
+    use crate::nuts::nut02::{Error, Id};
     use crate::nuts::KeysResponse;
+    use crate::util::hex;
 
     const SHORT_KEYSET_ID: &str = "00456a94ab4e1c46";
     const SHORT_KEYSET: &str = r#"
@@ -547,4 +515,36 @@ mod test {
 
         assert_eq!(keys_response.keysets.len(), 2);
     }
+
+    fn generate_random_id() -> Id {
+        let mut rand_bytes = vec![0u8; 8];
+        rand::thread_rng().fill_bytes(&mut rand_bytes[1..]);
+        Id::from_bytes(&rand_bytes)
+            .unwrap_or_else(|e| panic!("Failed to create Id from {}: {e}", hex::encode(rand_bytes)))
+    }
+
+    #[test]
+    fn test_id_serialization() {
+        let id = generate_random_id();
+        let id_str = id.to_string();
+
+        assert!(id_str.chars().all(|c| c.is_ascii_hexdigit()));
+        assert_eq!(16, id_str.len());
+        assert_eq!(id_str.to_lowercase(), id_str);
+    }
+
+    #[test]
+    fn test_id_deserialization() {
+        let id_from_short_str = Id::from_str("00123");
+        assert!(matches!(id_from_short_str, Err(Error::Length)));
+
+        let id_from_non_hex_str = Id::from_str(&SHORT_KEYSET_ID.replace('a', "x"));
+        assert!(matches!(id_from_non_hex_str, Err(Error::HexError(_))));
+
+        let id_invalid_version = Id::from_str(&SHORT_KEYSET_ID.replace("00", "99"));
+        assert!(matches!(id_invalid_version, Err(Error::UnknownVersion)));
+
+        let id_from_uppercase = Id::from_str(&SHORT_KEYSET_ID.to_uppercase());
+        assert!(id_from_uppercase.is_ok());
+    }
 }