amount.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. //! CDK Amount
  2. //!
  3. //! Is any unit and will be treated as the unit of the wallet
  4. use std::cmp::Ordering;
  5. use std::fmt;
  6. use std::str::FromStr;
  7. use lightning::offers::offer::Offer;
  8. use serde::{Deserialize, Serialize};
  9. use thiserror::Error;
  10. use crate::nuts::CurrencyUnit;
  11. /// Amount Error
  12. #[derive(Debug, Error)]
  13. pub enum Error {
  14. /// Split Values must be less then or equal to amount
  15. #[error("Split Values must be less then or equal to amount")]
  16. SplitValuesGreater,
  17. /// Amount overflow
  18. #[error("Amount Overflow")]
  19. AmountOverflow,
  20. /// Cannot convert units
  21. #[error("Cannot convert units")]
  22. CannotConvertUnits,
  23. /// Invalid amount
  24. #[error("Invalid Amount: {0}")]
  25. InvalidAmount(String),
  26. /// Amount undefined
  27. #[error("Amount undefined")]
  28. AmountUndefined,
  29. /// Utf8 parse error
  30. #[error(transparent)]
  31. Utf8ParseError(#[from] std::string::FromUtf8Error),
  32. }
  33. /// Amount can be any unit
  34. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
  35. #[cfg_attr(feature = "swagger", derive(utoipa::ToSchema))]
  36. #[serde(transparent)]
  37. pub struct Amount(u64);
  38. impl FromStr for Amount {
  39. type Err = Error;
  40. fn from_str(s: &str) -> Result<Self, Self::Err> {
  41. let value = s
  42. .parse::<u64>()
  43. .map_err(|_| Error::InvalidAmount(s.to_owned()))?;
  44. Ok(Amount(value))
  45. }
  46. }
  47. impl Amount {
  48. /// Amount zero
  49. pub const ZERO: Amount = Amount(0);
  50. /// Amount one
  51. pub const ONE: Amount = Amount(1);
  52. /// Split into parts that are powers of two
  53. pub fn split(&self) -> Vec<Self> {
  54. let sats = self.0;
  55. (0_u64..64)
  56. .rev()
  57. .filter_map(|bit| {
  58. let part = 1 << bit;
  59. ((sats & part) == part).then_some(Self::from(part))
  60. })
  61. .collect()
  62. }
  63. /// Split into parts that are powers of two by target
  64. pub fn split_targeted(&self, target: &SplitTarget) -> Result<Vec<Self>, Error> {
  65. let mut parts = match target {
  66. SplitTarget::None => self.split(),
  67. SplitTarget::Value(amount) => {
  68. if self.le(amount) {
  69. return Ok(self.split());
  70. }
  71. let mut parts_total = Amount::ZERO;
  72. let mut parts = Vec::new();
  73. // The powers of two that are need to create target value
  74. let parts_of_value = amount.split();
  75. while parts_total.lt(self) {
  76. for part in parts_of_value.iter().copied() {
  77. if (part + parts_total).le(self) {
  78. parts.push(part);
  79. } else {
  80. let amount_left = *self - parts_total;
  81. parts.extend(amount_left.split());
  82. }
  83. parts_total = Amount::try_sum(parts.clone().iter().copied())?;
  84. if parts_total.eq(self) {
  85. break;
  86. }
  87. }
  88. }
  89. parts
  90. }
  91. SplitTarget::Values(values) => {
  92. let values_total: Amount = Amount::try_sum(values.clone().into_iter())?;
  93. match self.cmp(&values_total) {
  94. Ordering::Equal => values.clone(),
  95. Ordering::Less => {
  96. return Err(Error::SplitValuesGreater);
  97. }
  98. Ordering::Greater => {
  99. let extra = *self - values_total;
  100. let mut extra_amount = extra.split();
  101. let mut values = values.clone();
  102. values.append(&mut extra_amount);
  103. values
  104. }
  105. }
  106. }
  107. };
  108. parts.sort();
  109. Ok(parts)
  110. }
  111. /// Splits amount into powers of two while accounting for the swap fee
  112. pub fn split_with_fee(&self, fee_ppk: u64) -> Result<Vec<Self>, Error> {
  113. let without_fee_amounts = self.split();
  114. let fee_ppk = fee_ppk * without_fee_amounts.len() as u64;
  115. let fee = Amount::from(fee_ppk.div_ceil(1000));
  116. let new_amount = self.checked_add(fee).ok_or(Error::AmountOverflow)?;
  117. let split = new_amount.split();
  118. let split_fee_ppk = split.len() as u64 * fee_ppk;
  119. let split_fee = Amount::from(split_fee_ppk.div_ceil(1000));
  120. if let Some(net_amount) = new_amount.checked_sub(split_fee) {
  121. if net_amount >= *self {
  122. return Ok(split);
  123. }
  124. }
  125. self.checked_add(Amount::ONE)
  126. .ok_or(Error::AmountOverflow)?
  127. .split_with_fee(fee_ppk)
  128. }
  129. /// Checked addition for Amount. Returns None if overflow occurs.
  130. pub fn checked_add(self, other: Amount) -> Option<Amount> {
  131. self.0.checked_add(other.0).map(Amount)
  132. }
  133. /// Checked subtraction for Amount. Returns None if overflow occurs.
  134. pub fn checked_sub(self, other: Amount) -> Option<Amount> {
  135. self.0.checked_sub(other.0).map(Amount)
  136. }
  137. /// Checked multiplication for Amount. Returns None if overflow occurs.
  138. pub fn checked_mul(self, other: Amount) -> Option<Amount> {
  139. self.0.checked_mul(other.0).map(Amount)
  140. }
  141. /// Checked division for Amount. Returns None if overflow occurs.
  142. pub fn checked_div(self, other: Amount) -> Option<Amount> {
  143. self.0.checked_div(other.0).map(Amount)
  144. }
  145. /// Try sum to check for overflow
  146. pub fn try_sum<I>(iter: I) -> Result<Self, Error>
  147. where
  148. I: IntoIterator<Item = Self>,
  149. {
  150. iter.into_iter().try_fold(Amount::ZERO, |acc, x| {
  151. acc.checked_add(x).ok_or(Error::AmountOverflow)
  152. })
  153. }
  154. /// Convert unit
  155. pub fn convert_unit(
  156. &self,
  157. current_unit: &CurrencyUnit,
  158. target_unit: &CurrencyUnit,
  159. ) -> Result<Amount, Error> {
  160. to_unit(self.0, current_unit, target_unit)
  161. }
  162. /// Convert to i64
  163. pub fn to_i64(self) -> Option<i64> {
  164. if self.0 <= i64::MAX as u64 {
  165. Some(self.0 as i64)
  166. } else {
  167. None
  168. }
  169. }
  170. /// Create from i64, returning None if negative
  171. pub fn from_i64(value: i64) -> Option<Self> {
  172. if value >= 0 {
  173. Some(Amount(value as u64))
  174. } else {
  175. None
  176. }
  177. }
  178. }
  179. impl Default for Amount {
  180. fn default() -> Self {
  181. Amount::ZERO
  182. }
  183. }
  184. impl Default for &Amount {
  185. fn default() -> Self {
  186. &Amount::ZERO
  187. }
  188. }
  189. impl fmt::Display for Amount {
  190. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  191. if let Some(width) = f.width() {
  192. write!(f, "{:width$}", self.0, width = width)
  193. } else {
  194. write!(f, "{}", self.0)
  195. }
  196. }
  197. }
  198. impl From<u64> for Amount {
  199. fn from(value: u64) -> Self {
  200. Self(value)
  201. }
  202. }
  203. impl From<&u64> for Amount {
  204. fn from(value: &u64) -> Self {
  205. Self(*value)
  206. }
  207. }
  208. impl From<Amount> for u64 {
  209. fn from(value: Amount) -> Self {
  210. value.0
  211. }
  212. }
  213. impl AsRef<u64> for Amount {
  214. fn as_ref(&self) -> &u64 {
  215. &self.0
  216. }
  217. }
  218. impl std::ops::Add for Amount {
  219. type Output = Amount;
  220. fn add(self, rhs: Amount) -> Self::Output {
  221. Amount(self.0.checked_add(rhs.0).expect("Addition error"))
  222. }
  223. }
  224. impl std::ops::AddAssign for Amount {
  225. fn add_assign(&mut self, rhs: Self) {
  226. self.0 = self.0.checked_add(rhs.0).expect("Addition error");
  227. }
  228. }
  229. impl std::ops::Sub for Amount {
  230. type Output = Amount;
  231. fn sub(self, rhs: Amount) -> Self::Output {
  232. Amount(self.0 - rhs.0)
  233. }
  234. }
  235. impl std::ops::SubAssign for Amount {
  236. fn sub_assign(&mut self, other: Self) {
  237. self.0 -= other.0;
  238. }
  239. }
  240. impl std::ops::Mul for Amount {
  241. type Output = Self;
  242. fn mul(self, other: Self) -> Self::Output {
  243. Amount(self.0 * other.0)
  244. }
  245. }
  246. impl std::ops::Div for Amount {
  247. type Output = Self;
  248. fn div(self, other: Self) -> Self::Output {
  249. Amount(self.0 / other.0)
  250. }
  251. }
  252. /// Convert offer to amount in unit
  253. pub fn amount_for_offer(offer: &Offer, unit: &CurrencyUnit) -> Result<Amount, Error> {
  254. let offer_amount = offer.amount().ok_or(Error::AmountUndefined)?;
  255. let (amount, currency) = match offer_amount {
  256. lightning::offers::offer::Amount::Bitcoin { amount_msats } => {
  257. (amount_msats, CurrencyUnit::Msat)
  258. }
  259. lightning::offers::offer::Amount::Currency {
  260. iso4217_code,
  261. amount,
  262. } => (
  263. amount,
  264. CurrencyUnit::from_str(&String::from_utf8(iso4217_code.to_vec())?)
  265. .map_err(|_| Error::CannotConvertUnits)?,
  266. ),
  267. };
  268. to_unit(amount, &currency, unit).map_err(|_err| Error::CannotConvertUnits)
  269. }
  270. /// Kinds of targeting that are supported
  271. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)]
  272. pub enum SplitTarget {
  273. /// Default target; least amount of proofs
  274. #[default]
  275. None,
  276. /// Target amount for wallet to have most proofs that add up to value
  277. Value(Amount),
  278. /// Specific amounts to split into **MUST** equal amount being split
  279. Values(Vec<Amount>),
  280. }
  281. /// Msats in sat
  282. pub const MSAT_IN_SAT: u64 = 1000;
  283. /// Helper function to convert units
  284. pub fn to_unit<T>(
  285. amount: T,
  286. current_unit: &CurrencyUnit,
  287. target_unit: &CurrencyUnit,
  288. ) -> Result<Amount, Error>
  289. where
  290. T: Into<u64>,
  291. {
  292. let amount = amount.into();
  293. match (current_unit, target_unit) {
  294. (CurrencyUnit::Sat, CurrencyUnit::Sat) => Ok(amount.into()),
  295. (CurrencyUnit::Msat, CurrencyUnit::Msat) => Ok(amount.into()),
  296. (CurrencyUnit::Sat, CurrencyUnit::Msat) => Ok((amount * MSAT_IN_SAT).into()),
  297. (CurrencyUnit::Msat, CurrencyUnit::Sat) => Ok((amount / MSAT_IN_SAT).into()),
  298. (CurrencyUnit::Usd, CurrencyUnit::Usd) => Ok(amount.into()),
  299. (CurrencyUnit::Eur, CurrencyUnit::Eur) => Ok(amount.into()),
  300. _ => Err(Error::CannotConvertUnits),
  301. }
  302. }
  303. #[cfg(test)]
  304. mod tests {
  305. use super::*;
  306. #[test]
  307. fn test_split_amount() {
  308. assert_eq!(Amount::from(1).split(), vec![Amount::from(1)]);
  309. assert_eq!(Amount::from(2).split(), vec![Amount::from(2)]);
  310. assert_eq!(
  311. Amount::from(3).split(),
  312. vec![Amount::from(2), Amount::from(1)]
  313. );
  314. let amounts: Vec<Amount> = [8, 2, 1].iter().map(|a| Amount::from(*a)).collect();
  315. assert_eq!(Amount::from(11).split(), amounts);
  316. let amounts: Vec<Amount> = [128, 64, 32, 16, 8, 4, 2, 1]
  317. .iter()
  318. .map(|a| Amount::from(*a))
  319. .collect();
  320. assert_eq!(Amount::from(255).split(), amounts);
  321. }
  322. #[test]
  323. fn test_split_target_amount() {
  324. let amount = Amount(65);
  325. let split = amount
  326. .split_targeted(&SplitTarget::Value(Amount(32)))
  327. .unwrap();
  328. assert_eq!(vec![Amount(1), Amount(32), Amount(32)], split);
  329. let amount = Amount(150);
  330. let split = amount
  331. .split_targeted(&SplitTarget::Value(Amount::from(50)))
  332. .unwrap();
  333. assert_eq!(
  334. vec![
  335. Amount(2),
  336. Amount(2),
  337. Amount(2),
  338. Amount(16),
  339. Amount(16),
  340. Amount(16),
  341. Amount(32),
  342. Amount(32),
  343. Amount(32)
  344. ],
  345. split
  346. );
  347. let amount = Amount::from(63);
  348. let split = amount
  349. .split_targeted(&SplitTarget::Value(Amount::from(32)))
  350. .unwrap();
  351. assert_eq!(
  352. vec![
  353. Amount(1),
  354. Amount(2),
  355. Amount(4),
  356. Amount(8),
  357. Amount(16),
  358. Amount(32)
  359. ],
  360. split
  361. );
  362. }
  363. #[test]
  364. fn test_split_with_fee() {
  365. let amount = Amount(2);
  366. let fee_ppk = 1;
  367. let split = amount.split_with_fee(fee_ppk).unwrap();
  368. assert_eq!(split, vec![Amount(2), Amount(1)]);
  369. let amount = Amount(3);
  370. let fee_ppk = 1;
  371. let split = amount.split_with_fee(fee_ppk).unwrap();
  372. assert_eq!(split, vec![Amount(4)]);
  373. let amount = Amount(3);
  374. let fee_ppk = 1000;
  375. let split = amount.split_with_fee(fee_ppk).unwrap();
  376. assert_eq!(split, vec![Amount(32)]);
  377. }
  378. #[test]
  379. fn test_split_values() {
  380. let amount = Amount(10);
  381. let target = vec![Amount(2), Amount(4), Amount(4)];
  382. let split_target = SplitTarget::Values(target.clone());
  383. let values = amount.split_targeted(&split_target).unwrap();
  384. assert_eq!(target, values);
  385. let target = vec![Amount(2), Amount(4), Amount(4)];
  386. let split_target = SplitTarget::Values(vec![Amount(2), Amount(4)]);
  387. let values = amount.split_targeted(&split_target).unwrap();
  388. assert_eq!(target, values);
  389. let split_target = SplitTarget::Values(vec![Amount(2), Amount(10)]);
  390. let values = amount.split_targeted(&split_target);
  391. assert!(values.is_err())
  392. }
  393. #[test]
  394. #[should_panic]
  395. fn test_amount_addition() {
  396. let amount_one: Amount = u64::MAX.into();
  397. let amount_two: Amount = 1.into();
  398. let amounts = vec![amount_one, amount_two];
  399. let _total: Amount = Amount::try_sum(amounts).unwrap();
  400. }
  401. #[test]
  402. fn test_try_amount_addition() {
  403. let amount_one: Amount = u64::MAX.into();
  404. let amount_two: Amount = 1.into();
  405. let amounts = vec![amount_one, amount_two];
  406. let total = Amount::try_sum(amounts);
  407. assert!(total.is_err());
  408. let amount_one: Amount = 10000.into();
  409. let amount_two: Amount = 1.into();
  410. let amounts = vec![amount_one, amount_two];
  411. let total = Amount::try_sum(amounts).unwrap();
  412. assert_eq!(total, 10001.into());
  413. }
  414. #[test]
  415. fn test_amount_to_unit() {
  416. let amount = Amount::from(1000);
  417. let current_unit = CurrencyUnit::Sat;
  418. let target_unit = CurrencyUnit::Msat;
  419. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  420. assert_eq!(converted, 1000000.into());
  421. let amount = Amount::from(1000);
  422. let current_unit = CurrencyUnit::Msat;
  423. let target_unit = CurrencyUnit::Sat;
  424. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  425. assert_eq!(converted, 1.into());
  426. let amount = Amount::from(1);
  427. let current_unit = CurrencyUnit::Usd;
  428. let target_unit = CurrencyUnit::Usd;
  429. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  430. assert_eq!(converted, 1.into());
  431. let amount = Amount::from(1);
  432. let current_unit = CurrencyUnit::Eur;
  433. let target_unit = CurrencyUnit::Eur;
  434. let converted = to_unit(amount, &current_unit, &target_unit).unwrap();
  435. assert_eq!(converted, 1.into());
  436. let amount = Amount::from(1);
  437. let current_unit = CurrencyUnit::Sat;
  438. let target_unit = CurrencyUnit::Eur;
  439. let converted = to_unit(amount, &current_unit, &target_unit);
  440. assert!(converted.is_err());
  441. }
  442. }