Explorar el Código

Restore order (#1517)

* feat(wallet): handle out of order response

* feat: add mint check that response is in order
tsk hace 2 semanas
padre
commit
093f5764b1
Se han modificado 2 ficheros con 354 adiciones y 9 borrados
  1. 31 0
      crates/cdk/src/mint/mod.rs
  2. 323 9
      crates/cdk/src/wallet/mod.rs

+ 31 - 0
crates/cdk/src/mint/mod.rs

@@ -929,6 +929,14 @@ impl Mint {
             let mut outputs = Vec::with_capacity(output_len);
             let mut signatures = Vec::with_capacity(output_len);
 
+            // Build a position map to track original request order for verification
+            let position_map: HashMap<PublicKey, usize> = request
+                .outputs
+                .iter()
+                .enumerate()
+                .map(|(idx, output)| (output.blinded_secret, idx))
+                .collect();
+
             let blinded_message: Vec<PublicKey> =
                 request.outputs.iter().map(|b| b.blinded_secret).collect();
 
@@ -950,6 +958,29 @@ impl Mint {
                 }
             }
 
+            // Verify response outputs maintain the same relative order as the request
+            // This ensures the NUT-09 spec requirement that outputs[i] corresponds to signatures[i]
+            let mut last_position: Option<usize> = None;
+            for output in &outputs {
+                let current_position =
+                    position_map.get(&output.blinded_secret).ok_or_else(|| {
+                        tracing::error!("Restore response contains output not in original request");
+                        Error::Internal
+                    })?;
+
+                if let Some(last_pos) = last_position {
+                    if *current_position <= last_pos {
+                        tracing::error!(
+                            "Restore response outputs are out of order: position {} after {}",
+                            current_position,
+                            last_pos
+                        );
+                        return Err(Error::Internal);
+                    }
+                }
+                last_position = Some(*current_position);
+            }
+
             Ok(RestoreResponse {
                 outputs,
                 signatures: signatures.clone(),

+ 323 - 9
crates/cdk/src/wallet/mod.rs

@@ -493,37 +493,60 @@ impl Wallet {
                     continue;
                 }
 
+                // Build a map from blinded_secret to signature for O(1) lookup
+                // This ensures we match signatures to secrets correctly regardless of response order
+                let signature_map: HashMap<_, _> = response
+                    .outputs
+                    .iter()
+                    .zip(response.signatures.iter())
+                    .map(|(output, sig)| (output.blinded_secret, sig.clone()))
+                    .collect();
+
                 // Enumerate secrets to track their original index (which corresponds to counter value)
+                // and match signatures by blinded_secret to ensure correct pairing
                 let matched_secrets: Vec<_> = premint_secrets
                     .secrets
                     .iter()
                     .enumerate()
-                    .filter(|(_, p)| response.outputs.contains(&p.blinded_message))
+                    .filter_map(|(idx, p)| {
+                        signature_map
+                            .get(&p.blinded_message.blinded_secret)
+                            .map(|sig| (idx, p, sig.clone()))
+                    })
                     .collect();
 
                 // Update highest counter based on matched indices
-                if let Some(&(max_idx, _)) = matched_secrets.last() {
+                if let Some(&(max_idx, _, _)) = matched_secrets.last() {
                     let counter_value = start_counter + max_idx as u32;
                     highest_counter =
                         Some(highest_counter.map_or(counter_value, |c| c.max(counter_value)));
                 }
 
-                let premint_secrets: Vec<_> = matched_secrets.into_iter().map(|(_, p)| p).collect();
-
                 // the response outputs and premint secrets should be the same after filtering
                 // blinded messages the mint did not have signatures for
-                if response.outputs.len() != premint_secrets.len() {
+                if response.outputs.len() != matched_secrets.len() {
                     return Err(Error::InvalidMintResponse(format!(
                         "restore response outputs ({}) does not match premint secrets ({})",
                         response.outputs.len(),
-                        premint_secrets.len()
+                        matched_secrets.len()
                     )));
                 }
 
+                // Extract signatures, rs, and secrets in matching order
+                // Each tuple (idx, premint, signature) ensures correct pairing
                 let proofs = construct_proofs(
-                    response.signatures,
-                    premint_secrets.iter().map(|p| p.r.clone()).collect(),
-                    premint_secrets.iter().map(|p| p.secret.clone()).collect(),
+                    matched_secrets
+                        .iter()
+                        .map(|(_, _, sig)| sig.clone())
+                        .collect(),
+                    matched_secrets
+                        .iter()
+                        .map(|(_, p, _)| p.r.clone())
+                        .collect(),
+                    matched_secrets
+                        .iter()
+                        .map(|(_, p, _)| p.secret.clone())
+                        .collect(),
                     &keys,
                 )?;
 
@@ -774,3 +797,294 @@ impl Drop for Wallet {
         self.seed.zeroize();
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::nuts::{BlindSignature, BlindedMessage, PreMint, PreMintSecrets};
+    use crate::secret::Secret;
+
+    /// Test that restore signature matching works correctly when response is in order
+    #[test]
+    fn test_restore_signature_matching_in_order() {
+        // Create test data with 3 premint secrets
+        let keyset_id = Id::from_bytes(&[0u8; 8]).unwrap();
+
+        // Generate deterministic keys for testing
+        let secret1 = Secret::generate();
+        let secret2 = Secret::generate();
+        let secret3 = Secret::generate();
+
+        let (blinded1, r1) = crate::dhke::blind_message(&secret1.to_bytes(), None).unwrap();
+        let (blinded2, r2) = crate::dhke::blind_message(&secret2.to_bytes(), None).unwrap();
+        let (blinded3, r3) = crate::dhke::blind_message(&secret3.to_bytes(), None).unwrap();
+
+        let premint1 = PreMint {
+            blinded_message: BlindedMessage::new(Amount::from(1), keyset_id, blinded1),
+            secret: secret1.clone(),
+            r: r1.clone(),
+            amount: Amount::from(1),
+        };
+        let premint2 = PreMint {
+            blinded_message: BlindedMessage::new(Amount::from(2), keyset_id, blinded2),
+            secret: secret2.clone(),
+            r: r2.clone(),
+            amount: Amount::from(2),
+        };
+        let premint3 = PreMint {
+            blinded_message: BlindedMessage::new(Amount::from(4), keyset_id, blinded3),
+            secret: secret3.clone(),
+            r: r3.clone(),
+            amount: Amount::from(4),
+        };
+
+        let premint_secrets = PreMintSecrets {
+            secrets: vec![premint1.clone(), premint2.clone(), premint3.clone()],
+            keyset_id,
+        };
+
+        // Create mock signatures (just need the structure, not real signatures)
+        let sig1 = BlindSignature {
+            amount: Amount::from(1),
+            keyset_id,
+            c: blinded1, // Using blinded as placeholder for signature
+            dleq: None,
+        };
+        let sig2 = BlindSignature {
+            amount: Amount::from(2),
+            keyset_id,
+            c: blinded2,
+            dleq: None,
+        };
+        let sig3 = BlindSignature {
+            amount: Amount::from(4),
+            keyset_id,
+            c: blinded3,
+            dleq: None,
+        };
+
+        // Response in same order as request
+        let response_outputs = vec![
+            premint1.blinded_message.clone(),
+            premint2.blinded_message.clone(),
+            premint3.blinded_message.clone(),
+        ];
+        let response_signatures = vec![sig1.clone(), sig2.clone(), sig3.clone()];
+
+        // Apply the matching logic (same as in restore)
+        let signature_map: HashMap<_, _> = response_outputs
+            .iter()
+            .zip(response_signatures.iter())
+            .map(|(output, sig)| (output.blinded_secret, sig.clone()))
+            .collect();
+
+        let matched_secrets: Vec<_> = premint_secrets
+            .secrets
+            .iter()
+            .enumerate()
+            .filter_map(|(idx, p)| {
+                signature_map
+                    .get(&p.blinded_message.blinded_secret)
+                    .map(|sig| (idx, p, sig.clone()))
+            })
+            .collect();
+
+        // Verify all 3 matched
+        assert_eq!(matched_secrets.len(), 3);
+
+        // Verify correct pairing by checking amounts match
+        assert_eq!(matched_secrets[0].2.amount, Amount::from(1));
+        assert_eq!(matched_secrets[1].2.amount, Amount::from(2));
+        assert_eq!(matched_secrets[2].2.amount, Amount::from(4));
+
+        // Verify indices are preserved
+        assert_eq!(matched_secrets[0].0, 0);
+        assert_eq!(matched_secrets[1].0, 1);
+        assert_eq!(matched_secrets[2].0, 2);
+    }
+
+    /// Test that restore signature matching works correctly when response is OUT of order
+    /// This is the critical test that verifies the fix for TokenNotVerified
+    #[test]
+    fn test_restore_signature_matching_out_of_order() {
+        let keyset_id = Id::from_bytes(&[0u8; 8]).unwrap();
+
+        let secret1 = Secret::generate();
+        let secret2 = Secret::generate();
+        let secret3 = Secret::generate();
+
+        let (blinded1, r1) = crate::dhke::blind_message(&secret1.to_bytes(), None).unwrap();
+        let (blinded2, r2) = crate::dhke::blind_message(&secret2.to_bytes(), None).unwrap();
+        let (blinded3, r3) = crate::dhke::blind_message(&secret3.to_bytes(), None).unwrap();
+
+        let premint1 = PreMint {
+            blinded_message: BlindedMessage::new(Amount::from(1), keyset_id, blinded1),
+            secret: secret1.clone(),
+            r: r1.clone(),
+            amount: Amount::from(1),
+        };
+        let premint2 = PreMint {
+            blinded_message: BlindedMessage::new(Amount::from(2), keyset_id, blinded2),
+            secret: secret2.clone(),
+            r: r2.clone(),
+            amount: Amount::from(2),
+        };
+        let premint3 = PreMint {
+            blinded_message: BlindedMessage::new(Amount::from(4), keyset_id, blinded3),
+            secret: secret3.clone(),
+            r: r3.clone(),
+            amount: Amount::from(4),
+        };
+
+        let premint_secrets = PreMintSecrets {
+            secrets: vec![premint1.clone(), premint2.clone(), premint3.clone()],
+            keyset_id,
+        };
+
+        let sig1 = BlindSignature {
+            amount: Amount::from(1),
+            keyset_id,
+            c: blinded1,
+            dleq: None,
+        };
+        let sig2 = BlindSignature {
+            amount: Amount::from(2),
+            keyset_id,
+            c: blinded2,
+            dleq: None,
+        };
+        let sig3 = BlindSignature {
+            amount: Amount::from(4),
+            keyset_id,
+            c: blinded3,
+            dleq: None,
+        };
+
+        // Response in REVERSED order (simulating out-of-order response from mint)
+        let response_outputs = vec![
+            premint3.blinded_message.clone(), // index 2 first
+            premint1.blinded_message.clone(), // index 0 second
+            premint2.blinded_message.clone(), // index 1 third
+        ];
+        let response_signatures = vec![sig3.clone(), sig1.clone(), sig2.clone()];
+
+        // Apply the matching logic (same as in restore)
+        let signature_map: HashMap<_, _> = response_outputs
+            .iter()
+            .zip(response_signatures.iter())
+            .map(|(output, sig)| (output.blinded_secret, sig.clone()))
+            .collect();
+
+        let matched_secrets: Vec<_> = premint_secrets
+            .secrets
+            .iter()
+            .enumerate()
+            .filter_map(|(idx, p)| {
+                signature_map
+                    .get(&p.blinded_message.blinded_secret)
+                    .map(|sig| (idx, p, sig.clone()))
+            })
+            .collect();
+
+        // Verify all 3 matched
+        assert_eq!(matched_secrets.len(), 3);
+
+        // Critical: Even though response was out of order, signatures should be
+        // correctly paired with their corresponding premint secrets
+        // matched_secrets should be in premint order (0, 1, 2) with correct signatures
+        assert_eq!(matched_secrets[0].0, 0); // First premint (amount 1)
+        assert_eq!(matched_secrets[0].2.amount, Amount::from(1)); // Correct signature
+
+        assert_eq!(matched_secrets[1].0, 1); // Second premint (amount 2)
+        assert_eq!(matched_secrets[1].2.amount, Amount::from(2)); // Correct signature
+
+        assert_eq!(matched_secrets[2].0, 2); // Third premint (amount 4)
+        assert_eq!(matched_secrets[2].2.amount, Amount::from(4)); // Correct signature
+    }
+
+    /// Test that restore handles partial responses correctly
+    #[test]
+    fn test_restore_signature_matching_partial_response() {
+        let keyset_id = Id::from_bytes(&[0u8; 8]).unwrap();
+
+        let secret1 = Secret::generate();
+        let secret2 = Secret::generate();
+        let secret3 = Secret::generate();
+
+        let (blinded1, r1) = crate::dhke::blind_message(&secret1.to_bytes(), None).unwrap();
+        let (blinded2, r2) = crate::dhke::blind_message(&secret2.to_bytes(), None).unwrap();
+        let (blinded3, r3) = crate::dhke::blind_message(&secret3.to_bytes(), None).unwrap();
+
+        let premint1 = PreMint {
+            blinded_message: BlindedMessage::new(Amount::from(1), keyset_id, blinded1),
+            secret: secret1.clone(),
+            r: r1.clone(),
+            amount: Amount::from(1),
+        };
+        let premint2 = PreMint {
+            blinded_message: BlindedMessage::new(Amount::from(2), keyset_id, blinded2),
+            secret: secret2.clone(),
+            r: r2.clone(),
+            amount: Amount::from(2),
+        };
+        let premint3 = PreMint {
+            blinded_message: BlindedMessage::new(Amount::from(4), keyset_id, blinded3),
+            secret: secret3.clone(),
+            r: r3.clone(),
+            amount: Amount::from(4),
+        };
+
+        let premint_secrets = PreMintSecrets {
+            secrets: vec![premint1.clone(), premint2.clone(), premint3.clone()],
+            keyset_id,
+        };
+
+        let sig1 = BlindSignature {
+            amount: Amount::from(1),
+            keyset_id,
+            c: blinded1,
+            dleq: None,
+        };
+        let sig3 = BlindSignature {
+            amount: Amount::from(4),
+            keyset_id,
+            c: blinded3,
+            dleq: None,
+        };
+
+        // Response only has signatures for premint1 and premint3 (gap at premint2)
+        // Also out of order
+        let response_outputs = vec![
+            premint3.blinded_message.clone(),
+            premint1.blinded_message.clone(),
+        ];
+        let response_signatures = vec![sig3.clone(), sig1.clone()];
+
+        let signature_map: HashMap<_, _> = response_outputs
+            .iter()
+            .zip(response_signatures.iter())
+            .map(|(output, sig)| (output.blinded_secret, sig.clone()))
+            .collect();
+
+        let matched_secrets: Vec<_> = premint_secrets
+            .secrets
+            .iter()
+            .enumerate()
+            .filter_map(|(idx, p)| {
+                signature_map
+                    .get(&p.blinded_message.blinded_secret)
+                    .map(|sig| (idx, p, sig.clone()))
+            })
+            .collect();
+
+        // Only 2 should match
+        assert_eq!(matched_secrets.len(), 2);
+
+        // Verify correct pairing despite gap and out-of-order response
+        assert_eq!(matched_secrets[0].0, 0); // First premint (amount 1)
+        assert_eq!(matched_secrets[0].2.amount, Amount::from(1));
+
+        assert_eq!(matched_secrets[1].0, 2); // Third premint (amount 4), index 1 skipped
+        assert_eq!(matched_secrets[1].2.amount, Amount::from(4));
+    }
+}