mod.rs 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. use crate::parser::{parse, ConfigValue, Error as ParsingError};
  2. use args::ArgsDeserializer;
  3. use serde::de::{self, IntoDeserializer};
  4. use std::str;
  5. use thiserror::Error as ThisError;
  6. mod args;
  7. mod value;
  8. /// Errors that can occur when deserializing a type.
  9. #[derive(Debug, PartialEq, Eq, Clone, ThisError)]
  10. pub enum Error {
  11. /// EOF was reached when looking for a value
  12. #[error("Unexpected end of file")]
  13. UnexpectedEof(ErrorInfo),
  14. #[error("End of stream")]
  15. EndOfStream,
  16. /// Custom errors
  17. #[error("Custom error")]
  18. Custom(ErrorInfo),
  19. }
  20. #[derive(Debug, PartialEq, Eq, Clone)]
  21. pub struct ErrorInfo {
  22. line: Option<usize>,
  23. col: usize,
  24. at: Option<usize>,
  25. message: String,
  26. }
  27. pub fn from_str<'de, T>(s: &'de str) -> Result<T, Error>
  28. where
  29. T: de::Deserialize<'de>,
  30. {
  31. from_slice(s.as_bytes())
  32. }
  33. pub fn from_slice<'de, T>(bytes: &'de [u8]) -> Result<T, Error>
  34. where
  35. T: de::Deserialize<'de>,
  36. {
  37. let mut d = Deserializer::new(bytes);
  38. let ret = T::deserialize(&mut d)?;
  39. d.end()?;
  40. Ok(ret)
  41. }
  42. /// Deserialization implementation for Config protocol
  43. pub struct Deserializer<'a> {
  44. input: &'a [u8],
  45. }
  46. impl<'a> Deserializer<'a> {
  47. pub fn new(input: &'a [u8]) -> Self {
  48. Self { input }
  49. }
  50. pub fn end(&mut self) -> Result<(), Error> {
  51. Ok(())
  52. }
  53. /// Return the next value
  54. #[inline]
  55. pub fn parse_next(&mut self) -> Result<ConfigValue<'a>, Error> {
  56. match parse(self.input) {
  57. Ok((new_stream, value)) => {
  58. self.input = new_stream;
  59. Ok(value)
  60. }
  61. Err(ParsingError::Partial) => Err(Error::EndOfStream),
  62. }
  63. }
  64. }
  65. impl<'de, 'b> de::Deserializer<'de> for &'b mut Deserializer<'de> {
  66. type Error = Error;
  67. fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
  68. where
  69. V: de::Visitor<'de>,
  70. {
  71. visitor.visit_map(MapVisitor {
  72. de: self,
  73. last_value: None,
  74. })
  75. }
  76. serde::forward_to_deserialize_any! {
  77. bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq
  78. bytes byte_buf map unit newtype_struct
  79. ignored_any unit_struct tuple_struct tuple option identifier
  80. enum struct
  81. }
  82. }
  83. struct MapVisitor<'de, 'b> {
  84. de: &'b mut Deserializer<'de>,
  85. last_value: Option<ConfigValue<'de>>,
  86. }
  87. impl<'de, 'b> de::MapAccess<'de> for MapVisitor<'de, 'b> {
  88. type Error = Error;
  89. fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Error>
  90. where
  91. K: de::DeserializeSeed<'de>,
  92. {
  93. match self.de.parse_next() {
  94. Ok(v) => {
  95. let name = v.name.clone();
  96. self.last_value = Some(v);
  97. seed.deserialize(name.into_deserializer()).map(Some)
  98. }
  99. _ => Ok(None),
  100. }
  101. }
  102. #[inline]
  103. fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Error>
  104. where
  105. V: de::DeserializeSeed<'de>,
  106. {
  107. seed.deserialize(ArgsDeserializer {
  108. input: self.last_value.as_ref().unwrap().args.clone(),
  109. })
  110. }
  111. }
  112. impl Error {
  113. pub fn custom(at: Option<usize>, s: String) -> Self {
  114. Self::Custom(ErrorInfo {
  115. line: None,
  116. col: 0,
  117. at,
  118. message: s,
  119. })
  120. }
  121. }
  122. impl de::Error for Error {
  123. fn custom<T: std::fmt::Display>(msg: T) -> Error {
  124. Error::custom(None, msg.to_string())
  125. }
  126. }
  127. #[cfg(test)]
  128. mod test {
  129. use super::*;
  130. use serde::Deserialize;
  131. use serde_enum_str::Deserialize_enum_str;
  132. #[derive(Deserialize, Debug)]
  133. pub struct Foo {
  134. foo: Vec<i32>,
  135. bar: u8,
  136. xxx: Option<String>,
  137. }
  138. #[derive(Deserialize, Debug, Default)]
  139. pub struct SaveInfo(pub u64, pub u64);
  140. #[derive(Deserialize_enum_str, Debug, PartialEq)]
  141. pub enum AppendFsync {
  142. #[serde(rename = "always")]
  143. Always,
  144. #[serde(rename = "everysec")]
  145. EverySecond,
  146. #[serde(other, rename = "no")]
  147. No,
  148. }
  149. #[derive(Deserialize_enum_str, Debug, PartialEq)]
  150. pub enum LogLevel {
  151. #[serde(rename = "debug")]
  152. Debug,
  153. #[serde(rename = "verbose")]
  154. Verbose,
  155. #[serde(rename = "notice")]
  156. Notice,
  157. #[serde(rename = "warning")]
  158. Warning,
  159. }
  160. impl Default for LogLevel {
  161. fn default() -> Self {
  162. Self::Warning
  163. }
  164. }
  165. impl Default for AppendFsync {
  166. fn default() -> Self {
  167. Self::No
  168. }
  169. }
  170. #[derive(Deserialize, Debug, Default)]
  171. pub struct Config {
  172. #[serde(rename = "always-show-logo")]
  173. always_show_logo: bool,
  174. #[serde(rename = "notify-keyspace-events")]
  175. notify_keyspace_events: String,
  176. daemonize: bool,
  177. port: u32,
  178. save: SaveInfo,
  179. #[serde(rename = "appendfsync")]
  180. append_fsync: AppendFsync,
  181. #[serde(flatten)]
  182. log: Log,
  183. databases: u8,
  184. }
  185. #[derive(Deserialize, Debug, Default)]
  186. pub struct Log {
  187. #[serde(rename = "loglevel")]
  188. level: LogLevel,
  189. #[serde(rename = "logfile")]
  190. file: String,
  191. }
  192. #[test]
  193. fn de() {
  194. let x: Foo = from_str("foo 32 44 12\r\nbar 32\r\n").unwrap();
  195. assert_eq!(32, x.bar);
  196. assert_eq!(None, x.xxx);
  197. assert_eq!(3, x.foo.len());
  198. }
  199. #[test]
  200. fn real_config() {
  201. let x: Config = from_str(
  202. "always-show-logo yes
  203. notify-keyspace-events KEA
  204. daemonize no
  205. pidfile /var/run/redis.pid
  206. port 24611
  207. timeout 0
  208. bind 127.0.0.1
  209. loglevel verbose
  210. logfile ''
  211. databases 16
  212. latency-monitor-threshold 1
  213. save 60 10000
  214. rdbcompression yes
  215. dbfilename dump.rdb
  216. dir ./tests/tmp/server.64463.1
  217. slave-serve-stale-data yes
  218. appendonly no
  219. appendfsync everysec
  220. no-appendfsync-on-rewrite no
  221. activerehashing yes
  222. unixsocket /home/crodas/redis/tests/tmp/server.64463.1/socket
  223. ",
  224. )
  225. .unwrap();
  226. assert!(x.always_show_logo);
  227. assert_eq!(60, x.save.0);
  228. assert_eq!(10_000, x.save.1);
  229. assert_eq!(24_611, x.port);
  230. assert_eq!("KEA", x.notify_keyspace_events);
  231. assert_eq!(AppendFsync::EverySecond, x.append_fsync);
  232. assert!(!x.daemonize);
  233. assert_eq!(LogLevel::Verbose, x.log.level);
  234. assert_eq!("", x.log.file);
  235. assert_eq!(16, x.databases);
  236. }
  237. }