1
0

implement server config reading

This commit is contained in:
EAimTY 2023-02-04 18:00:20 +09:00
parent 53c95d9860
commit 71bc8e9b2b
4 changed files with 230 additions and 15 deletions

View File

@ -1,7 +1,11 @@
use crate::utils::CongestionControl;
use lexopt::{Arg, Error as ArgumentError, Parser}; use lexopt::{Arg, Error as ArgumentError, Parser};
use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde::{de::Error as DeError, Deserialize, Deserializer};
use serde_json::Error as SerdeError; use serde_json::Error as SerdeError;
use std::{env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, str::FromStr}; use std::{
env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, net::SocketAddr, path::PathBuf,
str::FromStr, time::Duration,
};
use thiserror::Error; use thiserror::Error;
const HELP_MSG: &str = r#" const HELP_MSG: &str = r#"
@ -15,7 +19,34 @@ Arguments:
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct Config {} pub struct Config {
pub server: SocketAddr,
pub token: String,
pub certificate: PathBuf,
pub private_key: PathBuf,
#[serde(
default = "default::congestion_control",
deserialize_with = "deserialize_from_str"
)]
pub congestion_control: CongestionControl,
#[serde(default = "default::alpn")]
pub alpn: Vec<String>,
#[serde(default = "default::udp_relay_ipv6")]
pub udp_relay_ipv6: bool,
#[serde(default = "default::zero_rtt_handshake")]
pub zero_rtt_handshake: bool,
pub dual_stack: Option<bool>,
#[serde(default = "default::auth_timeout")]
pub auth_timeout: Duration,
#[serde(default = "default::max_idle_time")]
pub max_idle_time: Duration,
#[serde(default = "default::max_external_packet_size")]
pub max_external_packet_size: usize,
#[serde(default = "default::gc_interval")]
pub gc_interval: Duration,
#[serde(default = "default::gc_lifetime")]
pub gc_lifetime: Duration,
}
impl Config { impl Config {
pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> { pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> {
@ -48,7 +79,46 @@ impl Config {
} }
} }
mod default {} mod default {
use crate::utils::CongestionControl;
use std::time::Duration;
pub fn congestion_control() -> CongestionControl {
CongestionControl::Cubic
}
pub fn alpn() -> Vec<String> {
Vec::new()
}
pub fn udp_relay_ipv6() -> bool {
true
}
pub fn zero_rtt_handshake() -> bool {
false
}
pub fn auth_timeout() -> Duration {
Duration::from_secs(3)
}
pub fn max_idle_time() -> Duration {
Duration::from_secs(15)
}
pub fn max_external_packet_size() -> usize {
1500
}
pub fn gc_interval() -> Duration {
Duration::from_secs(3)
}
pub fn gc_lifetime() -> Duration {
Duration::from_secs(15)
}
}
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error> pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where where

View File

@ -3,6 +3,7 @@ use self::{
server::Server, server::Server,
}; };
use quinn::ConnectionError; use quinn::ConnectionError;
use rustls::Error as RustlsError;
use std::{env, io::Error as IoError, net::SocketAddr, process}; use std::{env, io::Error as IoError, net::SocketAddr, process};
use thiserror::Error; use thiserror::Error;
use tuic::Address; use tuic::Address;
@ -40,6 +41,10 @@ pub enum Error {
#[error(transparent)] #[error(transparent)]
Io(#[from] IoError), Io(#[from] IoError),
#[error(transparent)] #[error(transparent)]
Rustls(#[from] RustlsError),
#[error("invalid max idle time")]
InvalidMaxIdleTime,
#[error(transparent)]
Connection(#[from] ConnectionError), Connection(#[from] ConnectionError),
#[error(transparent)] #[error(transparent)]
Model(#[from] ModelError), Model(#[from] ModelError),

View File

@ -1,15 +1,24 @@
use crate::{config::Config, utils::UdpRelayMode, Error}; use crate::{
config::Config,
utils::{self, CongestionControl, UdpRelayMode},
Error,
};
use bytes::Bytes; use bytes::Bytes;
use crossbeam_utils::atomic::AtomicCell; use crossbeam_utils::atomic::AtomicCell;
use parking_lot::Mutex; use parking_lot::Mutex;
use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; use quinn::{
congestion::{BbrConfig, CubicConfig, NewRenoConfig},
Connecting, Connection as QuinnConnection, Endpoint, EndpointConfig, IdleTimeout, RecvStream,
SendStream, ServerConfig, TokioRuntime, TransportConfig, VarInt,
};
use register_count::{Counter, Register}; use register_count::{Counter, Register};
use rustls::{version, ServerConfig as RustlsServerConfig};
use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use std::{ use std::{
collections::{hash_map::Entry, HashMap}, collections::{hash_map::Entry, HashMap},
future::Future, future::Future,
io::{Error as IoError, ErrorKind}, io::{Error as IoError, ErrorKind},
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket},
pin::Pin, pin::Pin,
sync::{ sync::{
atomic::{AtomicBool, AtomicUsize, Ordering}, atomic::{AtomicBool, AtomicUsize, Ordering},
@ -39,13 +48,83 @@ pub struct Server {
udp_relay_ipv6: bool, udp_relay_ipv6: bool,
zero_rtt_handshake: bool, zero_rtt_handshake: bool,
auth_timeout: Duration, auth_timeout: Duration,
max_external_pkt_size: usize,
gc_interval: Duration, gc_interval: Duration,
gc_lifetime: Duration, gc_lifetime: Duration,
} }
impl Server { impl Server {
pub fn init(cfg: Config) -> Result<Self, Error> { pub fn init(cfg: Config) -> Result<Self, Error> {
todo!() let certs = utils::load_certs(cfg.certificate)?;
let priv_key = utils::load_priv_key(cfg.private_key)?;
let mut crypto = RustlsServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&version::TLS13])
.unwrap()
.with_no_client_auth()
.with_single_cert(certs, priv_key)?;
crypto.alpn_protocols = cfg.alpn.into_iter().map(|alpn| alpn.into_bytes()).collect();
crypto.max_early_data_size = u32::MAX;
crypto.send_half_rtt_data = cfg.zero_rtt_handshake;
let mut config = ServerConfig::with_crypto(Arc::new(crypto));
let mut tp_cfg = TransportConfig::default();
tp_cfg
.max_concurrent_bidi_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32))
.max_concurrent_uni_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32))
.max_idle_timeout(Some(
IdleTimeout::try_from(cfg.max_idle_time).map_err(|_| Error::InvalidMaxIdleTime)?,
));
match cfg.congestion_control {
CongestionControl::Cubic => {
tp_cfg.congestion_controller_factory(Arc::new(CubicConfig::default()))
}
CongestionControl::NewReno => {
tp_cfg.congestion_controller_factory(Arc::new(NewRenoConfig::default()))
}
CongestionControl::Bbr => {
tp_cfg.congestion_controller_factory(Arc::new(BbrConfig::default()))
}
};
config.transport_config(Arc::new(tp_cfg));
let domain = match cfg.server.ip() {
IpAddr::V4(_) => Domain::IPV4,
IpAddr::V6(_) => Domain::IPV6,
};
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if let Some(dual_stack) = cfg.dual_stack {
socket.set_only_v6(!dual_stack)?;
}
socket.bind(&SockAddr::from(cfg.server))?;
let socket = StdUdpSocket::from(socket);
let ep = Endpoint::new(
EndpointConfig::default(),
Some(config),
socket,
TokioRuntime,
)?;
Ok(Self {
ep,
token: Arc::from(cfg.token.into_bytes().into_boxed_slice()),
udp_relay_ipv6: cfg.udp_relay_ipv6,
zero_rtt_handshake: cfg.zero_rtt_handshake,
auth_timeout: cfg.auth_timeout,
max_external_pkt_size: cfg.max_external_packet_size,
gc_interval: cfg.gc_interval,
gc_lifetime: cfg.gc_lifetime,
})
} }
pub async fn start(&self) { pub async fn start(&self) {
@ -58,6 +137,7 @@ impl Server {
self.udp_relay_ipv6, self.udp_relay_ipv6,
self.zero_rtt_handshake, self.zero_rtt_handshake,
self.auth_timeout, self.auth_timeout,
self.max_external_pkt_size,
self.gc_interval, self.gc_interval,
self.gc_lifetime, self.gc_lifetime,
)); ));
@ -74,6 +154,7 @@ struct Connection {
is_authed: IsAuthed, is_authed: IsAuthed,
udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>, udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>,
udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>, udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>,
max_external_pkt_size: usize,
remote_uni_stream_cnt: Counter, remote_uni_stream_cnt: Counter,
remote_bi_stream_cnt: Counter, remote_bi_stream_cnt: Counter,
max_concurrent_uni_streams: Arc<AtomicUsize>, max_concurrent_uni_streams: Arc<AtomicUsize>,
@ -87,10 +168,19 @@ impl Connection {
udp_relay_ipv6: bool, udp_relay_ipv6: bool,
zero_rtt_handshake: bool, zero_rtt_handshake: bool,
auth_timeout: Duration, auth_timeout: Duration,
max_external_pkt_size: usize,
gc_interval: Duration, gc_interval: Duration,
gc_lifetime: Duration, gc_lifetime: Duration,
) { ) {
match Self::init(conn, token, udp_relay_ipv6, zero_rtt_handshake).await { match Self::init(
conn,
token,
udp_relay_ipv6,
zero_rtt_handshake,
max_external_pkt_size,
)
.await
{
Ok(conn) => { Ok(conn) => {
tokio::spawn(conn.clone().handle_auth_timeout(auth_timeout)); tokio::spawn(conn.clone().handle_auth_timeout(auth_timeout));
tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime)); tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime));
@ -115,6 +205,7 @@ impl Connection {
token: Arc<[u8]>, token: Arc<[u8]>,
udp_relay_ipv6: bool, udp_relay_ipv6: bool,
zero_rtt_handshake: bool, zero_rtt_handshake: bool,
max_external_pkt_size: usize,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let conn = if zero_rtt_handshake { let conn = if zero_rtt_handshake {
match conn.into_0rtt() { match conn.into_0rtt() {
@ -136,6 +227,7 @@ impl Connection {
is_authed: IsAuthed::new(), is_authed: IsAuthed::new(),
udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())),
udp_relay_mode: Arc::new(AtomicCell::new(None)), udp_relay_mode: Arc::new(AtomicCell::new(None)),
max_external_pkt_size,
remote_uni_stream_cnt: Counter::new(), remote_uni_stream_cnt: Counter::new(),
remote_bi_stream_cnt: Counter::new(), remote_bi_stream_cnt: Counter::new(),
max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)),
@ -494,7 +586,11 @@ impl UdpSession {
_ = cancel => {} _ = cancel => {}
() = async { () = async {
loop { loop {
match Self::accept(&socket_v4, socket_v6.as_deref()).await { match Self::accept(
&socket_v4,
socket_v6.as_deref(),
conn.max_external_pkt_size,
).await {
Ok((pkt, addr)) => { Ok((pkt, addr)) => {
tokio::spawn(send_pkt(conn.clone(), pkt, addr, assoc_id)); tokio::spawn(send_pkt(conn.clone(), pkt, addr, assoc_id));
} }
@ -508,9 +604,13 @@ impl UdpSession {
async fn accept( async fn accept(
socket_v4: &UdpSocket, socket_v4: &UdpSocket,
socket_v6: Option<&UdpSocket>, socket_v6: Option<&UdpSocket>,
max_pkt_size: usize,
) -> Result<(Bytes, SocketAddr), IoError> { ) -> Result<(Bytes, SocketAddr), IoError> {
async fn read_pkt(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> { async fn read_pkt(
let mut buf = vec![0u8; 65535]; socket: &UdpSocket,
max_pkt_size: usize,
) -> Result<(Bytes, SocketAddr), IoError> {
let mut buf = vec![0u8; max_pkt_size];
let (n, addr) = socket.recv_from(&mut buf).await?; let (n, addr) = socket.recv_from(&mut buf).await?;
buf.truncate(n); buf.truncate(n);
Ok((Bytes::from(buf), addr)) Ok((Bytes::from(buf), addr))
@ -518,11 +618,11 @@ impl UdpSession {
if let Some(socket_v6) = socket_v6 { if let Some(socket_v6) = socket_v6 {
tokio::select! { tokio::select! {
res = read_pkt(socket_v4) => res, res = read_pkt(socket_v4, max_pkt_size) => res,
res = read_pkt(socket_v6) => res, res = read_pkt(socket_v6, max_pkt_size) => res,
} }
} else { } else {
read_pkt(socket_v4).await read_pkt(socket_v4, max_pkt_size).await
} }
} }
} }

View File

@ -1,4 +1,44 @@
use std::str::FromStr; use rustls::{Certificate, PrivateKey};
use rustls_pemfile::Item;
use std::{
fs::{self, File},
io::{BufReader, Error as IoError},
path::PathBuf,
str::FromStr,
};
pub fn load_certs(path: PathBuf) -> Result<Vec<Certificate>, IoError> {
let mut file = BufReader::new(File::open(&path)?);
let mut certs = Vec::new();
while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) {
if let Item::X509Certificate(cert) = item {
certs.push(Certificate(cert));
}
}
if certs.is_empty() {
certs = vec![Certificate(fs::read(&path)?)];
}
Ok(certs)
}
pub fn load_priv_key(path: PathBuf) -> Result<PrivateKey, IoError> {
let mut file = BufReader::new(File::open(&path)?);
let mut priv_key = None;
while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) {
if let Item::RSAKey(key) | Item::PKCS8Key(key) | Item::ECKey(key) = item {
priv_key = Some(key);
}
}
priv_key
.map(Ok)
.unwrap_or_else(|| fs::read(&path))
.map(PrivateKey)
}
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
pub enum UdpRelayMode { pub enum UdpRelayMode {