From 71bc8e9b2bd2e47f98bcf4da0673f9730503cb6d Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 18:00:20 +0900 Subject: [PATCH] implement server config reading --- tuic-server/src/config.rs | 76 +++++++++++++++++++++++- tuic-server/src/main.rs | 5 ++ tuic-server/src/server.rs | 122 ++++++++++++++++++++++++++++++++++---- tuic-server/src/utils.rs | 42 ++++++++++++- 4 files changed, 230 insertions(+), 15 deletions(-) diff --git a/tuic-server/src/config.rs b/tuic-server/src/config.rs index c08805d..89ff4e1 100644 --- a/tuic-server/src/config.rs +++ b/tuic-server/src/config.rs @@ -1,7 +1,11 @@ +use crate::utils::CongestionControl; use lexopt::{Arg, Error as ArgumentError, Parser}; use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde_json::Error as SerdeError; -use std::{env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, str::FromStr}; +use std::{ + env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, net::SocketAddr, path::PathBuf, + str::FromStr, time::Duration, +}; use thiserror::Error; const HELP_MSG: &str = r#" @@ -15,7 +19,34 @@ Arguments: #[derive(Deserialize)] #[serde(deny_unknown_fields)] -pub struct Config {} +pub struct Config { + pub server: SocketAddr, + pub token: String, + pub certificate: PathBuf, + pub private_key: PathBuf, + #[serde( + default = "default::congestion_control", + deserialize_with = "deserialize_from_str" + )] + pub congestion_control: CongestionControl, + #[serde(default = "default::alpn")] + pub alpn: Vec, + #[serde(default = "default::udp_relay_ipv6")] + pub udp_relay_ipv6: bool, + #[serde(default = "default::zero_rtt_handshake")] + pub zero_rtt_handshake: bool, + pub dual_stack: Option, + #[serde(default = "default::auth_timeout")] + pub auth_timeout: Duration, + #[serde(default = "default::max_idle_time")] + pub max_idle_time: Duration, + #[serde(default = "default::max_external_packet_size")] + pub max_external_packet_size: usize, + #[serde(default = "default::gc_interval")] + pub gc_interval: Duration, + #[serde(default = "default::gc_lifetime")] + pub gc_lifetime: Duration, +} impl Config { pub fn parse(args: ArgsOs) -> Result { @@ -48,7 +79,46 @@ impl Config { } } -mod default {} +mod default { + use crate::utils::CongestionControl; + use std::time::Duration; + + pub fn congestion_control() -> CongestionControl { + CongestionControl::Cubic + } + + pub fn alpn() -> Vec { + Vec::new() + } + + pub fn udp_relay_ipv6() -> bool { + true + } + + pub fn zero_rtt_handshake() -> bool { + false + } + + pub fn auth_timeout() -> Duration { + Duration::from_secs(3) + } + + pub fn max_idle_time() -> Duration { + Duration::from_secs(15) + } + + pub fn max_external_packet_size() -> usize { + 1500 + } + + pub fn gc_interval() -> Duration { + Duration::from_secs(3) + } + + pub fn gc_lifetime() -> Duration { + Duration::from_secs(15) + } +} pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result where diff --git a/tuic-server/src/main.rs b/tuic-server/src/main.rs index 26c7dc5..31745c2 100644 --- a/tuic-server/src/main.rs +++ b/tuic-server/src/main.rs @@ -3,6 +3,7 @@ use self::{ server::Server, }; use quinn::ConnectionError; +use rustls::Error as RustlsError; use std::{env, io::Error as IoError, net::SocketAddr, process}; use thiserror::Error; use tuic::Address; @@ -40,6 +41,10 @@ pub enum Error { #[error(transparent)] Io(#[from] IoError), #[error(transparent)] + Rustls(#[from] RustlsError), + #[error("invalid max idle time")] + InvalidMaxIdleTime, + #[error(transparent)] Connection(#[from] ConnectionError), #[error(transparent)] Model(#[from] ModelError), diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 88537e1..892693f 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -1,15 +1,24 @@ -use crate::{config::Config, utils::UdpRelayMode, Error}; +use crate::{ + config::Config, + utils::{self, CongestionControl, UdpRelayMode}, + Error, +}; use bytes::Bytes; use crossbeam_utils::atomic::AtomicCell; use parking_lot::Mutex; -use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; +use quinn::{ + congestion::{BbrConfig, CubicConfig, NewRenoConfig}, + Connecting, Connection as QuinnConnection, Endpoint, EndpointConfig, IdleTimeout, RecvStream, + SendStream, ServerConfig, TokioRuntime, TransportConfig, VarInt, +}; use register_count::{Counter, Register}; +use rustls::{version, ServerConfig as RustlsServerConfig}; use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use std::{ collections::{hash_map::Entry, HashMap}, future::Future, io::{Error as IoError, ErrorKind}, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, pin::Pin, sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -39,13 +48,83 @@ pub struct Server { udp_relay_ipv6: bool, zero_rtt_handshake: bool, auth_timeout: Duration, + max_external_pkt_size: usize, gc_interval: Duration, gc_lifetime: Duration, } impl Server { pub fn init(cfg: Config) -> Result { - todo!() + let certs = utils::load_certs(cfg.certificate)?; + let priv_key = utils::load_priv_key(cfg.private_key)?; + + let mut crypto = RustlsServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&version::TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(certs, priv_key)?; + + crypto.alpn_protocols = cfg.alpn.into_iter().map(|alpn| alpn.into_bytes()).collect(); + crypto.max_early_data_size = u32::MAX; + crypto.send_half_rtt_data = cfg.zero_rtt_handshake; + + let mut config = ServerConfig::with_crypto(Arc::new(crypto)); + let mut tp_cfg = TransportConfig::default(); + + tp_cfg + .max_concurrent_bidi_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32)) + .max_concurrent_uni_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32)) + .max_idle_timeout(Some( + IdleTimeout::try_from(cfg.max_idle_time).map_err(|_| Error::InvalidMaxIdleTime)?, + )); + + match cfg.congestion_control { + CongestionControl::Cubic => { + tp_cfg.congestion_controller_factory(Arc::new(CubicConfig::default())) + } + CongestionControl::NewReno => { + tp_cfg.congestion_controller_factory(Arc::new(NewRenoConfig::default())) + } + CongestionControl::Bbr => { + tp_cfg.congestion_controller_factory(Arc::new(BbrConfig::default())) + } + }; + + config.transport_config(Arc::new(tp_cfg)); + + let domain = match cfg.server.ip() { + IpAddr::V4(_) => Domain::IPV4, + IpAddr::V6(_) => Domain::IPV6, + }; + + let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + + if let Some(dual_stack) = cfg.dual_stack { + socket.set_only_v6(!dual_stack)?; + } + + socket.bind(&SockAddr::from(cfg.server))?; + let socket = StdUdpSocket::from(socket); + + let ep = Endpoint::new( + EndpointConfig::default(), + Some(config), + socket, + TokioRuntime, + )?; + + Ok(Self { + ep, + token: Arc::from(cfg.token.into_bytes().into_boxed_slice()), + udp_relay_ipv6: cfg.udp_relay_ipv6, + zero_rtt_handshake: cfg.zero_rtt_handshake, + auth_timeout: cfg.auth_timeout, + max_external_pkt_size: cfg.max_external_packet_size, + gc_interval: cfg.gc_interval, + gc_lifetime: cfg.gc_lifetime, + }) } pub async fn start(&self) { @@ -58,6 +137,7 @@ impl Server { self.udp_relay_ipv6, self.zero_rtt_handshake, self.auth_timeout, + self.max_external_pkt_size, self.gc_interval, self.gc_lifetime, )); @@ -74,6 +154,7 @@ struct Connection { is_authed: IsAuthed, udp_sessions: Arc>>, udp_relay_mode: Arc>>, + max_external_pkt_size: usize, remote_uni_stream_cnt: Counter, remote_bi_stream_cnt: Counter, max_concurrent_uni_streams: Arc, @@ -87,10 +168,19 @@ impl Connection { udp_relay_ipv6: bool, zero_rtt_handshake: bool, auth_timeout: Duration, + max_external_pkt_size: usize, gc_interval: Duration, gc_lifetime: Duration, ) { - match Self::init(conn, token, udp_relay_ipv6, zero_rtt_handshake).await { + match Self::init( + conn, + token, + udp_relay_ipv6, + zero_rtt_handshake, + max_external_pkt_size, + ) + .await + { Ok(conn) => { tokio::spawn(conn.clone().handle_auth_timeout(auth_timeout)); tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime)); @@ -115,6 +205,7 @@ impl Connection { token: Arc<[u8]>, udp_relay_ipv6: bool, zero_rtt_handshake: bool, + max_external_pkt_size: usize, ) -> Result { let conn = if zero_rtt_handshake { match conn.into_0rtt() { @@ -136,6 +227,7 @@ impl Connection { is_authed: IsAuthed::new(), udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_relay_mode: Arc::new(AtomicCell::new(None)), + max_external_pkt_size, remote_uni_stream_cnt: Counter::new(), remote_bi_stream_cnt: Counter::new(), max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), @@ -494,7 +586,11 @@ impl UdpSession { _ = cancel => {} () = async { loop { - match Self::accept(&socket_v4, socket_v6.as_deref()).await { + match Self::accept( + &socket_v4, + socket_v6.as_deref(), + conn.max_external_pkt_size, + ).await { Ok((pkt, addr)) => { tokio::spawn(send_pkt(conn.clone(), pkt, addr, assoc_id)); } @@ -508,9 +604,13 @@ impl UdpSession { async fn accept( socket_v4: &UdpSocket, socket_v6: Option<&UdpSocket>, + max_pkt_size: usize, ) -> Result<(Bytes, SocketAddr), IoError> { - async fn read_pkt(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> { - let mut buf = vec![0u8; 65535]; + async fn read_pkt( + socket: &UdpSocket, + max_pkt_size: usize, + ) -> Result<(Bytes, SocketAddr), IoError> { + let mut buf = vec![0u8; max_pkt_size]; let (n, addr) = socket.recv_from(&mut buf).await?; buf.truncate(n); Ok((Bytes::from(buf), addr)) @@ -518,11 +618,11 @@ impl UdpSession { if let Some(socket_v6) = socket_v6 { tokio::select! { - res = read_pkt(socket_v4) => res, - res = read_pkt(socket_v6) => res, + res = read_pkt(socket_v4, max_pkt_size) => res, + res = read_pkt(socket_v6, max_pkt_size) => res, } } else { - read_pkt(socket_v4).await + read_pkt(socket_v4, max_pkt_size).await } } } diff --git a/tuic-server/src/utils.rs b/tuic-server/src/utils.rs index d6849a2..31061c7 100644 --- a/tuic-server/src/utils.rs +++ b/tuic-server/src/utils.rs @@ -1,4 +1,44 @@ -use std::str::FromStr; +use rustls::{Certificate, PrivateKey}; +use rustls_pemfile::Item; +use std::{ + fs::{self, File}, + io::{BufReader, Error as IoError}, + path::PathBuf, + str::FromStr, +}; + +pub fn load_certs(path: PathBuf) -> Result, IoError> { + let mut file = BufReader::new(File::open(&path)?); + let mut certs = Vec::new(); + + while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) { + if let Item::X509Certificate(cert) = item { + certs.push(Certificate(cert)); + } + } + + if certs.is_empty() { + certs = vec![Certificate(fs::read(&path)?)]; + } + + Ok(certs) +} + +pub fn load_priv_key(path: PathBuf) -> Result { + let mut file = BufReader::new(File::open(&path)?); + let mut priv_key = None; + + while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) { + if let Item::RSAKey(key) | Item::PKCS8Key(key) | Item::ECKey(key) = item { + priv_key = Some(key); + } + } + + priv_key + .map(Ok) + .unwrap_or_else(|| fs::read(&path)) + .map(PrivateKey) +} #[derive(Clone, Copy)] pub enum UdpRelayMode {