diff --git a/tuic-server/Cargo.toml b/tuic-server/Cargo.toml index 23e9215..cb33da6 100644 --- a/tuic-server/Cargo.toml +++ b/tuic-server/Cargo.toml @@ -3,6 +3,25 @@ name = "tuic-server" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +bytes = { version = "1.3.0", default-features = false, features = ["std"] } +crossbeam-utils = { version = "0.8.14", default-features = false, features = ["std"] } +lexopt = { version = "0.3.0", default-features = false } +once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] } +parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] } +quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] } +register-count = { version = "0.1.0", default-features = false, features = ["std"] } +rustls = { version = "0.20.8", default-features = false, features = ["quic"] } +rustls-native-certs = { version = "0.6.2", default-features = false } +rustls-pemfile = { version = "1.0.2", default-features = false } +serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] } +serde_json = { version = "1.0.91", default-features = false, features = ["std"] } +socket2 = { version = "0.4.7", default-features = false } +socks5-proto = { version = "0.3.3", default-features = false } +socks5-server = { version = "0.8.3", default-features = false } +thiserror = { version = "1.0.38", default-features = false } +tokio = { version = "1.24.2", default-features = false, features = ["macros", "net", "parking_lot", "rt-multi-thread", "time"] } +tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] } +tuic = { path = "../tuic", default-features = false } +tuic-quinn = { path = "../tuic-quinn", default-features = false } +webpki = { version = "0.22.0", default-features = false } \ No newline at end of file diff --git a/tuic-server/src/config.rs b/tuic-server/src/config.rs new file mode 100644 index 0000000..c08805d --- /dev/null +++ b/tuic-server/src/config.rs @@ -0,0 +1,77 @@ +use lexopt::{Arg, Error as ArgumentError, Parser}; +use serde::{de::Error as DeError, Deserialize, Deserializer}; +use serde_json::Error as SerdeError; +use std::{env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, str::FromStr}; +use thiserror::Error; + +const HELP_MSG: &str = r#" +Usage tuic-server [arguments] + +Arguments: + -c, --config Path to the config file (required) + -v, --version Print the version + -h, --help Print this help message +"#; + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Config {} + +impl Config { + pub fn parse(args: ArgsOs) -> Result { + let mut parser = Parser::from_iter(args); + let mut path = None; + + while let Some(arg) = parser.next()? { + match arg { + Arg::Short('c') | Arg::Long("config") => { + if path.is_none() { + path = Some(parser.value()?); + } else { + return Err(ConfigError::Argument(arg.unexpected())); + } + } + Arg::Short('v') | Arg::Long("version") => { + return Err(ConfigError::Version(env!("CARGO_PKG_VERSION"))) + } + Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(HELP_MSG)), + _ => return Err(ConfigError::Argument(arg.unexpected())), + } + } + + if path.is_none() { + return Err(ConfigError::NoConfig); + } + + let file = File::open(path.unwrap())?; + Ok(serde_json::from_reader(file)?) + } +} + +mod default {} + +pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result +where + T: FromStr, + ::Err: Display, + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + T::from_str(&s).map_err(DeError::custom) +} + +#[derive(Debug, Error)] +pub enum ConfigError { + #[error(transparent)] + Argument(#[from] ArgumentError), + #[error("no config file specified")] + NoConfig, + #[error("{0}")] + Version(&'static str), + #[error("{0}")] + Help(&'static str), + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Serde(#[from] SerdeError), +} diff --git a/tuic-server/src/main.rs b/tuic-server/src/main.rs index e7a11a9..d52e01a 100644 --- a/tuic-server/src/main.rs +++ b/tuic-server/src/main.rs @@ -1,3 +1,53 @@ -fn main() { - println!("Hello, world!"); +use self::{ + config::{Config, ConfigError}, + server::Server, +}; +use quinn::{crypto::ExportKeyingMaterialError, ConnectionError}; +use std::{env, io::Error as IoError, process}; +use thiserror::Error; +use tuic_quinn::Error as ModelError; + +mod config; +mod server; +mod utils; + +#[tokio::main] +async fn main() { + let cfg = match Config::parse(env::args_os()) { + Ok(cfg) => cfg, + Err(ConfigError::Version(msg) | ConfigError::Help(msg)) => { + println!("{msg}"); + process::exit(0); + } + Err(err) => { + eprintln!("{err}"); + process::exit(1); + } + }; + + match Server::init(cfg) { + Ok(server) => server.start().await, + Err(err) => { + eprintln!("{err}"); + process::exit(1); + } + } +} + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Connection(#[from] ConnectionError), + #[error(transparent)] + Model(#[from] ModelError), + #[error("duplicated authentication")] + DuplicatedAuth, + #[error("token length too short")] + ExportKeyingMaterial, + #[error("authentication failed")] + AuthFailed, + #[error("received packet from unexpected source")] + UnexpectedPacketSource, } diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs new file mode 100644 index 0000000..9101202 --- /dev/null +++ b/tuic-server/src/server.rs @@ -0,0 +1,269 @@ +use crate::{config::Config, utils::UdpRelayMode, Error}; +use bytes::Bytes; +use crossbeam_utils::atomic::AtomicCell; +use parking_lot::Mutex; +use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::{Context, Poll, Waker}, +}; +use tuic_quinn::{side, Connection as Model, Task}; + +pub struct Server { + ep: Endpoint, + token: Arc<[u8]>, + zero_rtt_handshake: bool, +} + +impl Server { + pub fn init(cfg: Config) -> Result { + todo!() + } + + pub async fn start(&self) { + loop { + let conn = self.ep.accept().await.unwrap(); + tokio::spawn(Connection::init( + conn, + self.token.clone(), + self.zero_rtt_handshake, + )); + } + } +} + +#[derive(Clone)] +struct Connection { + inner: QuinnConnection, + model: Model, + token: Arc<[u8]>, + is_authed: IsAuthed, + udp_relay_mode: Arc>>, +} + +impl Connection { + pub async fn init(conn: Connecting, token: Arc<[u8]>, zero_rtt_handshake: bool) { + match Self::handshake(conn, token, zero_rtt_handshake).await { + Ok(conn) => loop { + if conn.is_closed() { + break; + } + + match conn.accept().await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + }, + Err(err) => eprintln!("{err}"), + } + } + + async fn handshake( + conn: Connecting, + token: Arc<[u8]>, + zero_rtt_handshake: bool, + ) -> Result { + let conn = if zero_rtt_handshake { + match conn.into_0rtt() { + Ok((conn, _)) => conn, + Err(conn) => { + eprintln!("0-RTT handshake failed, fallback to 1-RTT handshake"); + conn.await? + } + } + } else { + conn.await? + }; + + Ok(Self { + inner: conn.clone(), + model: Model::::new(conn), + token, + is_authed: IsAuthed::new(), + udp_relay_mode: Arc::new(AtomicCell::new(None)), + }) + } + + async fn accept(&self) -> Result<(), Error> { + tokio::select! { + res = self.inner.accept_uni() => tokio::spawn(self.clone().handle_uni_stream(res?)), + res = self.inner.accept_bi() => tokio::spawn(self.clone().handle_bi_stream(res?)), + res = self.inner.read_datagram() => tokio::spawn(self.clone().handle_datagram(res?)), + }; + + Ok(()) + } + + async fn handle_uni_stream(self, recv: RecvStream) { + async fn pre_process(conn: &Connection, recv: RecvStream) -> Result { + let task = conn.model.accept_uni_stream(recv).await?; + + if let Task::Authenticate(token) = &task { + if conn.is_authed() { + return Err(Error::DuplicatedAuth); + } else { + let mut buf = [0; 32]; + conn.inner + .export_keying_material(&mut buf, &conn.token, &conn.token) + .map_err(|_| Error::ExportKeyingMaterial)?; + + if token == &buf { + conn.set_authed(); + } else { + return Err(Error::AuthFailed); + } + } + } + + tokio::select! { + () = conn.authed() => {} + err = conn.inner.closed() => Err(err)?, + }; + + let same_pkt_src = matches!(task, Task::Packet(_)) + && matches!(conn.get_udp_relay_mode(), Some(UdpRelayMode::Native)); + if same_pkt_src { + return Err(Error::UnexpectedPacketSource); + } + + Ok(task) + } + + match pre_process(&self, recv).await { + Ok(Task::Packet(pkt)) => todo!(), + Ok(Task::Dissociate(assoc_id)) => todo!(), + Ok(_) => unreachable!(), + Err(err) => { + eprintln!("{err}"); + self.inner.close(VarInt::from_u32(0), b""); + return; + } + } + } + + async fn handle_bi_stream(self, (send, recv): (SendStream, RecvStream)) { + async fn pre_process( + conn: &Connection, + send: SendStream, + recv: RecvStream, + ) -> Result { + let task = conn.model.accept_bi_stream(send, recv).await?; + + tokio::select! { + () = conn.authed() => {} + err = conn.inner.closed() => Err(err)?, + }; + + Ok(task) + } + + match pre_process(&self, send, recv).await { + Ok(Task::Connect(conn)) => todo!(), + Ok(_) => unreachable!(), + Err(err) => { + eprintln!("{err}"); + self.inner.close(VarInt::from_u32(0), b""); + return; + } + } + } + + async fn handle_datagram(self, dg: Bytes) { + async fn pre_process(conn: &Connection, dg: Bytes) -> Result { + let task = conn.model.accept_datagram(dg)?; + + tokio::select! { + () = conn.authed() => {} + err = conn.inner.closed() => Err(err)?, + }; + + let same_pkt_src = matches!(task, Task::Packet(_)) + && matches!(conn.get_udp_relay_mode(), Some(UdpRelayMode::Quic)); + if same_pkt_src { + return Err(Error::UnexpectedPacketSource); + } + + Ok(task) + } + + match pre_process(&self, dg).await { + Ok(Task::Packet(pkt)) => todo!(), + Ok(Task::Heartbeat) => todo!(), + Ok(_) => unreachable!(), + Err(err) => { + eprintln!("{err}"); + self.inner.close(VarInt::from_u32(0), b""); + return; + } + } + } + + fn set_authed(&self) { + self.is_authed.set_authed(); + } + + fn is_authed(&self) -> bool { + self.is_authed.is_authed() + } + + fn authed(&self) -> IsAuthed { + self.is_authed.clone() + } + + fn set_udp_relay_mode(&self, mode: UdpRelayMode) { + self.udp_relay_mode.store(Some(mode)); + } + + fn get_udp_relay_mode(&self) -> Option { + self.udp_relay_mode.load() + } + + fn is_closed(&self) -> bool { + self.inner.close_reason().is_some() + } +} + +#[derive(Clone)] +struct IsAuthed { + is_authed: Arc, + broadcast: Arc>>, +} + +impl IsAuthed { + fn new() -> Self { + Self { + is_authed: Arc::new(AtomicBool::new(false)), + broadcast: Arc::new(Mutex::new(Vec::new())), + } + } + + fn set_authed(&self) { + self.is_authed.store(true, Ordering::Release); + + for waker in self.broadcast.lock().drain(..) { + waker.wake(); + } + } + + fn is_authed(&self) -> bool { + self.is_authed.load(Ordering::Relaxed) + } +} + +impl Future for IsAuthed { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.is_authed.load(Ordering::Relaxed) { + Poll::Ready(()) + } else { + self.broadcast.lock().push(cx.waker().clone()); + Poll::Pending + } + } +} diff --git a/tuic-server/src/utils.rs b/tuic-server/src/utils.rs new file mode 100644 index 0000000..d6849a2 --- /dev/null +++ b/tuic-server/src/utils.rs @@ -0,0 +1,29 @@ +use std::str::FromStr; + +#[derive(Clone, Copy)] +pub enum UdpRelayMode { + Native, + Quic, +} + +pub enum CongestionControl { + Cubic, + NewReno, + Bbr, +} + +impl FromStr for CongestionControl { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + if s.eq_ignore_ascii_case("cubic") { + Ok(Self::Cubic) + } else if s.eq_ignore_ascii_case("new_reno") || s.eq_ignore_ascii_case("newreno") { + Ok(Self::NewReno) + } else if s.eq_ignore_ascii_case("bbr") { + Ok(Self::Bbr) + } else { + Err("invalid congestion control") + } + } +}