Răsfoiți Sursa

fix: check for addition overflow

thesimplekid 7 luni în urmă
părinte
comite
eb0ba7da43

+ 2 - 1
crates/cdk-integration-tests/tests/mint.rs

@@ -46,7 +46,8 @@ pub async fn test_mint_double_receive() -> Result<()> {
             &SendKind::default(),
             false,
         )
-        .await?;
+        .await
+        .unwrap();
 
     let mnemonic = Mnemonic::generate(12)?;
 

+ 51 - 4
crates/cdk/src/amount.rs

@@ -87,6 +87,21 @@ impl Amount {
         parts.sort();
         Ok(parts)
     }
+
+    /// Checked addition for Amount. Returns None if overflow occurs.
+    pub fn checked_add(self, other: Amount) -> Option<Amount> {
+        self.0.checked_add(other.0).map(Amount)
+    }
+
+    /// Try sum to check for overflow
+    pub fn try_sum<I>(iter: I) -> Result<Self, Error>
+    where
+        I: IntoIterator<Item = Self>,
+    {
+        iter.into_iter().try_fold(Amount::ZERO, |acc, x| {
+            acc.checked_add(x).ok_or(Error::AmountOverflow)
+        })
+    }
 }
 
 impl Default for Amount {
@@ -135,13 +150,13 @@ impl std::ops::Add for Amount {
     type Output = Amount;
 
     fn add(self, rhs: Amount) -> Self::Output {
-        Amount(self.0 + rhs.0)
+        Amount(self.0.checked_add(rhs.0).expect("Addition error"))
     }
 }
 
 impl std::ops::AddAssign for Amount {
     fn add_assign(&mut self, rhs: Self) {
-        self.0 += rhs.0;
+        self.0 = self.0.checked_add(rhs.0).expect("Addition error");
     }
 }
 
@@ -177,8 +192,10 @@ impl std::ops::Div for Amount {
 
 impl core::iter::Sum for Amount {
     fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
-        let sats: u64 = iter.map(|amt| amt.0).sum();
-        Amount::from(sats)
+        iter.fold(Amount::ZERO, |acc, x| {
+            acc.checked_add(x)
+                .unwrap_or_else(|| panic!("Addition overflow"))
+        })
     }
 }
 
@@ -288,4 +305,34 @@ mod tests {
 
         assert!(values.is_err())
     }
+
+    #[test]
+    #[should_panic]
+    fn test_amount_addition() {
+        let amount_one: Amount = u64::MAX.into();
+        let amount_two: Amount = 1.into();
+
+        let amounts = vec![amount_one, amount_two];
+
+        let _total: Amount = amounts.into_iter().sum();
+    }
+
+    #[test]
+    fn test_try_amount_addition() {
+        let amount_one: Amount = u64::MAX.into();
+        let amount_two: Amount = 1.into();
+
+        let amounts = vec![amount_one, amount_two];
+
+        let total = Amount::try_sum(amounts);
+
+        assert!(total.is_err());
+        let amount_one: Amount = 10000.into();
+        let amount_two: Amount = 1.into();
+
+        let amounts = vec![amount_one, amount_two];
+        let total = Amount::try_sum(amounts).unwrap();
+
+        assert_eq!(total, 10001.into());
+    }
 }

+ 3 - 0
crates/cdk/src/error.rs

@@ -57,6 +57,9 @@ pub enum Error {
     /// Split Values must be less then or equal to amount
     #[error("Split Values must be less then or equal to amount")]
     SplitValuesGreater,
+    /// Amount overflow
+    #[error("Amount Overflow")]
+    AmountOverflow,
     /// Secp256k1 error
     #[error(transparent)]
     Secp256k1(#[from] bitcoin::secp256k1::Error),