浏览代码

`cashu` improve: add keyset id type

thesimplekid 1 年之前
父节点
当前提交
4e3268a7c7
共有 3 个文件被更改,包括 182 次插入48 次删除
  1. 1 0
      crates/cashu/Cargo.toml
  2. 12 0
      crates/cashu/src/nuts/nut01.rs
  3. 169 48
      crates/cashu/src/nuts/nut02.rs

+ 1 - 0
crates/cashu/Cargo.toml

@@ -28,6 +28,7 @@ serde = { workspace = true }
 serde_json = { workspace = true }
 url = { workspace = true }
 regex = "1.8.4"
+itertools = "0.11.0"
 
 [dev-dependencies]
 # tokio = {version = "1.27.0", features = ["rt", "macros"] }

+ 12 - 0
crates/cashu/src/nuts/nut01.rs

@@ -43,6 +43,12 @@ impl PublicKey {
     }
 }
 
+impl std::fmt::Display for PublicKey {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.write_str(&self.to_hex())
+    }
+}
+
 #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
 #[serde(transparent)]
 pub struct SecretKey(#[serde(with = "crate::serde_utils::serde_secret_key")] k256::SecretKey);
@@ -72,6 +78,7 @@ impl SecretKey {
 }
 
 /// Mint Keys [NUT-01]
+// TODO: CHange this to Amount type
 #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
 pub struct Keys(BTreeMap<u64, PublicKey>);
 
@@ -95,6 +102,11 @@ impl Keys {
             .map(|(k, v)| (k.to_owned(), hex::encode(v.0.to_sec1_bytes())))
             .collect()
     }
+
+    /// Iterate through the (`Amount`, `PublicKey`) entries in the Map
+    pub fn iter(&self) -> impl Iterator<Item = (&u64, &PublicKey)> {
+        self.0.iter()
+    }
 }
 
 impl From<mint::Keys> for Keys {

+ 169 - 48
crates/cashu/src/nuts/nut02.rs

@@ -6,20 +6,145 @@ use std::collections::HashSet;
 use base64::{engine::general_purpose, Engine as _};
 use bitcoin::hashes::sha256::Hash as Sha256;
 use bitcoin::hashes::Hash;
+use itertools::Itertools;
 use serde::{Deserialize, Serialize};
 
 use super::nut01::Keys;
 
+#[derive(Debug, PartialEq, Eq)]
+pub enum Error {
+    Base64(base64::DecodeError),
+    Length,
+}
+
+/// A keyset ID is an identifier for a specific keyset. It can be derived by
+/// anyone who knows the set of public keys of a mint. The keyset ID **CAN**
+/// be stored in a Cashu token such that the token can be used to identify
+/// which mint or keyset it was generated from.
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct Id([u8; Id::BYTES]);
+
+impl Id {
+    const BYTES: usize = 9;
+    const STRLEN: usize = 12;
+
+    pub fn try_from_base64(b64: &str) -> Result<Self, Error> {
+        use base64::{
+            engine::general_purpose::{STANDARD, URL_SAFE},
+            Engine as _,
+        };
+
+        if b64.len() != Self::STRLEN {
+            return Err(Error::Length);
+        }
+
+        if let Ok(bytes) = URL_SAFE.decode(b64) {
+            if bytes.len() == Self::BYTES {
+                return Ok(Self(
+                    <[u8; Self::BYTES]>::try_from(bytes.as_slice()).unwrap(),
+                ));
+            }
+        }
+
+        match STANDARD.decode(b64) {
+            Ok(bytes) if bytes.len() == Self::BYTES => Ok(Self(
+                <[u8; Self::BYTES]>::try_from(bytes.as_slice()).unwrap(),
+            )),
+            Ok(_) => Err(Error::Length),
+            Err(e) => Err(Error::Base64(e)),
+        }
+    }
+}
+
+impl std::fmt::Display for Id {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        let mut output = String::with_capacity(Self::STRLEN);
+        general_purpose::STANDARD.encode_string(&self.0.as_slice(), &mut output);
+        f.write_str(&output)
+    }
+}
+
+impl std::convert::TryFrom<String> for Id {
+    type Error = Error;
+    fn try_from(value: String) -> Result<Self, Self::Error> {
+        Id::try_from_base64(&value)
+    }
+}
+
+impl serde::ser::Serialize for Id {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: serde::Serializer,
+    {
+        serializer.serialize_str(&self.to_string())
+    }
+}
+
+impl<'de> serde::de::Deserialize<'de> for Id {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        struct IdVisitor;
+
+        impl<'de> serde::de::Visitor<'de> for IdVisitor {
+            type Value = Id;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
+                formatter.write_str("a 12-character Base64 string")
+            }
+
+            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
+            where
+                E: serde::de::Error,
+            {
+                Id::try_from_base64(v).map_err(|e| match e {
+                    Error::Length => E::custom(format!(
+                        "Invalid Length: Expected {}, got {}",
+                        Id::STRLEN,
+                        v.len()
+                    )),
+                    Error::Base64(e) => E::custom(e),
+                })
+            }
+        }
+
+        deserializer.deserialize_str(IdVisitor)
+    }
+}
+
+impl From<&Keys> for Id {
+    fn from(map: &Keys) -> Self {
+        /* NUT-02 § 2.2.2
+            1 - sort keyset by amount
+            2 - concatenate all (sorted) public keys to one string
+            3 - HASH_SHA256 the concatenated public keys
+            4 - take the first 12 characters of the base64-encoded hash
+        */
+
+        let pubkeys_concat = map
+            .iter()
+            .sorted_by(|(amt_a, _), (amt_b, _)| amt_a.cmp(amt_b))
+            .map(|(_, pubkey)| pubkey)
+            .join("");
+
+        let hash = Sha256::hash(pubkeys_concat.as_bytes());
+        let bytes = hash.to_byte_array();
+        // First 9 bytes of hash will encode as the first 12 Base64 characters later
+        Self(<[u8; Self::BYTES]>::try_from(&bytes[0..Self::BYTES]).unwrap())
+    }
+}
+
 /// Mint Keysets [NUT-02]
 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
 pub struct Response {
     /// set of public keys that the mint generates
-    pub keysets: HashSet<String>,
+    pub keysets: HashSet<Id>,
 }
 
 #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
 pub struct KeySet {
-    pub id: String,
+    pub id: Id,
     pub keys: Keys,
 }
 
