diff --git a/server/src/config.rs b/server/src/config.rs index 0459f4b..297d605 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -9,14 +9,22 @@ use rustls::{version::TLS13, Error as RustlsError, ServerConfig as RustlsServerC use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde_json::Error as JsonError; use std::{ - collections::HashSet, env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, - num::ParseIntError, str::FromStr, sync::Arc, time::Duration, + collections::HashSet, + env::ArgsOs, + fmt::Display, + fs::File, + io::Error as IoError, + net::{AddrParseError, IpAddr, Ipv4Addr, SocketAddr}, + num::ParseIntError, + str::FromStr, + sync::Arc, + time::Duration, }; use thiserror::Error; pub struct Config { pub server_config: ServerConfig, - pub port: u16, + pub listen_addr: SocketAddr, pub token: HashSet<[u8; 32]>, pub authentication_timeout: Duration, pub max_udp_relay_packet_size: usize, @@ -68,7 +76,7 @@ impl Config { config }; - let port = raw.port.unwrap(); + let listen_addr = SocketAddr::from((raw.ip, raw.port.unwrap())); let token = raw .token @@ -82,7 +90,7 @@ impl Config { Ok(Self { server_config, - port, + listen_addr, token, authentication_timeout, max_udp_relay_packet_size, @@ -99,6 +107,9 @@ struct RawConfig { certificate: Option, private_key: Option, + #[serde(default = "default::ip")] + ip: IpAddr, + #[serde( default = "default::congestion_controller", deserialize_with = "deserialize_from_str" @@ -128,6 +139,7 @@ impl Default for RawConfig { token: Vec::new(), certificate: None, private_key: None, + ip: default::ip(), congestion_controller: default::congestion_controller(), max_idle_time: default::max_idle_time(), authentication_timeout: default::authentication_timeout(), @@ -172,6 +184,13 @@ impl RawConfig { "PRIVATE_KEY", ); + opts.optopt( + "", + "ip", + "Set the server listening IP. Default: 0.0.0.0", + "IP", + ); + opts.optopt( "", "congestion-controller", @@ -276,6 +295,10 @@ impl RawConfig { } }; + if let Some(ip) = matches.opt_str("ip") { + raw.ip = ip.parse()?; + }; + if let Some(cgstn_ctrl) = matches.opt_str("congestion-controller") { raw.congestion_controller = cgstn_ctrl.parse()?; }; @@ -347,6 +370,10 @@ where mod default { use super::*; + pub(super) fn ip() -> IpAddr { + IpAddr::V4(Ipv4Addr::UNSPECIFIED) + } + pub(super) const fn congestion_controller() -> CongestionController { CongestionController::Cubic } @@ -390,6 +417,8 @@ pub enum ConfigError { MissingOption(&'static str), #[error(transparent)] ParseInt(#[from] ParseIntError), + #[error(transparent)] + ParseAddr(#[from] AddrParseError), #[error("Invalid congestion controller")] InvalidCongestionController, #[error(transparent)] diff --git a/server/src/main.rs b/server/src/main.rs index a3d0c31..682ecac 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -34,7 +34,7 @@ async fn main() { let server = match Server::init( config.server_config, - config.port, + config.listen_addr, config.token, config.authentication_timeout, config.max_udp_relay_packet_size, diff --git a/server/src/server.rs b/server/src/server.rs index 1834ce0..fd54edb 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -4,15 +4,15 @@ use quinn::{Endpoint, EndpointConfig, Incoming, ServerConfig}; use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use std::{ collections::HashSet, - io::Error as IoError, - net::{Ipv6Addr, SocketAddr, UdpSocket}, + io::Result, + net::{SocketAddr, UdpSocket}, sync::Arc, time::Duration, }; pub struct Server { incoming: Incoming, - port: u16, + listen_addr: SocketAddr, token: Arc>, authentication_timeout: Duration, max_pkt_size: usize, @@ -21,24 +21,26 @@ pub struct Server { impl Server { pub fn init( config: ServerConfig, - port: u16, + listen_addr: SocketAddr, token: HashSet<[u8; 32]>, auth_timeout: Duration, max_pkt_size: usize, - ) -> Result { - let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; - socket.set_only_v6(false)?; - socket.bind(&SockAddr::from(SocketAddr::from(( - Ipv6Addr::UNSPECIFIED, - port, - ))))?; - let socket = UdpSocket::from(socket); + ) -> Result { + let socket = match listen_addr { + SocketAddr::V4(_) => UdpSocket::bind(listen_addr)?, + SocketAddr::V6(_) => { + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(false)?; + socket.bind(&SockAddr::from(listen_addr))?; + UdpSocket::from(socket) + } + }; let (_, incoming) = Endpoint::new(EndpointConfig::default(), Some(config), socket)?; Ok(Self { incoming, - port, + listen_addr, token: Arc::new(token), authentication_timeout: auth_timeout, max_pkt_size, @@ -46,7 +48,7 @@ impl Server { } pub async fn run(mut self) { - log::info!("Server started. Listening port: {}", self.port); + log::info!("Server started. Listening: {}", self.listen_addr); while let Some(conn) = self.incoming.next().await { let token = self.token.clone();