amount.rs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. //! CDK Amount
  2. //!
  3. //! Is any unit and will be treated as the unit of the wallet
  4. use std::cmp::Ordering;
  5. use std::fmt;
  6. use std::str::FromStr;
  7. use serde::{Deserialize, Serialize};
  8. use thiserror::Error;
  9. use crate::nuts::CurrencyUnit;
  10. /// Amount Error
  11. #[derive(Debug, Error)]
  12. pub enum Error {
  13. /// Split Values must be less then or equal to amount
  14. #[error("Split Values must be less then or equal to amount")]
  15. SplitValuesGreater,
  16. /// Amount overflow
  17. #[error("Amount Overflow")]
  18. AmountOverflow,
  19. /// Cannot convert units
  20. #[error("Cannot convert units")]
  21. CannotConvertUnits,
  22. /// Invalid amount
  23. #[error("Invalid Amount: {0}")]
  24. InvalidAmount(String),
  25. }
  26. /// Amount can be any unit
  27. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
  28. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  29. #[serde(transparent)]
  30. pub struct Amount(u64);
  31. impl FromStr for Amount {
  32. type Err = Error;
  33. fn from_str(s: &str) -> Result<Self, Self::Err> {
  34. let value = s
  35. .parse::<u64>()
  36. .map_err(|_| Error::InvalidAmount(s.to_owned()))?;
  37. Ok(Amount(value))
  38. }
  39. }
  40. impl Amount {
  41. /// Amount zero
  42. pub const ZERO: Amount = Amount(0);
  43. // Amount one
  44. pub const ONE: Amount = Amount(1);
  45. /// Split into parts that are powers of two
  46. pub fn split(&self) -> Vec<Self> {
  47. let sats = self.0;
  48. (0_u64..64)
  49. .rev()
  50. .filter_map(|bit| {
  51. let part = 1 << bit;
  52. ((sats & part) == part).then_some(Self::from(part))
  53. })
  54. .collect()
  55. }
  56. /// Split into parts that are powers of two by target
  57. pub fn split_targeted(&self, target: &SplitTarget) -> Result<Vec<Self>, Error> {
  58. let mut parts = match target {
  59. SplitTarget::None => self.split(),
  60. SplitTarget::Value(amount) => {
  61. if self.le(amount) {
  62. return Ok(self.split());
  63. }
  64. let mut parts_total = Amount::ZERO;
  65. let mut parts = Vec::new();
  66. // The powers of two that are need to create target value
  67. let parts_of_value = amount.split();
  68. while parts_total.lt(self) {
  69. for part in parts_of_value.iter().copied() {
  70. if (part + parts_total).le(self) {
  71. parts.push(part);
  72. } else {
  73. let amount_left = *self - parts_total;
  74. parts.extend(amount_left.split());
  75. }
  76. parts_total = Amount::try_sum(parts.clone().iter().copied())?;
  77. if parts_total.eq(self) {
  78. break;
  79. }
  80. }
  81. }
  82. parts
  83. }
  84. SplitTarget::Values(values) => {
  85. let values_total: Amount = Amount::try_sum(values.clone().into_iter())?;
  86. match self.cmp(&values_total) {
  87. Ordering::Equal => values.clone(),
  88. Ordering::Less => {
  89. return Err(Error::SplitValuesGreater);
  90. }
  91. Ordering::Greater => {
  92. let extra = *self - values_total;
  93. let mut extra_amount = extra.split();
  94. let mut values = values.clone();
  95. values.append(&mut extra_amount);
  96. values
  97. }
  98. }
  99. }
  100. };
  101. parts.sort();
  102. Ok(parts)
  103. }
  104. /// Splits amount into powers of two while accounting for the swap fee
  105. pub fn split_with_fee(&self, fee_ppk: u64) -> Result<Vec<Self>, Error> {
  106. let without_fee_amounts = self.split();
  107. let fee_ppk = fee_ppk * without_fee_amounts.len() as u64;
  108. let fee = Amount::from((fee_ppk + 999) / 1000);
  109. let new_amount = self.checked_add(fee).ok_or(Error::AmountOverflow)?;
  110. let split = new_amount.split();
  111. let split_fee_ppk = split.len() as u64 * fee_ppk;
  112. let split_fee = Amount::from((split_fee_ppk + 999) / 1000);
  113. if let Some(net_amount) = new_amount.checked_sub(split_fee) {
  114. if net_amount >= *self {
  115. return Ok(split);
  116. }
  117. }
  118. self.checked_add(Amount::ONE)
  119. .ok_or(Error::AmountOverflow)?
  120. .split_with_fee(fee_ppk)
  121. }
  122. /// Checked addition for Amount. Returns None if overflow occurs.
  123. pub fn checked_add(self, other: Amount) -> Option<Amount> {
  124. self.0.checked_add(other.0).map(Amount)
  125. }
  126. /// Checked subtraction for Amount. Returns None if overflow occurs.
  127. pub fn checked_sub(self, other: Amount) -> Option<Amount> {
  128. self.0.checked_sub(other.0).map(Amount)
  129. }
  130. /// Checked multiplication for Amount. Returns None if overflow occurs.
  131. pub fn checked_mul(self, other: Amount) -> Option<Amount> {
  132. self.0.checked_mul(other.0).map(Amount)
  133. }
  134. /// Checked division for Amount. Returns None if overflow occurs.
  135. pub fn checked_div(self, other: Amount) -> Option<Amount> {
  136. self.0.checked_div(other.0).map(Amount)
  137. }
  138. /// Try sum to check for overflow
  139. pub fn try_sum<I>(iter: I) -> Result<Self, Error>
  140. where
  141. I: IntoIterator<Item = Self>,
  142. {
  143. iter.into_iter().try_fold(Amount::ZERO, |acc, x| {
  144. acc.checked_add(x).ok_or(Error::AmountOverflow)
  145. })
  146. }
  147. }
  148. impl Default for Amount {
  149. fn default() -> Self {
  150. Amount::ZERO
  151. }
  152. }
  153. impl Default for &Amount {
  154. fn default() -> Self {
  155. &Amount::ZERO
  156. }
  157. }
  158. impl fmt::Display for Amount {
  159. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  160. if let Some(width) = f.width() {
  161. write!(f, "{:width$}", self.0, width = width)
  162. } else {
  163. write!(f, "{}", self.0)
  164. }
  165. }
  166. }
  167. impl From<u64> for Amount {
  168. fn from(value: u64) -> Self {
  169. Self(value)
  170. }
  171. }
  172. impl From<&u64> for Amount {
  173. fn from(value: &u64) -> Self {
  174. Self(*value)
  175. }
  176. }
  177. impl From<Amount> for u64 {
  178. fn from(value: Amount) -> Self {
  179. value.0
  180. }
  181. }
  182. impl AsRef<u64> for Amount {
  183. fn as_ref(&self) -> &u64 {
  184. &self.0
  185. }
  186. }
  187. impl std::ops::Add for Amount {
  188. type Output = Amount;
  189. fn add(self, rhs: Amount) -> Self::Output {
  190. Amount(self.0.checked_add(rhs.0).expect("Addition error"))
  191. }
  192. }
  193. impl std::ops::AddAssign for Amount {
  194. fn add_assign(&mut self, rhs: Self) {
  195. self.0 = self.0.checked_add(rhs.0).expect("Addition error");
  196. }
  197. }
  198. impl std::ops::Sub for Amount {
  199. type Output = Amount;
  200. fn sub(self, rhs: Amount) -> Self::Output {
  201. Amount(self.0 - rhs.0)
  202. }
  203. }
  204. impl std::ops::SubAssign for Amount {
  205. fn sub_assign(&mut self, other: Self) {
  206. self.0 -= other.0;
  207. }
  208. }
  209. impl std::ops::Mul for Amount {
  210. type Output = Self;
  211. fn mul(self, other: Self) -> Self::Output {
  212. Amount(self.0 * other.0)
  213. }
  214. }
  215. impl std::ops::Div for Amount {
  216. type Output = Self;
  217. fn div(self, other: Self) -> Self::Output {
  218. Amount(self.0 / other.0)
  219. }
  220. }
  221. /// Kinds of targeting that are supported
  222. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
  223. pub enum SplitTarget {
  224. /// Default target; least amount of proofs
  225. #[default]
  226. None,
  227. /// Target amount for wallet to have most proofs that add up to value
  228. Value(Amount),
  229. /// Specific amounts to split into **MUST** equal amount being split
  230. Values(Vec<Amount>),
  231. }
  232. /// Msats in sat
  233. pub const MSAT_IN_SAT: u64 = 1000;
  234. /// Helper function to convert units
  235. pub fn to_unit<T>(
  236. amount: T,
  237. current_unit: &CurrencyUnit,
  238. target_unit: &CurrencyUnit,
  239. ) -> Result<Amount, Error>
  240. where
  241. T: Into<u64>,
  242. {
  243. let amount = amount.into();
  244. match (current_unit, target_unit) {
  245. (CurrencyUnit::Sat, CurrencyUnit::Sat) => Ok(amount.into()),
  246. (CurrencyUnit::Msat, CurrencyUnit::Msat) => Ok(amount.into()),
  247. (CurrencyUnit::Sat, CurrencyUnit::Msat) => Ok((amount * MSAT_IN_SAT).into()),
  248. (CurrencyUnit::Msat, CurrencyUnit::Sat) => Ok((amount / MSAT_IN_SAT).into()),
  249. (CurrencyUnit::Usd, CurrencyUnit::Usd) => Ok(amount.into()),
  250. (CurrencyUnit::Eur, CurrencyUnit::Eur) => Ok(amount.into()),
  251. _ => Err(Error::CannotConvertUnits),
  252. }
  253. }
  254. #[cfg(test)]
  255. mod tests {
  256. use super::*;
  257. #[test]
  258. fn test_split_amount() {
  259. assert_eq!(Amount::from(1).split(), vec![Amount::from(1)]);
  260. assert_eq!(Amount::from(2).split(), vec![Amount::from(2)]);
  261. assert_eq!(
  262. Amount::from(3).split(),
  263. vec![Amount::from(2), Amount::from(1)]
  264. );
  265. let amounts: Vec<Amount> = [8, 2, 1].iter().map(|a| Amount::from(*a)).collect();
  266. assert_eq!(Amount::from(11).split(), amounts);
  267. let amounts: Vec<Amount> = [128, 64, 32, 16, 8, 4, 2, 1]
  268. .iter()
  269. .map(|a| Amount::from(*a))
  270. .collect();
  271. assert_eq!(Amount::from(255).split(), amounts);
  272. }
  273. #[test]
  274. fn test_split_target_amount() {
  275. let amount = Amount(65);
  276. let split = amount
  277. .split_targeted(&SplitTarget::Value(Amount(32)))
  278. .unwrap();
  279. assert_eq!(vec![Amount(1), Amount(32), Amount(32)], split);
  280. let amount = Amount(150);
  281. let split = amount
  282. .split_targeted(&SplitTarget::Value(Amount::from(50)))
  283. .unwrap();
  284. assert_eq!(
  285. vec![
  286. Amount(2),
  287. Amount(2),
  288. Amount(2),
  289. Amount(16),
  290. Amount(16),
  291. Amount(16),
  292. Amount(32),
  293. Amount(32),
  294. Amount(32)
  295. ],
  296. split
  297. );
  298. let amount = Amount::from(63);
  299. let split = amount
  300. .split_targeted(&SplitTarget::Value(Amount::from(32)))
  301. .unwrap();
  302. assert_eq!(
  303. vec![
  304. Amount(1),
  305. Amount(2),
  306. Amount(4),
  307. Amount(8),
  308. Amount(16),
  309. Amount(32)
  310. ],
  311. split
  312. );
  313. }
  314. #[test]
  315. fn test_split_with_fee() {
  316. let amount = Amount(2);
  317. let fee_ppk = 1;
  318. let split = amount.split_with_fee(fee_ppk).unwrap();
  319. assert_eq!(split, vec![Amount(2), Amount(1)]);
  320. let amount = Amount(3);
  321. let fee_ppk = 1;
  322. let split = amount.split_with_fee(fee_ppk).unwrap();
  323. assert_eq!(split, vec![Amount(4)]);
  324. let amount = Amount(3);
  325. let fee_ppk = 1000;
  326. let split = amount.split_with_fee(fee_ppk).unwrap();
  327. assert_eq!(split, vec![Amount(32)]);
  328. }
  329. #[test]
  330. fn test_split_values() {
  331. let amount = Amount(10);
  332. let target = vec![Amount(2), Amount(4), Amount(4)];
  333. let split_target = SplitTarget::Values(target.clone());
  334. let values = amount.split_targeted(&split_target).unwrap();
  335. assert_eq!(target, values);
  336. let target = vec![Amount(2), Amount(4), Amount(4)];
  337. let split_target = SplitTarget::Values(vec![Amount(2), Amount(4)]);
  338. let values = amount.split_targeted(&split_target).unwrap();
  339. assert_eq!(target, values);
  340. let split_target = SplitTarget::Values(vec![Amount(2), Amount(10)]);
  341. let values = amount.split_targeted(&split_target);
  342. assert!(values.is_err())
  343. }
  344. #[test]
  345. #[should_panic]
  346. fn test_amount_addition() {
  347. let amount_one: Amount = u64::MAX.into();
  348. let amount_two: Amount = 1.into();
  349. let amounts = vec![amount_one, amount_two];
  350. let _total: Amount = Amount::try_sum(amounts).unwrap();
  351. }
  352. #[test]
  353. fn test_try_amount_addition() {
  354. let amount_one: Amount = u64::MAX.into();
  355. let amount_two: Amount = 1.into();
  356. let amounts = vec![amount_one, amount_two];
  357. let total = Amount::try_sum(amounts);
  358. assert!(total.is_err());
  359. let amount_one: Amount = 10000.into();
  360. let amount_two: Amount = 1.into();
  361. let amounts = vec![amount_one, amount_two];
  362. let total = Amount::try_sum(amounts).unwrap();
  363. assert_eq!(total, 10001.into());
  364. }
  365. #[test]
  366. fn test_amount_to_unit() {
  367. let amount = Amount::from(1000);
  368. let current_unit = CurrencyUnit::Sat;
  369. let target_unit = CurrencyUnit::Msat;
  370. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  371. assert_eq!(converted, 1000000.into());
  372. let amount = Amount::from(1000);
  373. let current_unit = CurrencyUnit::Msat;
  374. let target_unit = CurrencyUnit::Sat;
  375. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  376. assert_eq!(converted, 1.into());
  377. let amount = Amount::from(1);
  378. let current_unit = CurrencyUnit::Usd;
  379. let target_unit = CurrencyUnit::Usd;
  380. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  381. assert_eq!(converted, 1.into());
  382. let amount = Amount::from(1);
  383. let current_unit = CurrencyUnit::Eur;
  384. let target_unit = CurrencyUnit::Eur;
  385. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  386. assert_eq!(converted, 1.into());
  387. let amount = Amount::from(1);
  388. let current_unit = CurrencyUnit::Sat;
  389. let target_unit = CurrencyUnit::Eur;
  390. let converted = to_unit(amount, &current_unit, &target_unit);
  391. assert!(converted.is_err());
  392. }
  393. }