@@ -32,43 +157,23 @@ impl From<mint::KeySet> for KeySet {
     }
 }
 
-impl Keys {
-    pub fn id(&self) -> String {
-        /* 1 - sort keyset by amount
-         * 2 - concatenate all (sorted) public keys to one string
-         * 3 - HASH_SHA256 the concatenated public keys
-         * 4 - take the first 12 characters of the hash
-         */
-
-        let pubkeys_concat = self
-            .keys()
-            .values()
-            .map(|pubkey| hex::encode(k256::PublicKey::from(pubkey).to_sec1_bytes()))
-            .collect::<Vec<String>>()
-            .join("");
-
-        let hash = general_purpose::STANDARD.encode(Sha256::hash(pubkeys_concat.as_bytes()));
-
-        hash[0..12].to_string()
-    }
-}
-
 pub mod mint {
     use std::collections::BTreeMap;
 
-    use base64::{engine::general_purpose, Engine as _};
     use bitcoin::hashes::sha256::Hash as Sha256;
     use bitcoin_hashes::Hash;
     use bitcoin_hashes::HashEngine;
+    use itertools::Itertools;
     use k256::SecretKey;
-    use serde::Deserialize;
     use serde::Serialize;
 
+    use super::Id;
+
     use crate::nuts::nut01::mint::{KeyPair, Keys};
 
     #[derive(Debug, Clone, PartialEq, Eq, Serialize)]
     pub struct KeySet {
-        pub id: String,
+        pub id: Id,
         pub keys: Keys,
     }
 
@@ -104,30 +209,44 @@ pub mod mint {
                 map.insert(amount, keypair);
             }
 
+            let keys = Keys(map);
+
             Self {
-                id: Self::id(&map),
-                keys: Keys(map),
+                id: (&keys).into(),
+                keys,
             }
         }
+    }
 
-        fn id(map: &BTreeMap<u64, KeyPair>) -> String {
-            /* 1 - sort keyset by amount
-             * 2 - concatenate all (sorted) public keys to one string
-             * 3 - HASH_SHA256 the concatenated public keys
-             * 4 - take the first 12 characters of the hash
-             */
-
-            let pubkeys_concat = map
-                .values()
-                .map(|keypair| {
-                    hex::encode(k256::PublicKey::from(&keypair.public_key).to_sec1_bytes())
-                })
-                .collect::<Vec<String>>()
-                .join("");
+    impl From<KeySet> for Id {
+        fn from(keyset: KeySet) -> Id {
+            let keys: super::KeySet = keyset.into();
 
-            let hash = general_purpose::STANDARD.encode(Sha256::hash(pubkeys_concat.as_bytes()));
+            Id::from(&keys.keys)
+        }
+    }
 
-            hash[0..12].to_string()
+    impl From<&Keys> for Id {
+        fn from(map: &Keys) -> Self {
+            /* NUT-02 § 2.2.2
+                1 - sort keyset by amount
+                2 - concatenate all (sorted) public keys to one string
+                3 - HASH_SHA256 the concatenated public keys
+                4 - take the first 12 characters of the base64-encoded hash
+            */
+
+            let keys: super::Keys = map.clone().into();
+
+            let pubkeys_concat = keys
+                .iter()
+                .sorted_by(|(amt_a, _), (amt_b, _)| amt_a.cmp(amt_b))
+                .map(|(_, pubkey)| pubkey)
+                .join("");
+
+            let hash = Sha256::hash(pubkeys_concat.as_bytes());
+            let bytes = hash.to_byte_array();
+            // First 9 bytes of hash will encode as the first 12 Base64 characters later
+            Self(<[u8; Self::BYTES]>::try_from(&bytes[0..Self::BYTES]).unwrap())
         }
     }
 }
@@ -135,6 +254,8 @@ pub mod mint {
 #[cfg(test)]
 mod test {
 
+    use crate::nuts::nut02::Id;
+
     use super::Keys;
 
     const SHORT_KEYSET_ID: &str = "esom3oyNLLit";
@@ -221,14 +342,14 @@ mod test {
     fn deserialization_and_id_generation() {
         let keys: Keys = serde_json::from_str(SHORT_KEYSET).unwrap();
 
-        let id = keys.id();
+        let id: Id = (&keys).into();
 
-        assert_eq!(id, SHORT_KEYSET_ID);
+        assert_eq!(id, Id::try_from_base64(SHORT_KEYSET_ID).unwrap());
 
         let keys: Keys = serde_json::from_str(KEYSET).unwrap();
 
-        let id = keys.id();
+        let id: Id = (&keys).into();
 
-        assert_eq!(id, KEYSET_ID);
+        assert_eq!(id, Id::try_from_base64(KEYSET_ID).unwrap());
     }
 }