diff --git a/tuic-server/src/config.rs b/tuic-server/src/config.rs index de01b6c..9d0deea 100644 --- a/tuic-server/src/config.rs +++ b/tuic-server/src/config.rs @@ -39,15 +39,27 @@ pub struct Config { #[serde(default = "default::zero_rtt_handshake")] pub zero_rtt_handshake: bool, pub dual_stack: Option, - #[serde(default = "default::auth_timeout")] + #[serde( + default = "default::auth_timeout", + deserialize_with = "deserialize_duration" + )] pub auth_timeout: Duration, - #[serde(default = "default::max_idle_time")] + #[serde( + default = "default::max_idle_time", + deserialize_with = "deserialize_duration" + )] pub max_idle_time: Duration, #[serde(default = "default::max_external_packet_size")] pub max_external_packet_size: usize, - #[serde(default = "default::gc_interval")] + #[serde( + default = "default::gc_interval", + deserialize_with = "deserialize_duration" + )] pub gc_interval: Duration, - #[serde(default = "default::gc_lifetime")] + #[serde( + default = "default::gc_lifetime", + deserialize_with = "deserialize_duration" + )] pub gc_lifetime: Duration, #[serde(default = "default::log_level")] pub log_level: LevelFilter, @@ -153,6 +165,59 @@ where Ok(map) } +fn parse_duration(s: &str) -> Result { + let mut num = Vec::with_capacity(8); + let mut chars = Vec::with_capacity(2); + let mut expected_unit = false; + for c in s.chars() { + if !expected_unit && c.is_numeric() { + num.push(c); + continue; + } + + if !expected_unit && c.is_ascii() { + expected_unit = true; + } + chars.push(c); + } + + let n: u64 = num + .into_iter() + .collect::() + .parse() + .map_err(|e| format!("invalid value: {}, reason {}", &s, e))?; + + match chars.into_iter().collect::().as_str() { + "" => Ok(Duration::from_millis(n)), + "s" => Ok(Duration::from_secs(n)), + "ms" => Ok(Duration::from_millis(n)), + _ => Err(format!("invalid value: {}, expected 10s or 10ms", &s)), + } +} + +#[test] +fn test_parseduration() { + // test parse + assert_eq!(parse_duration("100s"), Ok(Duration::from_secs(100))); + assert_eq!(parse_duration("100ms"), Ok(Duration::from_millis(100))); + + // test default unit + assert_eq!(parse_duration("10000"), Ok(Duration::from_millis(10000))); + + // test invalid data + assert!(parse_duration("").is_err()); + assert!(parse_duration("1ms100").is_err()); + assert!(parse_duration("ms").is_err()); +} + +pub fn deserialize_duration<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + let s: String = String::deserialize(deserializer)?; + parse_duration(&s).map_err(DeError::custom) +} + #[derive(Debug, Error)] pub enum ConfigError { #[error(transparent)]