1
0

parse duration in config using humantime

This commit is contained in:
EAimTY 2023-05-04 18:40:53 +09:00
parent c1f6a2ca33
commit c776c53d33
4 changed files with 72 additions and 53 deletions

View File

@ -7,6 +7,7 @@ edition = "2021"
bytes = { version = "1.4.0", default-features = false, features = ["std"] } bytes = { version = "1.4.0", default-features = false, features = ["std"] }
crossbeam-utils = { version = "0.8.15", default-features = false, features = ["std"] } crossbeam-utils = { version = "0.8.15", default-features = false, features = ["std"] }
env_logger = { version = "0.10.0", default-features = false, features = ["humantime"] } env_logger = { version = "0.10.0", default-features = false, features = ["humantime"] }
humantime = { version = "2.1.0", default-features = false }
lexopt = { version = "0.3.0", default-features = false } lexopt = { version = "0.3.0", default-features = false }
log = { version = "0.4.17", default-features = false, features = ["serde", "std"] } log = { version = "0.4.17", default-features = false, features = ["serde", "std"] }
once_cell = { version = "1.17.1", default-features = false, features = ["parking_lot", "std"] } once_cell = { version = "1.17.1", default-features = false, features = ["parking_lot", "std"] }

View File

@ -1,4 +1,5 @@
use crate::utils::{CongestionControl, UdpRelayMode}; use crate::utils::{CongestionControl, UdpRelayMode};
use humantime::Duration as HumanDuration;
use lexopt::{Arg, Error as ArgumentError, Parser}; use lexopt::{Arg, Error as ArgumentError, Parser};
use log::LevelFilter; use log::LevelFilter;
use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde::{de::Error as DeError, Deserialize, Deserializer};
@ -29,7 +30,9 @@ Arguments:
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct Config { pub struct Config {
pub relay: Relay, pub relay: Relay,
pub local: Local, pub local: Local,
#[serde(default = "default::log_level")] #[serde(default = "default::log_level")]
pub log_level: LevelFilter, pub log_level: LevelFilter,
} }
@ -39,36 +42,62 @@ pub struct Config {
pub struct Relay { pub struct Relay {
#[serde(deserialize_with = "deserialize_server")] #[serde(deserialize_with = "deserialize_server")]
pub server: (String, u16), pub server: (String, u16),
pub uuid: Uuid, pub uuid: Uuid,
pub password: String, pub password: String,
pub ip: Option<IpAddr>, pub ip: Option<IpAddr>,
#[serde(default = "default::relay::certificates")] #[serde(default = "default::relay::certificates")]
pub certificates: Vec<PathBuf>, pub certificates: Vec<PathBuf>,
#[serde( #[serde(
default = "default::relay::udp_relay_mode", default = "default::relay::udp_relay_mode",
deserialize_with = "deserialize_from_str" deserialize_with = "deserialize_from_str"
)] )]
pub udp_relay_mode: UdpRelayMode, pub udp_relay_mode: UdpRelayMode,
#[serde( #[serde(
default = "default::relay::congestion_control", default = "default::relay::congestion_control",
deserialize_with = "deserialize_from_str" deserialize_with = "deserialize_from_str"
)] )]
pub congestion_control: CongestionControl, pub congestion_control: CongestionControl,
#[serde(default = "default::relay::alpn")] #[serde(default = "default::relay::alpn")]
pub alpn: Vec<String>, pub alpn: Vec<String>,
#[serde(default = "default::relay::zero_rtt_handshake")] #[serde(default = "default::relay::zero_rtt_handshake")]
pub zero_rtt_handshake: bool, pub zero_rtt_handshake: bool,
#[serde(default = "default::relay::disable_sni")] #[serde(default = "default::relay::disable_sni")]
pub disable_sni: bool, pub disable_sni: bool,
#[serde(default = "default::relay::timeout")]
#[serde(
default = "default::relay::timeout",
deserialize_with = "deserialize_duration"
)]
pub timeout: Duration, pub timeout: Duration,
#[serde(default = "default::relay::heartbeat")]
#[serde(
default = "default::relay::heartbeat",
deserialize_with = "deserialize_duration"
)]
pub heartbeat: Duration, pub heartbeat: Duration,
#[serde(default = "default::relay::disable_native_certs")] #[serde(default = "default::relay::disable_native_certs")]
pub disable_native_certs: bool, pub disable_native_certs: bool,
#[serde(default = "default::relay::gc_interval")]
#[serde(
default = "default::relay::gc_interval",
deserialize_with = "deserialize_duration"
)]
pub gc_interval: Duration, pub gc_interval: Duration,
#[serde(default = "default::relay::gc_lifetime")]
#[serde(
default = "default::relay::gc_lifetime",
deserialize_with = "deserialize_duration"
)]
pub gc_lifetime: Duration, pub gc_lifetime: Duration,
} }
@ -76,9 +105,13 @@ pub struct Relay {
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct Local { pub struct Local {
pub server: SocketAddr, pub server: SocketAddr,
pub username: Option<String>, pub username: Option<String>,
pub password: Option<String>, pub password: Option<String>,
pub dual_stack: Option<bool>, pub dual_stack: Option<bool>,
#[serde(default = "default::local::max_packet_size")] #[serde(default = "default::local::max_packet_size")]
pub max_packet_size: usize, pub max_packet_size: usize,
} }
@ -203,6 +236,17 @@ where
} }
} }
pub fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse::<HumanDuration>()
.map(|d| *d)
.map_err(DeError::custom)
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ConfigError { pub enum ConfigError {
#[error(transparent)] #[error(transparent)]

View File

@ -7,6 +7,7 @@ edition = "2021"
bytes = { version = "1.4.0", default-features = false, features = ["std"] } bytes = { version = "1.4.0", default-features = false, features = ["std"] }
crossbeam-utils = { version = "0.8.15", default-features = false, features = ["std"] } crossbeam-utils = { version = "0.8.15", default-features = false, features = ["std"] }
env_logger = { version = "0.10.0", default-features = false, features = ["humantime"] } env_logger = { version = "0.10.0", default-features = false, features = ["humantime"] }
humantime = { version = "2.1.0", default-features = false }
lexopt = { version = "0.3.0", default-features = false } lexopt = { version = "0.3.0", default-features = false }
log = { version = "0.4.17", default-features = false, features = ["serde", "std"] } log = { version = "0.4.17", default-features = false, features = ["serde", "std"] }
parking_lot = { version = "0.12.1", default-features = false } parking_lot = { version = "0.12.1", default-features = false }

View File

@ -1,4 +1,5 @@
use crate::utils::CongestionControl; use crate::utils::CongestionControl;
use humantime::Duration as HumanDuration;
use lexopt::{Arg, Error as ArgumentError, Parser}; use lexopt::{Arg, Error as ArgumentError, Parser};
use log::LevelFilter; use log::LevelFilter;
use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde::{de::Error as DeError, Deserialize, Deserializer};
@ -23,44 +24,58 @@ Arguments:
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct Config { pub struct Config {
pub server: SocketAddr, pub server: SocketAddr,
#[serde(deserialize_with = "deserialize_users")] #[serde(deserialize_with = "deserialize_users")]
pub users: HashMap<Uuid, String>, pub users: HashMap<Uuid, String>,
pub certificate: PathBuf, pub certificate: PathBuf,
pub private_key: PathBuf, pub private_key: PathBuf,
#[serde( #[serde(
default = "default::congestion_control", default = "default::congestion_control",
deserialize_with = "deserialize_from_str" deserialize_with = "deserialize_from_str"
)] )]
pub congestion_control: CongestionControl, pub congestion_control: CongestionControl,
#[serde(default = "default::alpn")] #[serde(default = "default::alpn")]
pub alpn: Vec<String>, pub alpn: Vec<String>,
#[serde(default = "default::udp_relay_ipv6")] #[serde(default = "default::udp_relay_ipv6")]
pub udp_relay_ipv6: bool, pub udp_relay_ipv6: bool,
#[serde(default = "default::zero_rtt_handshake")] #[serde(default = "default::zero_rtt_handshake")]
pub zero_rtt_handshake: bool, pub zero_rtt_handshake: bool,
pub dual_stack: Option<bool>, pub dual_stack: Option<bool>,
#[serde( #[serde(
default = "default::auth_timeout", default = "default::auth_timeout",
deserialize_with = "deserialize_duration" deserialize_with = "deserialize_duration"
)] )]
pub auth_timeout: Duration, pub auth_timeout: Duration,
#[serde( #[serde(
default = "default::max_idle_time", default = "default::max_idle_time",
deserialize_with = "deserialize_duration" deserialize_with = "deserialize_duration"
)] )]
pub max_idle_time: Duration, pub max_idle_time: Duration,
#[serde(default = "default::max_external_packet_size")] #[serde(default = "default::max_external_packet_size")]
pub max_external_packet_size: usize, pub max_external_packet_size: usize,
#[serde( #[serde(
default = "default::gc_interval", default = "default::gc_interval",
deserialize_with = "deserialize_duration" deserialize_with = "deserialize_duration"
)] )]
pub gc_interval: Duration, pub gc_interval: Duration,
#[serde( #[serde(
default = "default::gc_lifetime", default = "default::gc_lifetime",
deserialize_with = "deserialize_duration" deserialize_with = "deserialize_duration"
)] )]
pub gc_lifetime: Duration, pub gc_lifetime: Duration,
#[serde(default = "default::log_level")] #[serde(default = "default::log_level")]
pub log_level: LevelFilter, pub log_level: LevelFilter,
} }
@ -118,11 +133,11 @@ mod default {
} }
pub fn auth_timeout() -> Duration { pub fn auth_timeout() -> Duration {
Duration::from_secs(10) Duration::from_secs(3)
} }
pub fn max_idle_time() -> Duration { pub fn max_idle_time() -> Duration {
Duration::from_secs(15) Duration::from_secs(10)
} }
pub fn max_external_packet_size() -> usize { pub fn max_external_packet_size() -> usize {
@ -165,57 +180,15 @@ where
Ok(map) Ok(map)
} }
fn parse_duration(s: &str) -> Result<Duration, String> {
let mut num = Vec::with_capacity(8);
let mut chars = Vec::with_capacity(2);
let mut expected_unit = false;
for c in s.chars() {
if !expected_unit && c.is_numeric() {
num.push(c);
continue;
}
if !expected_unit && c.is_ascii() {
expected_unit = true;
}
chars.push(c);
}
let n: u64 = num
.into_iter()
.collect::<String>()
.parse()
.map_err(|e| format!("invalid value: {}, reason {}", &s, e))?;
match chars.into_iter().collect::<String>().as_str() {
"" => Ok(Duration::from_millis(n)),
"s" => Ok(Duration::from_secs(n)),
"ms" => Ok(Duration::from_millis(n)),
_ => Err(format!("invalid value: {}, expected 10s or 10ms", &s)),
}
}
#[test]
fn test_parseduration() {
// test parse
assert_eq!(parse_duration("100s"), Ok(Duration::from_secs(100)));
assert_eq!(parse_duration("100ms"), Ok(Duration::from_millis(100)));
// test default unit
assert_eq!(parse_duration("10000"), Ok(Duration::from_millis(10000)));
// test invalid data
assert!(parse_duration("").is_err());
assert!(parse_duration("1ms100").is_err());
assert!(parse_duration("ms").is_err());
}
pub fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error> pub fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let s: String = String::deserialize(deserializer)?; let s = String::deserialize(deserializer)?;
parse_duration(&s).map_err(DeError::custom)
s.parse::<HumanDuration>()
.map(|d| *d)
.map_err(DeError::custom)
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]