amount.rs 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. //! CDK Amount
  2. //!
  3. //! Is any unit and will be treated as the unit of the wallet
  4. use std::cmp::Ordering;
  5. use std::fmt;
  6. use std::str::FromStr;
  7. use serde::{Deserialize, Serialize};
  8. use thiserror::Error;
  9. use crate::nuts::CurrencyUnit;
  10. /// Amount Error
  11. #[derive(Debug, Error)]
  12. pub enum Error {
  13. /// Split Values must be less then or equal to amount
  14. #[error("Split Values must be less then or equal to amount")]
  15. SplitValuesGreater,
  16. /// Amount overflow
  17. #[error("Amount Overflow")]
  18. AmountOverflow,
  19. /// Cannot convert units
  20. #[error("Cannot convert units")]
  21. CannotConvertUnits,
  22. }
  23. /// Amount can be any unit
  24. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
  25. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  26. #[serde(transparent)]
  27. pub struct Amount(u64);
  28. impl FromStr for Amount {
  29. type Err = Error;
  30. fn from_str(s: &str) -> Result<Self, Self::Err> {
  31. let value = s.parse::<u64>().map_err(|_| Error::AmountOverflow)?;
  32. Ok(Amount(value))
  33. }
  34. }
  35. impl Amount {
  36. /// Amount zero
  37. pub const ZERO: Amount = Amount(0);
  38. /// Split into parts that are powers of two
  39. pub fn split(&self) -> Vec<Self> {
  40. let sats = self.0;
  41. (0_u64..64)
  42. .rev()
  43. .filter_map(|bit| {
  44. let part = 1 << bit;
  45. ((sats & part) == part).then_some(Self::from(part))
  46. })
  47. .collect()
  48. }
  49. /// Split into parts that are powers of two by target
  50. pub fn split_targeted(&self, target: &SplitTarget) -> Result<Vec<Self>, Error> {
  51. let mut parts = match target {
  52. SplitTarget::None => self.split(),
  53. SplitTarget::Value(amount) => {
  54. if self.le(amount) {
  55. return Ok(self.split());
  56. }
  57. let mut parts_total = Amount::ZERO;
  58. let mut parts = Vec::new();
  59. // The powers of two that are need to create target value
  60. let parts_of_value = amount.split();
  61. while parts_total.lt(self) {
  62. for part in parts_of_value.iter().copied() {
  63. if (part + parts_total).le(self) {
  64. parts.push(part);
  65. } else {
  66. let amount_left = *self - parts_total;
  67. parts.extend(amount_left.split());
  68. }
  69. parts_total = Amount::try_sum(parts.clone().iter().copied())?;
  70. if parts_total.eq(self) {
  71. break;
  72. }
  73. }
  74. }
  75. parts
  76. }
  77. SplitTarget::Values(values) => {
  78. let values_total: Amount = Amount::try_sum(values.clone().into_iter())?;
  79. match self.cmp(&values_total) {
  80. Ordering::Equal => values.clone(),
  81. Ordering::Less => {
  82. return Err(Error::SplitValuesGreater);
  83. }
  84. Ordering::Greater => {
  85. let extra = *self - values_total;
  86. let mut extra_amount = extra.split();
  87. let mut values = values.clone();
  88. values.append(&mut extra_amount);
  89. values
  90. }
  91. }
  92. }
  93. };
  94. parts.sort();
  95. Ok(parts)
  96. }
  97. /// Checked addition for Amount. Returns None if overflow occurs.
  98. pub fn checked_add(self, other: Amount) -> Option<Amount> {
  99. self.0.checked_add(other.0).map(Amount)
  100. }
  101. /// Checked subtraction for Amount. Returns None if overflow occurs.
  102. pub fn checked_sub(self, other: Amount) -> Option<Amount> {
  103. self.0.checked_sub(other.0).map(Amount)
  104. }
  105. /// Try sum to check for overflow
  106. pub fn try_sum<I>(iter: I) -> Result<Self, Error>
  107. where
  108. I: IntoIterator<Item = Self>,
  109. {
  110. iter.into_iter().try_fold(Amount::ZERO, |acc, x| {
  111. acc.checked_add(x).ok_or(Error::AmountOverflow)
  112. })
  113. }
  114. }
  115. impl Default for Amount {
  116. fn default() -> Self {
  117. Amount::ZERO
  118. }
  119. }
  120. impl Default for &Amount {
  121. fn default() -> Self {
  122. &Amount::ZERO
  123. }
  124. }
  125. impl fmt::Display for Amount {
  126. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  127. if let Some(width) = f.width() {
  128. write!(f, "{:width$}", self.0, width = width)
  129. } else {
  130. write!(f, "{}", self.0)
  131. }
  132. }
  133. }
  134. impl From<u64> for Amount {
  135. fn from(value: u64) -> Self {
  136. Self(value)
  137. }
  138. }
  139. impl From<&u64> for Amount {
  140. fn from(value: &u64) -> Self {
  141. Self(*value)
  142. }
  143. }
  144. impl From<Amount> for u64 {
  145. fn from(value: Amount) -> Self {
  146. value.0
  147. }
  148. }
  149. impl AsRef<u64> for Amount {
  150. fn as_ref(&self) -> &u64 {
  151. &self.0
  152. }
  153. }
  154. impl std::ops::Add for Amount {
  155. type Output = Amount;
  156. fn add(self, rhs: Amount) -> Self::Output {
  157. Amount(self.0.checked_add(rhs.0).expect("Addition error"))
  158. }
  159. }
  160. impl std::ops::AddAssign for Amount {
  161. fn add_assign(&mut self, rhs: Self) {
  162. self.0 = self.0.checked_add(rhs.0).expect("Addition error");
  163. }
  164. }
  165. impl std::ops::Sub for Amount {
  166. type Output = Amount;
  167. fn sub(self, rhs: Amount) -> Self::Output {
  168. Amount(self.0 - rhs.0)
  169. }
  170. }
  171. impl std::ops::SubAssign for Amount {
  172. fn sub_assign(&mut self, other: Self) {
  173. self.0 -= other.0;
  174. }
  175. }
  176. impl std::ops::Mul for Amount {
  177. type Output = Self;
  178. fn mul(self, other: Self) -> Self::Output {
  179. Amount(self.0 * other.0)
  180. }
  181. }
  182. impl std::ops::Div for Amount {
  183. type Output = Self;
  184. fn div(self, other: Self) -> Self::Output {
  185. Amount(self.0 / other.0)
  186. }
  187. }
  188. /// Kinds of targeting that are supported
  189. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
  190. pub enum SplitTarget {
  191. /// Default target; least amount of proofs
  192. #[default]
  193. None,
  194. /// Target amount for wallet to have most proofs that add up to value
  195. Value(Amount),
  196. /// Specific amounts to split into **MUST** equal amount being split
  197. Values(Vec<Amount>),
  198. }
  199. /// Msats in sat
  200. pub const MSAT_IN_SAT: u64 = 1000;
  201. /// Helper function to convert units
  202. pub fn to_unit<T>(
  203. amount: T,
  204. current_unit: &CurrencyUnit,
  205. target_unit: &CurrencyUnit,
  206. ) -> Result<Amount, Error>
  207. where
  208. T: Into<u64>,
  209. {
  210. let amount = amount.into();
  211. match (current_unit, target_unit) {
  212. (CurrencyUnit::Sat, CurrencyUnit::Sat) => Ok(amount.into()),
  213. (CurrencyUnit::Msat, CurrencyUnit::Msat) => Ok(amount.into()),
  214. (CurrencyUnit::Sat, CurrencyUnit::Msat) => Ok((amount * MSAT_IN_SAT).into()),
  215. (CurrencyUnit::Msat, CurrencyUnit::Sat) => Ok((amount / MSAT_IN_SAT).into()),
  216. (CurrencyUnit::Usd, CurrencyUnit::Usd) => Ok(amount.into()),
  217. (CurrencyUnit::Eur, CurrencyUnit::Eur) => Ok(amount.into()),
  218. _ => Err(Error::CannotConvertUnits),
  219. }
  220. }
  221. #[cfg(test)]
  222. mod tests {
  223. use super::*;
  224. #[test]
  225. fn test_split_amount() {
  226. assert_eq!(Amount::from(1).split(), vec![Amount::from(1)]);
  227. assert_eq!(Amount::from(2).split(), vec![Amount::from(2)]);
  228. assert_eq!(
  229. Amount::from(3).split(),
  230. vec![Amount::from(2), Amount::from(1)]
  231. );
  232. let amounts: Vec<Amount> = [8, 2, 1].iter().map(|a| Amount::from(*a)).collect();
  233. assert_eq!(Amount::from(11).split(), amounts);
  234. let amounts: Vec<Amount> = [128, 64, 32, 16, 8, 4, 2, 1]
  235. .iter()
  236. .map(|a| Amount::from(*a))
  237. .collect();
  238. assert_eq!(Amount::from(255).split(), amounts);
  239. }
  240. #[test]
  241. fn test_split_target_amount() {
  242. let amount = Amount(65);
  243. let split = amount
  244. .split_targeted(&SplitTarget::Value(Amount(32)))
  245. .unwrap();
  246. assert_eq!(vec![Amount(1), Amount(32), Amount(32)], split);
  247. let amount = Amount(150);
  248. let split = amount
  249. .split_targeted(&SplitTarget::Value(Amount::from(50)))
  250. .unwrap();
  251. assert_eq!(
  252. vec![
  253. Amount(2),
  254. Amount(2),
  255. Amount(2),
  256. Amount(16),
  257. Amount(16),
  258. Amount(16),
  259. Amount(32),
  260. Amount(32),
  261. Amount(32)
  262. ],
  263. split
  264. );
  265. let amount = Amount::from(63);
  266. let split = amount
  267. .split_targeted(&SplitTarget::Value(Amount::from(32)))
  268. .unwrap();
  269. assert_eq!(
  270. vec![
  271. Amount(1),
  272. Amount(2),
  273. Amount(4),
  274. Amount(8),
  275. Amount(16),
  276. Amount(32)
  277. ],
  278. split
  279. );
  280. }
  281. #[test]
  282. fn test_split_values() {
  283. let amount = Amount(10);
  284. let target = vec![Amount(2), Amount(4), Amount(4)];
  285. let split_target = SplitTarget::Values(target.clone());
  286. let values = amount.split_targeted(&split_target).unwrap();
  287. assert_eq!(target, values);
  288. let target = vec![Amount(2), Amount(4), Amount(4)];
  289. let split_target = SplitTarget::Values(vec![Amount(2), Amount(4)]);
  290. let values = amount.split_targeted(&split_target).unwrap();
  291. assert_eq!(target, values);
  292. let split_target = SplitTarget::Values(vec![Amount(2), Amount(10)]);
  293. let values = amount.split_targeted(&split_target);
  294. assert!(values.is_err())
  295. }
  296. #[test]
  297. #[should_panic]
  298. fn test_amount_addition() {
  299. let amount_one: Amount = u64::MAX.into();
  300. let amount_two: Amount = 1.into();
  301. let amounts = vec![amount_one, amount_two];
  302. let _total: Amount = Amount::try_sum(amounts).unwrap();
  303. }
  304. #[test]
  305. fn test_try_amount_addition() {
  306. let amount_one: Amount = u64::MAX.into();
  307. let amount_two: Amount = 1.into();
  308. let amounts = vec![amount_one, amount_two];
  309. let total = Amount::try_sum(amounts);
  310. assert!(total.is_err());
  311. let amount_one: Amount = 10000.into();
  312. let amount_two: Amount = 1.into();
  313. let amounts = vec![amount_one, amount_two];
  314. let total = Amount::try_sum(amounts).unwrap();
  315. assert_eq!(total, 10001.into());
  316. }
  317. #[test]
  318. fn test_amount_to_unit() {
  319. let amount = Amount::from(1000);
  320. let current_unit = CurrencyUnit::Sat;
  321. let target_unit = CurrencyUnit::Msat;
  322. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  323. assert_eq!(converted, 1000000.into());
  324. let amount = Amount::from(1000);
  325. let current_unit = CurrencyUnit::Msat;
  326. let target_unit = CurrencyUnit::Sat;
  327. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  328. assert_eq!(converted, 1.into());
  329. let amount = Amount::from(1);
  330. let current_unit = CurrencyUnit::Usd;
  331. let target_unit = CurrencyUnit::Usd;
  332. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  333. assert_eq!(converted, 1.into());
  334. let amount = Amount::from(1);
  335. let current_unit = CurrencyUnit::Eur;
  336. let target_unit = CurrencyUnit::Eur;
  337. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  338. assert_eq!(converted, 1.into());
  339. let amount = Amount::from(1);
  340. let current_unit = CurrencyUnit::Sat;
  341. let target_unit = CurrencyUnit::Eur;
  342. let converted = to_unit(amount, &current_unit, &target_unit);
  343. assert!(converted.is_err());
  344. }
  345. }