Переглянути джерело

Add type for replay protection

From Option<String> to Option<MaxLengthString<64>>
Cesar Rodas 8 місяців тому
батько
коміт
7ab7b48c31

+ 4 - 2
utxo/src/id/mod.rs

@@ -1,3 +1,4 @@
+use crate::{BaseTx, Transaction};
 use serde::{de, Deserialize, Deserializer, Serialize};
 use std::{fmt::Display, ops::Deref, str::FromStr};
 
@@ -16,6 +17,9 @@ use std::{fmt::Display, ops::Deref, str::FromStr};
 /// A string with a max-length checked at compiled time
 pub struct MaxLengthString<const MAX_LENGTH: usize>(String);
 
+/// Replay protection ID
+pub type ReplayProtection = MaxLengthString<64>;
+
 impl<const MAX_LENGTH: usize> PartialEq<str> for MaxLengthString<MAX_LENGTH> {
     fn eq(&self, other: &str) -> bool {
         self.0.eq(other)
@@ -157,6 +161,4 @@ mod binary;
 mod error;
 mod payment;
 
-use crate::{BaseTx, Transaction};
-
 pub use self::{error::Error, payment::PaymentId};

+ 2 - 2
utxo/src/storage/cache/batch.rs

@@ -2,7 +2,7 @@ use super::Storage;
 use crate::{
     payment::PaymentTo,
     storage::{AccountTransactionType, Batch, Error, ReceivedPaymentStatus},
-    AccountId, BaseTx, PaymentId, RevId, Revision, Tag, TxId,
+    AccountId, BaseTx, PaymentId, ReplayProtection, RevId, Revision, Tag, TxId,
 };
 use std::{collections::HashMap, marker::PhantomData, sync::Arc};
 
@@ -68,7 +68,7 @@ where
 
     async fn store_replay_protection(
         &mut self,
-        protection: &str,
+        protection: &ReplayProtection,
         transaction_id: &TxId,
     ) -> Result<(), Error> {
         self.inner

+ 4 - 4
utxo/src/storage/mod.rs

@@ -1,7 +1,7 @@
 //! Storage layer trait
 use crate::{
     amount::AmountCents, payment::PaymentTo, transaction::Type, AccountId, Amount, Asset, BaseTx,
-    Filter, PaymentFrom, PaymentId, RevId, Revision, Tag, Transaction, TxId,
+    Filter, PaymentFrom, PaymentId, ReplayProtection, RevId, Revision, Tag, Transaction, TxId,
 };
 //use chrono::{DateTime, Utc};
 use serde::Serialize;
@@ -167,7 +167,7 @@ pub trait Batch<'a> {
     /// Stores the replay protection. It fails if the protection is already stored.
     async fn store_replay_protection(
         &mut self,
-        protection: &str,
+        protection: &ReplayProtection,
         transaction_id: &TxId,
     ) -> Result<(), Error>;
 
@@ -413,7 +413,7 @@ pub mod test {
             )],
         )
         .expect("valid tx")
-        .set_replay_protection("test".to_owned())
+        .set_replay_protection("test".into())
         .expect("valid tx");
 
         assert!(ledger.store(deposit).await.is_ok());
@@ -428,7 +428,7 @@ pub mod test {
             )],
         )
         .expect("valid tx")
-        .set_replay_protection("test".to_owned())
+        .set_replay_protection("test".into())
         .expect("valid tx");
 
         let result = ledger.store(deposit).await;

+ 4 - 4
utxo/src/storage/sqlite/batch.rs

@@ -1,7 +1,7 @@
 use crate::{
     payment::PaymentTo,
     storage::{self, to_bytes, AccountTransactionType, Error, ReceivedPaymentStatus},
-    AccountId, BaseTx, PaymentId, RevId, Revision, Tag, TxId, Type,
+    AccountId, BaseTx, PaymentId, ReplayProtection, RevId, Revision, Tag, TxId, Type,
 };
 use sqlx::{Row, Sqlite, Transaction as SqlxTransaction};
 use std::{marker::PhantomData, num::TryFromIntError};
@@ -66,18 +66,18 @@ impl<'a> storage::Batch<'a> for Batch<'a> {
 
     async fn store_replay_protection(
         &mut self,
-        protection: &str,
+        protection: &ReplayProtection,
         transaction_id: &TxId,
     ) -> Result<(), Error> {
         let query =
             sqlx::query(r#"INSERT INTO "transactions_replay_protection"("protection_id", "transaction_id") VALUES(?, ?) "#,
-            ).bind(protection).bind(transaction_id.to_string())
+            ).bind(protection.as_str()).bind(transaction_id.to_string())
             .execute(&mut *self.inner)
             .await;
 
         if let Err(e) = query {
             let default_err = e.to_string();
-            let query =  sqlx::query(r#"SELECT "transaction_id" FROM "transactions_replay_protection" WHERE "protection_id" = ? "#).bind(protection)
+            let query =  sqlx::query(r#"SELECT "transaction_id" FROM "transactions_replay_protection" WHERE "protection_id" = ? "#).bind(protection.as_str())
                 .fetch_optional(&mut *self.inner).await;
             Err(if let Ok(Some(row)) = query {
                 Error::AlreadyExists(row.get::<String, usize>(0).parse()?)

+ 2 - 2
utxo/src/transaction/base_tx.rs

@@ -3,7 +3,7 @@ use crate::{
     payment::PaymentTo,
     storage::to_bytes,
     transaction::{Error, Revision, Type},
-    Accounts, Asset, PaymentFrom, Status, TxId,
+    Accounts, Asset, PaymentFrom, ReplayProtection, Status, TxId,
 };
 use chrono::{serde::ts_milliseconds, DateTime, Utc};
 use serde::{Deserialize, Serialize};
@@ -19,7 +19,7 @@ use std::collections::HashMap;
 /// Withdrawal transaction.
 pub struct BaseTx {
     /// A unique identifier generated by the client to make sure the transaction was not created before. If provided the storage layer will make sure it is unique.
-    pub replay_protection: Option<String>,
+    pub replay_protection: Option<ReplayProtection>,
     /// List of spend payments to create this transaction
     #[serde(skip_serializing_if = "Vec::is_empty")]
     pub spends: Vec<PaymentFrom>,

+ 6 - 2
utxo/src/transaction/mod.rs

@@ -3,7 +3,8 @@ use crate::{
     payment::PaymentTo,
     storage::Storage,
     token::{TokenManager, TokenPayload},
-    AccountId, Amount, FilterableValue, MaxLengthString, PaymentFrom, RevId, Status, TxId,
+    AccountId, Amount, FilterableValue, MaxLengthString, PaymentFrom, ReplayProtection, RevId,
+    Status, TxId,
 };
 use chrono::{DateTime, Duration, TimeZone, Utc};
 use serde::{Deserialize, Serialize};
@@ -124,7 +125,10 @@ impl Transaction {
     }
 
     /// Consumes the current transaction and replaces it with a similar transaction with replay protection
-    pub fn set_replay_protection(mut self, replay_protection: String) -> Result<Self, Error> {
+    pub fn set_replay_protection(
+        mut self,
+        replay_protection: ReplayProtection,
+    ) -> Result<Self, Error> {
         self.transaction.replay_protection = Some(replay_protection);
 
         let new_tx_id = self.transaction.id()?;