mod.rs 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. //! HTTP cache.
  2. //!
  3. //! This is mod defines a common trait to define custom backends for the HTTP cache.
  4. //!
  5. //! The HTTP cache is a layer to cache responses from HTTP requests, to avoid hitting
  6. //! the same endpoint multiple times, which can be expensive and slow, or to provide
  7. //! idempotent operations.
  8. //!
  9. //! This mod also provides common backend implementations as well, such as In
  10. //! Memory (default) and Redis.
  11. use std::ops::Deref;
  12. use std::sync::Arc;
  13. use std::time::Duration;
  14. use serde::de::DeserializeOwned;
  15. use serde::Serialize;
  16. use sha2::{Digest, Sha256};
  17. mod backend;
  18. mod config;
  19. pub use self::backend::*;
  20. pub use self::config::Config;
  21. #[async_trait::async_trait]
  22. /// Cache storage for the HTTP cache.
  23. pub trait HttpCacheStorage {
  24. /// Sets the expiration times for the cache.
  25. fn set_expiration_times(&mut self, cache_ttl: Duration, cache_tti: Duration);
  26. /// Get a value from the cache.
  27. async fn get(&self, key: &HttpCacheKey) -> Option<Vec<u8>>;
  28. /// Set a value in the cache.
  29. async fn set(&self, key: HttpCacheKey, value: Vec<u8>);
  30. }
  31. /// Http cache with a pluggable storage backend.
  32. pub struct HttpCache {
  33. /// Time to live for the cache.
  34. pub ttl: Duration,
  35. /// Time to idle for the cache.
  36. pub tti: Duration,
  37. /// Storage backend for the cache.
  38. storage: Arc<Box<dyn HttpCacheStorage + Send + Sync>>,
  39. }
  40. impl Default for HttpCache {
  41. fn default() -> Self {
  42. Self::new(
  43. Duration::from_secs(DEFAULT_TTL_SECS),
  44. Duration::from_secs(DEFAULT_TTI_SECS),
  45. None,
  46. )
  47. }
  48. }
  49. /// Max payload size for the cache key.
  50. ///
  51. /// This is a trade-off between security and performance. A large payload can be used to
  52. /// perform a CPU attack.
  53. const MAX_PAYLOAD_SIZE: usize = 10 * 1024 * 1024;
  54. /// Default TTL for the cache.
  55. const DEFAULT_TTL_SECS: u64 = 60;
  56. /// Default TTI for the cache.
  57. const DEFAULT_TTI_SECS: u64 = 60;
  58. /// Http cache key.
  59. ///
  60. /// This type ensures no Vec<u8> is used as a key, which is error-prone.
  61. #[derive(Clone, Debug, PartialEq, Eq, Hash)]
  62. pub struct HttpCacheKey([u8; 32]);
  63. impl Deref for HttpCacheKey {
  64. type Target = [u8; 32];
  65. fn deref(&self) -> &Self::Target {
  66. &self.0
  67. }
  68. }
  69. impl From<config::Config> for HttpCache {
  70. fn from(config: config::Config) -> Self {
  71. match config.backend {
  72. config::Backend::Memory => Self::new(
  73. Duration::from_secs(config.ttl.unwrap_or(DEFAULT_TTL_SECS)),
  74. Duration::from_secs(config.tti.unwrap_or(DEFAULT_TTI_SECS)),
  75. None,
  76. ),
  77. #[cfg(feature = "redis")]
  78. config::Backend::Redis(redis_config) => {
  79. let client = redis::Client::open(redis_config.connection_string)
  80. .expect("Failed to create Redis client");
  81. let storage = HttpCacheRedis::new(client).set_prefix(
  82. redis_config
  83. .key_prefix
  84. .unwrap_or_default()
  85. .as_bytes()
  86. .to_vec(),
  87. );
  88. Self::new(
  89. Duration::from_secs(config.ttl.unwrap_or(DEFAULT_TTL_SECS)),
  90. Duration::from_secs(config.tti.unwrap_or(DEFAULT_TTI_SECS)),
  91. Some(Box::new(storage)),
  92. )
  93. }
  94. }
  95. }
  96. }
  97. impl HttpCache {
  98. /// Create a new HTTP cache.
  99. pub fn new(
  100. ttl: Duration,
  101. tti: Duration,
  102. storage: Option<Box<dyn HttpCacheStorage + Send + Sync + 'static>>,
  103. ) -> Self {
  104. let mut storage = storage.unwrap_or_else(|| Box::new(InMemoryHttpCache::default()));
  105. storage.set_expiration_times(ttl, tti);
  106. Self {
  107. ttl,
  108. tti,
  109. storage: Arc::new(storage),
  110. }
  111. }
  112. /// Calculate a cache key from a serializable value.
  113. ///
  114. /// Usually the input is the request body or query parameters.
  115. ///
  116. /// The result is an optional cache key. If the key cannot be calculated, it
  117. /// will be None, meaning the value cannot be cached, therefore the entire
  118. /// caching mechanism should be skipped.
  119. ///
  120. /// Instead of using the entire serialized input as the key, the key is a
  121. /// double hash to have a predictable key size, although it may open the
  122. /// window for CPU attacks with large payloads, but it is a trade-off.
  123. /// Perhaps upper layer have a protection against large payloads.
  124. pub fn calculate_key<K>(&self, key: &K) -> Option<HttpCacheKey>
  125. where
  126. K: Serialize,
  127. {
  128. let json_value = match serde_json::to_vec(key) {
  129. Ok(value) => value,
  130. Err(err) => {
  131. tracing::warn!("Failed to serialize key: {:?}", err);
  132. return None;
  133. }
  134. };
  135. if json_value.len() > MAX_PAYLOAD_SIZE {
  136. tracing::warn!("Key size is too large: {}", json_value.len());
  137. return None;
  138. }
  139. let first_hash = Sha256::digest(json_value);
  140. let second_hash = Sha256::digest(first_hash);
  141. Some(HttpCacheKey(second_hash.into()))
  142. }
  143. /// Get a value from the cache.
  144. pub async fn get<V>(self: &Arc<Self>, key: &HttpCacheKey) -> Option<V>
  145. where
  146. V: DeserializeOwned,
  147. {
  148. self.storage.get(key).await.and_then(|value| {
  149. serde_json::from_slice(&value)
  150. .map_err(|e| {
  151. tracing::warn!("Failed to deserialize value: {:?}", e);
  152. e
  153. })
  154. .ok()
  155. })
  156. }
  157. /// Set a value in the cache.
  158. pub async fn set<V: Serialize>(self: &Arc<Self>, key: HttpCacheKey, value: &V) {
  159. if let Ok(bytes) = serde_json::to_vec(value).map_err(|e| {
  160. tracing::warn!("Failed to serialize value: {:?}", e);
  161. e
  162. }) {
  163. self.storage.set(key, bytes).await;
  164. }
  165. }
  166. }