Răsfoiți Sursa

feat: Update Id::try_from to return a u32 and remove redundant logic in nut13::derive_path_from_keyset_id (#452)

* fix: return u32 from existing Id::TryFrom and add lossless u64 versions

* remove TryFrom<u64> for Id and it's inverse

* fix: remove unit test and fix nut13::derive_path_from_keyset_id

* test: derive_path_from_keyset_id

* fix: convert Id::TryFrom to Id::From

* docs: comment calling out From<Id> for u32 as a one-way function
vnprc 4 luni în urmă
părinte
comite
f9d500e9a8
2 a modificat fișierele cu 44 adăugiri și 7 ștergeri
  1. 22 6
      crates/cdk/src/nuts/nut02.rs
  2. 22 1
      crates/cdk/src/nuts/nut13.rs

+ 22 - 6
crates/cdk/src/nuts/nut02.rs

@@ -112,16 +112,25 @@ impl Id {
             id: bytes[1..].try_into()?,
         })
     }
+
+    /// [`Id`] as bytes
+    pub fn as_bytes(&self) -> [u8; Self::BYTELEN + 1] {
+        let mut bytes = [0u8; Self::BYTELEN + 1];
+        bytes[0] = self.version.to_byte();
+        bytes[1..].copy_from_slice(&self.id);
+        bytes
+    }
 }
 
-impl TryFrom<Id> for u64 {
-    type Error = Error;
-    fn try_from(value: Id) -> Result<Self, Self::Error> {
-        let hex_bytes: [u8; 8] = value.to_bytes().try_into().map_err(|_| Error::Length)?;
+// Used to generate a compressed unique identifier as part of the NUT13 spec
+// This is a one-way function
+impl From<Id> for u32 {
+    fn from(value: Id) -> Self {
+        let hex_bytes: [u8; 8] = value.as_bytes();
 
         let int = u64::from_be_bytes(hex_bytes);
 
-        Ok(int % (2_u64.pow(31) - 1))
+        (int % (2_u64.pow(31) - 1)) as u32
     }
 }
 
@@ -490,11 +499,18 @@ mod test {
     fn test_to_int() {
         let id = Id::from_str("009a1f293253e41e").unwrap();
 
-        let id_int = u64::try_from(id).unwrap();
+        let id_int = u32::from(id);
         assert_eq!(864559728, id_int)
     }
 
     #[test]
+    fn test_id_from_invalid_byte_length() {
+        let three_bytes = [0x01, 0x02, 0x03];
+        let result = Id::from_bytes(&three_bytes);
+        assert!(result.is_err(), "Expected an invalid byte length error");
+    }
+
+    #[test]
     fn test_keyset_bytes() {
         let id = Id::from_str("009a1f293253e41e").unwrap();
 

+ 22 - 1
crates/cdk/src/nuts/nut13.rs

@@ -170,7 +170,8 @@ impl PreMintSecrets {
 }
 
 fn derive_path_from_keyset_id(id: Id) -> Result<DerivationPath, Error> {
-    let index = (u64::try_from(id)? % (2u64.pow(31) - 1)) as u32;
+    let index = u32::from(id);
+
     let keyset_child_number = ChildNumber::from_hardened_idx(index)?;
     Ok(DerivationPath::from(vec![
         ChildNumber::from_hardened_idx(129372)?,
@@ -184,6 +185,7 @@ mod tests {
     use std::str::FromStr;
 
     use bip39::Mnemonic;
+    use bitcoin::bip32::DerivationPath;
     use bitcoin::Network;
 
     use super::*;
@@ -232,4 +234,23 @@ mod tests {
             assert_eq!(r, SecretKey::from_hex(test_r).unwrap())
         }
     }
+
+    #[test]
+    fn test_derive_path_from_keyset_id() {
+        let test_cases = [
+            ("009a1f293253e41e", "m/129372'/0'/864559728'"),
+            ("0000000000000000", "m/129372'/0'/0'"),
+            ("00ffffffffffffff", "m/129372'/0'/33554431'"),
+        ];
+
+        for (id_hex, expected_path) in test_cases {
+            let id = Id::from_str(id_hex).unwrap();
+            let path = derive_path_from_keyset_id(id).unwrap();
+            assert_eq!(
+                DerivationPath::from_str(expected_path).unwrap(),
+                path,
+                "Path derivation failed for ID {id_hex}"
+            );
+        }
+    }
 }