amount.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. //! CDK Amount
  2. //!
  3. //! Is any unit and will be treated as the unit of the wallet
  4. use std::cmp::Ordering;
  5. use std::fmt;
  6. use std::str::FromStr;
  7. use serde::{Deserialize, Deserializer, Serialize, Serializer};
  8. use thiserror::Error;
  9. use crate::nuts::CurrencyUnit;
  10. /// Amount Error
  11. #[derive(Debug, Error)]
  12. pub enum Error {
  13. /// Split Values must be less then or equal to amount
  14. #[error("Split Values must be less then or equal to amount")]
  15. SplitValuesGreater,
  16. /// Amount overflow
  17. #[error("Amount Overflow")]
  18. AmountOverflow,
  19. /// Cannot convert units
  20. #[error("Cannot convert units")]
  21. CannotConvertUnits,
  22. }
  23. /// Amount can be any unit
  24. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
  25. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  26. #[serde(transparent)]
  27. pub struct Amount(u64);
  28. impl Amount {
  29. /// Amount zero
  30. pub const ZERO: Amount = Amount(0);
  31. /// Split into parts that are powers of two
  32. pub fn split(&self) -> Vec<Self> {
  33. let sats = self.0;
  34. (0_u64..64)
  35. .rev()
  36. .filter_map(|bit| {
  37. let part = 1 << bit;
  38. ((sats & part) == part).then_some(Self::from(part))
  39. })
  40. .collect()
  41. }
  42. /// Split into parts that are powers of two by target
  43. pub fn split_targeted(&self, target: &SplitTarget) -> Result<Vec<Self>, Error> {
  44. let mut parts = match target {
  45. SplitTarget::None => self.split(),
  46. SplitTarget::Value(amount) => {
  47. if self.le(amount) {
  48. return Ok(self.split());
  49. }
  50. let mut parts_total = Amount::ZERO;
  51. let mut parts = Vec::new();
  52. // The powers of two that are need to create target value
  53. let parts_of_value = amount.split();
  54. while parts_total.lt(self) {
  55. for part in parts_of_value.iter().copied() {
  56. if (part + parts_total).le(self) {
  57. parts.push(part);
  58. } else {
  59. let amount_left = *self - parts_total;
  60. parts.extend(amount_left.split());
  61. }
  62. parts_total = Amount::try_sum(parts.clone().iter().copied())?;
  63. if parts_total.eq(self) {
  64. break;
  65. }
  66. }
  67. }
  68. parts
  69. }
  70. SplitTarget::Values(values) => {
  71. let values_total: Amount = Amount::try_sum(values.clone().into_iter())?;
  72. match self.cmp(&values_total) {
  73. Ordering::Equal => values.clone(),
  74. Ordering::Less => {
  75. return Err(Error::SplitValuesGreater);
  76. }
  77. Ordering::Greater => {
  78. let extra = *self - values_total;
  79. let mut extra_amount = extra.split();
  80. let mut values = values.clone();
  81. values.append(&mut extra_amount);
  82. values
  83. }
  84. }
  85. }
  86. };
  87. parts.sort();
  88. Ok(parts)
  89. }
  90. /// Checked addition for Amount. Returns None if overflow occurs.
  91. pub fn checked_add(self, other: Amount) -> Option<Amount> {
  92. self.0.checked_add(other.0).map(Amount)
  93. }
  94. /// Checked subtraction for Amount. Returns None if overflow occurs.
  95. pub fn checked_sub(self, other: Amount) -> Option<Amount> {
  96. self.0.checked_sub(other.0).map(Amount)
  97. }
  98. /// Try sum to check for overflow
  99. pub fn try_sum<I>(iter: I) -> Result<Self, Error>
  100. where
  101. I: IntoIterator<Item = Self>,
  102. {
  103. iter.into_iter().try_fold(Amount::ZERO, |acc, x| {
  104. acc.checked_add(x).ok_or(Error::AmountOverflow)
  105. })
  106. }
  107. }
  108. impl Default for Amount {
  109. fn default() -> Self {
  110. Amount::ZERO
  111. }
  112. }
  113. impl Default for &Amount {
  114. fn default() -> Self {
  115. &Amount::ZERO
  116. }
  117. }
  118. impl fmt::Display for Amount {
  119. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  120. if let Some(width) = f.width() {
  121. write!(f, "{:width$}", self.0, width = width)
  122. } else {
  123. write!(f, "{}", self.0)
  124. }
  125. }
  126. }
  127. impl From<u64> for Amount {
  128. fn from(value: u64) -> Self {
  129. Self(value)
  130. }
  131. }
  132. impl From<&u64> for Amount {
  133. fn from(value: &u64) -> Self {
  134. Self(*value)
  135. }
  136. }
  137. impl From<Amount> for u64 {
  138. fn from(value: Amount) -> Self {
  139. value.0
  140. }
  141. }
  142. impl AsRef<u64> for Amount {
  143. fn as_ref(&self) -> &u64 {
  144. &self.0
  145. }
  146. }
  147. impl std::ops::Add for Amount {
  148. type Output = Amount;
  149. fn add(self, rhs: Amount) -> Self::Output {
  150. Amount(self.0.checked_add(rhs.0).expect("Addition error"))
  151. }
  152. }
  153. impl std::ops::AddAssign for Amount {
  154. fn add_assign(&mut self, rhs: Self) {
  155. self.0 = self.0.checked_add(rhs.0).expect("Addition error");
  156. }
  157. }
  158. impl std::ops::Sub for Amount {
  159. type Output = Amount;
  160. fn sub(self, rhs: Amount) -> Self::Output {
  161. Amount(self.0 - rhs.0)
  162. }
  163. }
  164. impl std::ops::SubAssign for Amount {
  165. fn sub_assign(&mut self, other: Self) {
  166. self.0 -= other.0;
  167. }
  168. }
  169. impl std::ops::Mul for Amount {
  170. type Output = Self;
  171. fn mul(self, other: Self) -> Self::Output {
  172. Amount(self.0 * other.0)
  173. }
  174. }
  175. impl std::ops::Div for Amount {
  176. type Output = Self;
  177. fn div(self, other: Self) -> Self::Output {
  178. Amount(self.0 / other.0)
  179. }
  180. }
  181. /// String wrapper for an [Amount].
  182. ///
  183. /// It ser-/deserializes the inner [Amount] to a string, while at the same time using the [u64]
  184. /// value of the [Amount] for comparison and ordering. This helps automatically sort the keys of
  185. /// a [BTreeMap] when [AmountStr] is used as key.
  186. #[derive(Debug, Clone, PartialEq, Eq)]
  187. pub struct AmountStr(Amount);
  188. impl AmountStr {
  189. pub(crate) fn from(amt: Amount) -> Self {
  190. Self(amt)
  191. }
  192. }
  193. impl PartialOrd<Self> for AmountStr {
  194. fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
  195. Some(self.cmp(other))
  196. }
  197. }
  198. impl Ord for AmountStr {
  199. fn cmp(&self, other: &Self) -> Ordering {
  200. self.0.cmp(&other.0)
  201. }
  202. }
  203. impl<'de> Deserialize<'de> for AmountStr {
  204. fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
  205. where
  206. D: Deserializer<'de>,
  207. {
  208. let s = String::deserialize(deserializer)?;
  209. u64::from_str(&s)
  210. .map(Amount)
  211. .map(Self)
  212. .map_err(serde::de::Error::custom)
  213. }
  214. }
  215. impl Serialize for AmountStr {
  216. fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
  217. where
  218. S: Serializer,
  219. {
  220. serializer.serialize_str(&self.0.to_string())
  221. }
  222. }
  223. /// Kinds of targeting that are supported
  224. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
  225. pub enum SplitTarget {
  226. /// Default target; least amount of proofs
  227. #[default]
  228. None,
  229. /// Target amount for wallet to have most proofs that add up to value
  230. Value(Amount),
  231. /// Specific amounts to split into **MUST** equal amount being split
  232. Values(Vec<Amount>),
  233. }
  234. /// Msats in sat
  235. pub const MSAT_IN_SAT: u64 = 1000;
  236. /// Helper function to convert units
  237. pub fn to_unit<T>(
  238. amount: T,
  239. current_unit: &CurrencyUnit,
  240. target_unit: &CurrencyUnit,
  241. ) -> Result<Amount, Error>
  242. where
  243. T: Into<u64>,
  244. {
  245. let amount = amount.into();
  246. match (current_unit, target_unit) {
  247. (CurrencyUnit::Sat, CurrencyUnit::Sat) => Ok(amount.into()),
  248. (CurrencyUnit::Msat, CurrencyUnit::Msat) => Ok(amount.into()),
  249. (CurrencyUnit::Sat, CurrencyUnit::Msat) => Ok((amount * MSAT_IN_SAT).into()),
  250. (CurrencyUnit::Msat, CurrencyUnit::Sat) => Ok((amount / MSAT_IN_SAT).into()),
  251. (CurrencyUnit::Usd, CurrencyUnit::Usd) => Ok(amount.into()),
  252. (CurrencyUnit::Eur, CurrencyUnit::Eur) => Ok(amount.into()),
  253. _ => Err(Error::CannotConvertUnits),
  254. }
  255. }
  256. #[cfg(test)]
  257. mod tests {
  258. use super::*;
  259. #[test]
  260. fn test_split_amount() {
  261. assert_eq!(Amount::from(1).split(), vec![Amount::from(1)]);
  262. assert_eq!(Amount::from(2).split(), vec![Amount::from(2)]);
  263. assert_eq!(
  264. Amount::from(3).split(),
  265. vec![Amount::from(2), Amount::from(1)]
  266. );
  267. let amounts: Vec<Amount> = [8, 2, 1].iter().map(|a| Amount::from(*a)).collect();
  268. assert_eq!(Amount::from(11).split(), amounts);
  269. let amounts: Vec<Amount> = [128, 64, 32, 16, 8, 4, 2, 1]
  270. .iter()
  271. .map(|a| Amount::from(*a))
  272. .collect();
  273. assert_eq!(Amount::from(255).split(), amounts);
  274. }
  275. #[test]
  276. fn test_split_target_amount() {
  277. let amount = Amount(65);
  278. let split = amount
  279. .split_targeted(&SplitTarget::Value(Amount(32)))
  280. .unwrap();
  281. assert_eq!(vec![Amount(1), Amount(32), Amount(32)], split);
  282. let amount = Amount(150);
  283. let split = amount
  284. .split_targeted(&SplitTarget::Value(Amount::from(50)))
  285. .unwrap();
  286. assert_eq!(
  287. vec![
  288. Amount(2),
  289. Amount(2),
  290. Amount(2),
  291. Amount(16),
  292. Amount(16),
  293. Amount(16),
  294. Amount(32),
  295. Amount(32),
  296. Amount(32)
  297. ],
  298. split
  299. );
  300. let amount = Amount::from(63);
  301. let split = amount
  302. .split_targeted(&SplitTarget::Value(Amount::from(32)))
  303. .unwrap();
  304. assert_eq!(
  305. vec![
  306. Amount(1),
  307. Amount(2),
  308. Amount(4),
  309. Amount(8),
  310. Amount(16),
  311. Amount(32)
  312. ],
  313. split
  314. );
  315. }
  316. #[test]
  317. fn test_split_values() {
  318. let amount = Amount(10);
  319. let target = vec![Amount(2), Amount(4), Amount(4)];
  320. let split_target = SplitTarget::Values(target.clone());
  321. let values = amount.split_targeted(&split_target).unwrap();
  322. assert_eq!(target, values);
  323. let target = vec![Amount(2), Amount(4), Amount(4)];
  324. let split_target = SplitTarget::Values(vec![Amount(2), Amount(4)]);
  325. let values = amount.split_targeted(&split_target).unwrap();
  326. assert_eq!(target, values);
  327. let split_target = SplitTarget::Values(vec![Amount(2), Amount(10)]);
  328. let values = amount.split_targeted(&split_target);
  329. assert!(values.is_err())
  330. }
  331. #[test]
  332. #[should_panic]
  333. fn test_amount_addition() {
  334. let amount_one: Amount = u64::MAX.into();
  335. let amount_two: Amount = 1.into();
  336. let amounts = vec![amount_one, amount_two];
  337. let _total: Amount = Amount::try_sum(amounts).unwrap();
  338. }
  339. #[test]
  340. fn test_try_amount_addition() {
  341. let amount_one: Amount = u64::MAX.into();
  342. let amount_two: Amount = 1.into();
  343. let amounts = vec![amount_one, amount_two];
  344. let total = Amount::try_sum(amounts);
  345. assert!(total.is_err());
  346. let amount_one: Amount = 10000.into();
  347. let amount_two: Amount = 1.into();
  348. let amounts = vec![amount_one, amount_two];
  349. let total = Amount::try_sum(amounts).unwrap();
  350. assert_eq!(total, 10001.into());
  351. }
  352. #[test]
  353. fn test_amount_to_unit() {
  354. let amount = Amount::from(1000);
  355. let current_unit = CurrencyUnit::Sat;
  356. let target_unit = CurrencyUnit::Msat;
  357. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  358. assert_eq!(converted, 1000000.into());
  359. let amount = Amount::from(1000);
  360. let current_unit = CurrencyUnit::Msat;
  361. let target_unit = CurrencyUnit::Sat;
  362. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  363. assert_eq!(converted, 1.into());
  364. let amount = Amount::from(1);
  365. let current_unit = CurrencyUnit::Usd;
  366. let target_unit = CurrencyUnit::Usd;
  367. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  368. assert_eq!(converted, 1.into());
  369. let amount = Amount::from(1);
  370. let current_unit = CurrencyUnit::Eur;
  371. let target_unit = CurrencyUnit::Eur;
  372. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  373. assert_eq!(converted, 1.into());
  374. let amount = Amount::from(1);
  375. let current_unit = CurrencyUnit::Sat;
  376. let target_unit = CurrencyUnit::Eur;
  377. let converted = to_unit(amount, &current_unit, &target_unit);
  378. assert!(converted.is_err());
  379. }
  380. }