implement server config reading
This commit is contained in:
parent
53c95d9860
commit
71bc8e9b2b
@ -1,7 +1,11 @@
|
||||
use crate::utils::CongestionControl;
|
||||
use lexopt::{Arg, Error as ArgumentError, Parser};
|
||||
use serde::{de::Error as DeError, Deserialize, Deserializer};
|
||||
use serde_json::Error as SerdeError;
|
||||
use std::{env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, str::FromStr};
|
||||
use std::{
|
||||
env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, net::SocketAddr, path::PathBuf,
|
||||
str::FromStr, time::Duration,
|
||||
};
|
||||
use thiserror::Error;
|
||||
|
||||
const HELP_MSG: &str = r#"
|
||||
@ -15,7 +19,34 @@ Arguments:
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct Config {}
|
||||
pub struct Config {
|
||||
pub server: SocketAddr,
|
||||
pub token: String,
|
||||
pub certificate: PathBuf,
|
||||
pub private_key: PathBuf,
|
||||
#[serde(
|
||||
default = "default::congestion_control",
|
||||
deserialize_with = "deserialize_from_str"
|
||||
)]
|
||||
pub congestion_control: CongestionControl,
|
||||
#[serde(default = "default::alpn")]
|
||||
pub alpn: Vec<String>,
|
||||
#[serde(default = "default::udp_relay_ipv6")]
|
||||
pub udp_relay_ipv6: bool,
|
||||
#[serde(default = "default::zero_rtt_handshake")]
|
||||
pub zero_rtt_handshake: bool,
|
||||
pub dual_stack: Option<bool>,
|
||||
#[serde(default = "default::auth_timeout")]
|
||||
pub auth_timeout: Duration,
|
||||
#[serde(default = "default::max_idle_time")]
|
||||
pub max_idle_time: Duration,
|
||||
#[serde(default = "default::max_external_packet_size")]
|
||||
pub max_external_packet_size: usize,
|
||||
#[serde(default = "default::gc_interval")]
|
||||
pub gc_interval: Duration,
|
||||
#[serde(default = "default::gc_lifetime")]
|
||||
pub gc_lifetime: Duration,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> {
|
||||
@ -48,7 +79,46 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
mod default {}
|
||||
mod default {
|
||||
use crate::utils::CongestionControl;
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn congestion_control() -> CongestionControl {
|
||||
CongestionControl::Cubic
|
||||
}
|
||||
|
||||
pub fn alpn() -> Vec<String> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
pub fn udp_relay_ipv6() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn zero_rtt_handshake() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn auth_timeout() -> Duration {
|
||||
Duration::from_secs(3)
|
||||
}
|
||||
|
||||
pub fn max_idle_time() -> Duration {
|
||||
Duration::from_secs(15)
|
||||
}
|
||||
|
||||
pub fn max_external_packet_size() -> usize {
|
||||
1500
|
||||
}
|
||||
|
||||
pub fn gc_interval() -> Duration {
|
||||
Duration::from_secs(3)
|
||||
}
|
||||
|
||||
pub fn gc_lifetime() -> Duration {
|
||||
Duration::from_secs(15)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
|
@ -3,6 +3,7 @@ use self::{
|
||||
server::Server,
|
||||
};
|
||||
use quinn::ConnectionError;
|
||||
use rustls::Error as RustlsError;
|
||||
use std::{env, io::Error as IoError, net::SocketAddr, process};
|
||||
use thiserror::Error;
|
||||
use tuic::Address;
|
||||
@ -40,6 +41,10 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Io(#[from] IoError),
|
||||
#[error(transparent)]
|
||||
Rustls(#[from] RustlsError),
|
||||
#[error("invalid max idle time")]
|
||||
InvalidMaxIdleTime,
|
||||
#[error(transparent)]
|
||||
Connection(#[from] ConnectionError),
|
||||
#[error(transparent)]
|
||||
Model(#[from] ModelError),
|
||||
|
@ -1,15 +1,24 @@
|
||||
use crate::{config::Config, utils::UdpRelayMode, Error};
|
||||
use crate::{
|
||||
config::Config,
|
||||
utils::{self, CongestionControl, UdpRelayMode},
|
||||
Error,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use parking_lot::Mutex;
|
||||
use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt};
|
||||
use quinn::{
|
||||
congestion::{BbrConfig, CubicConfig, NewRenoConfig},
|
||||
Connecting, Connection as QuinnConnection, Endpoint, EndpointConfig, IdleTimeout, RecvStream,
|
||||
SendStream, ServerConfig, TokioRuntime, TransportConfig, VarInt,
|
||||
};
|
||||
use register_count::{Counter, Register};
|
||||
use rustls::{version, ServerConfig as RustlsServerConfig};
|
||||
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
future::Future,
|
||||
io::{Error as IoError, ErrorKind},
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket},
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket},
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicUsize, Ordering},
|
||||
@ -39,13 +48,83 @@ pub struct Server {
|
||||
udp_relay_ipv6: bool,
|
||||
zero_rtt_handshake: bool,
|
||||
auth_timeout: Duration,
|
||||
max_external_pkt_size: usize,
|
||||
gc_interval: Duration,
|
||||
gc_lifetime: Duration,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub fn init(cfg: Config) -> Result<Self, Error> {
|
||||
todo!()
|
||||
let certs = utils::load_certs(cfg.certificate)?;
|
||||
let priv_key = utils::load_priv_key(cfg.private_key)?;
|
||||
|
||||
let mut crypto = RustlsServerConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
.with_safe_default_kx_groups()
|
||||
.with_protocol_versions(&[&version::TLS13])
|
||||
.unwrap()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(certs, priv_key)?;
|
||||
|
||||
crypto.alpn_protocols = cfg.alpn.into_iter().map(|alpn| alpn.into_bytes()).collect();
|
||||
crypto.max_early_data_size = u32::MAX;
|
||||
crypto.send_half_rtt_data = cfg.zero_rtt_handshake;
|
||||
|
||||
let mut config = ServerConfig::with_crypto(Arc::new(crypto));
|
||||
let mut tp_cfg = TransportConfig::default();
|
||||
|
||||
tp_cfg
|
||||
.max_concurrent_bidi_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32))
|
||||
.max_concurrent_uni_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32))
|
||||
.max_idle_timeout(Some(
|
||||
IdleTimeout::try_from(cfg.max_idle_time).map_err(|_| Error::InvalidMaxIdleTime)?,
|
||||
));
|
||||
|
||||
match cfg.congestion_control {
|
||||
CongestionControl::Cubic => {
|
||||
tp_cfg.congestion_controller_factory(Arc::new(CubicConfig::default()))
|
||||
}
|
||||
CongestionControl::NewReno => {
|
||||
tp_cfg.congestion_controller_factory(Arc::new(NewRenoConfig::default()))
|
||||
}
|
||||
CongestionControl::Bbr => {
|
||||
tp_cfg.congestion_controller_factory(Arc::new(BbrConfig::default()))
|
||||
}
|
||||
};
|
||||
|
||||
config.transport_config(Arc::new(tp_cfg));
|
||||
|
||||
let domain = match cfg.server.ip() {
|
||||
IpAddr::V4(_) => Domain::IPV4,
|
||||
IpAddr::V6(_) => Domain::IPV6,
|
||||
};
|
||||
|
||||
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
|
||||
|
||||
if let Some(dual_stack) = cfg.dual_stack {
|
||||
socket.set_only_v6(!dual_stack)?;
|
||||
}
|
||||
|
||||
socket.bind(&SockAddr::from(cfg.server))?;
|
||||
let socket = StdUdpSocket::from(socket);
|
||||
|
||||
let ep = Endpoint::new(
|
||||
EndpointConfig::default(),
|
||||
Some(config),
|
||||
socket,
|
||||
TokioRuntime,
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
ep,
|
||||
token: Arc::from(cfg.token.into_bytes().into_boxed_slice()),
|
||||
udp_relay_ipv6: cfg.udp_relay_ipv6,
|
||||
zero_rtt_handshake: cfg.zero_rtt_handshake,
|
||||
auth_timeout: cfg.auth_timeout,
|
||||
max_external_pkt_size: cfg.max_external_packet_size,
|
||||
gc_interval: cfg.gc_interval,
|
||||
gc_lifetime: cfg.gc_lifetime,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn start(&self) {
|
||||
@ -58,6 +137,7 @@ impl Server {
|
||||
self.udp_relay_ipv6,
|
||||
self.zero_rtt_handshake,
|
||||
self.auth_timeout,
|
||||
self.max_external_pkt_size,
|
||||
self.gc_interval,
|
||||
self.gc_lifetime,
|
||||
));
|
||||
@ -74,6 +154,7 @@ struct Connection {
|
||||
is_authed: IsAuthed,
|
||||
udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>,
|
||||
udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>,
|
||||
max_external_pkt_size: usize,
|
||||
remote_uni_stream_cnt: Counter,
|
||||
remote_bi_stream_cnt: Counter,
|
||||
max_concurrent_uni_streams: Arc<AtomicUsize>,
|
||||
@ -87,10 +168,19 @@ impl Connection {
|
||||
udp_relay_ipv6: bool,
|
||||
zero_rtt_handshake: bool,
|
||||
auth_timeout: Duration,
|
||||
max_external_pkt_size: usize,
|
||||
gc_interval: Duration,
|
||||
gc_lifetime: Duration,
|
||||
) {
|
||||
match Self::init(conn, token, udp_relay_ipv6, zero_rtt_handshake).await {
|
||||
match Self::init(
|
||||
conn,
|
||||
token,
|
||||
udp_relay_ipv6,
|
||||
zero_rtt_handshake,
|
||||
max_external_pkt_size,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(conn) => {
|
||||
tokio::spawn(conn.clone().handle_auth_timeout(auth_timeout));
|
||||
tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime));
|
||||
@ -115,6 +205,7 @@ impl Connection {
|
||||
token: Arc<[u8]>,
|
||||
udp_relay_ipv6: bool,
|
||||
zero_rtt_handshake: bool,
|
||||
max_external_pkt_size: usize,
|
||||
) -> Result<Self, Error> {
|
||||
let conn = if zero_rtt_handshake {
|
||||
match conn.into_0rtt() {
|
||||
@ -136,6 +227,7 @@ impl Connection {
|
||||
is_authed: IsAuthed::new(),
|
||||
udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||
udp_relay_mode: Arc::new(AtomicCell::new(None)),
|
||||
max_external_pkt_size,
|
||||
remote_uni_stream_cnt: Counter::new(),
|
||||
remote_bi_stream_cnt: Counter::new(),
|
||||
max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)),
|
||||
@ -494,7 +586,11 @@ impl UdpSession {
|
||||
_ = cancel => {}
|
||||
() = async {
|
||||
loop {
|
||||
match Self::accept(&socket_v4, socket_v6.as_deref()).await {
|
||||
match Self::accept(
|
||||
&socket_v4,
|
||||
socket_v6.as_deref(),
|
||||
conn.max_external_pkt_size,
|
||||
).await {
|
||||
Ok((pkt, addr)) => {
|
||||
tokio::spawn(send_pkt(conn.clone(), pkt, addr, assoc_id));
|
||||
}
|
||||
@ -508,9 +604,13 @@ impl UdpSession {
|
||||
async fn accept(
|
||||
socket_v4: &UdpSocket,
|
||||
socket_v6: Option<&UdpSocket>,
|
||||
max_pkt_size: usize,
|
||||
) -> Result<(Bytes, SocketAddr), IoError> {
|
||||
async fn read_pkt(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
async fn read_pkt(
|
||||
socket: &UdpSocket,
|
||||
max_pkt_size: usize,
|
||||
) -> Result<(Bytes, SocketAddr), IoError> {
|
||||
let mut buf = vec![0u8; max_pkt_size];
|
||||
let (n, addr) = socket.recv_from(&mut buf).await?;
|
||||
buf.truncate(n);
|
||||
Ok((Bytes::from(buf), addr))
|
||||
@ -518,11 +618,11 @@ impl UdpSession {
|
||||
|
||||
if let Some(socket_v6) = socket_v6 {
|
||||
tokio::select! {
|
||||
res = read_pkt(socket_v4) => res,
|
||||
res = read_pkt(socket_v6) => res,
|
||||
res = read_pkt(socket_v4, max_pkt_size) => res,
|
||||
res = read_pkt(socket_v6, max_pkt_size) => res,
|
||||
}
|
||||
} else {
|
||||
read_pkt(socket_v4).await
|
||||
read_pkt(socket_v4, max_pkt_size).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,44 @@
|
||||
use std::str::FromStr;
|
||||
use rustls::{Certificate, PrivateKey};
|
||||
use rustls_pemfile::Item;
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
io::{BufReader, Error as IoError},
|
||||
path::PathBuf,
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
pub fn load_certs(path: PathBuf) -> Result<Vec<Certificate>, IoError> {
|
||||
let mut file = BufReader::new(File::open(&path)?);
|
||||
let mut certs = Vec::new();
|
||||
|
||||
while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) {
|
||||
if let Item::X509Certificate(cert) = item {
|
||||
certs.push(Certificate(cert));
|
||||
}
|
||||
}
|
||||
|
||||
if certs.is_empty() {
|
||||
certs = vec![Certificate(fs::read(&path)?)];
|
||||
}
|
||||
|
||||
Ok(certs)
|
||||
}
|
||||
|
||||
pub fn load_priv_key(path: PathBuf) -> Result<PrivateKey, IoError> {
|
||||
let mut file = BufReader::new(File::open(&path)?);
|
||||
let mut priv_key = None;
|
||||
|
||||
while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) {
|
||||
if let Item::RSAKey(key) | Item::PKCS8Key(key) | Item::ECKey(key) = item {
|
||||
priv_key = Some(key);
|
||||
}
|
||||
}
|
||||
|
||||
priv_key
|
||||
.map(Ok)
|
||||
.unwrap_or_else(|| fs::read(&path))
|
||||
.map(PrivateKey)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum UdpRelayMode {
|
||||
|
Loading…
x
Reference in New Issue
Block a user