From 0707e29a75d7cbfbf08aa0d31c903eaf6ce0e39e Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 31 Jul 2022 12:47:41 +0900 Subject: [PATCH 001/103] preparing for refactor --- Cargo.toml | 52 ++- client/Cargo.toml | 37 -- client/main.rs | 3 + client/src/certificate.rs | 36 -- client/src/config.rs | 629 -------------------------- client/src/main.rs | 58 --- client/src/relay/address.rs | 37 -- client/src/relay/connection.rs | 389 ---------------- client/src/relay/incoming.rs | 157 ------- client/src/relay/mod.rs | 100 ---- client/src/relay/request.rs | 177 -------- client/src/relay/stream.rs | 208 --------- client/src/relay/task.rs | 121 ----- client/src/socks5/associate.rs | 140 ------ client/src/socks5/bind.rs | 24 - client/src/socks5/connect.rs | 34 -- client/src/socks5/mod.rs | 93 ---- protocol/Cargo.toml | 17 - protocol/README.md | 202 --------- protocol/src/lib.rs | 355 --------------- server/Cargo.toml | 32 -- server/main.rs | 3 + server/src/certificate.rs | 39 -- server/src/config.rs | 428 ------------------ server/src/connection/authenticate.rs | 53 --- server/src/connection/dispatch.rs | 226 --------- server/src/connection/mod.rs | 258 ----------- server/src/connection/task.rs | 186 -------- server/src/connection/udp.rs | 178 -------- server/src/main.rs | 51 --- server/src/server.rs | 62 --- src/lib.rs | 0 32 files changed, 51 insertions(+), 4334 deletions(-) delete mode 100644 client/Cargo.toml create mode 100644 client/main.rs delete mode 100644 client/src/certificate.rs delete mode 100644 client/src/config.rs delete mode 100644 client/src/main.rs delete mode 100644 client/src/relay/address.rs delete mode 100644 client/src/relay/connection.rs delete mode 100644 client/src/relay/incoming.rs delete mode 100644 client/src/relay/mod.rs delete mode 100644 client/src/relay/request.rs delete mode 100644 client/src/relay/stream.rs delete mode 100644 client/src/relay/task.rs delete mode 100644 client/src/socks5/associate.rs delete mode 100644 client/src/socks5/bind.rs delete mode 100644 client/src/socks5/connect.rs delete mode 100644 client/src/socks5/mod.rs delete mode 100644 protocol/Cargo.toml delete mode 100644 protocol/README.md delete mode 100644 protocol/src/lib.rs delete mode 100644 server/Cargo.toml create mode 100644 server/main.rs delete mode 100644 server/src/certificate.rs delete mode 100644 server/src/config.rs delete mode 100644 server/src/connection/authenticate.rs delete mode 100644 server/src/connection/dispatch.rs delete mode 100644 server/src/connection/mod.rs delete mode 100644 server/src/connection/task.rs delete mode 100644 server/src/connection/udp.rs delete mode 100644 server/src/main.rs delete mode 100644 server/src/server.rs create mode 100644 src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index f8d0061..e78e84e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,50 @@ -[workspace] -members = [ - "client", - "server", - "protocol", -] +[package] +name = "tuic" +version = "0.9.0" +authors = ["EAimTY "] +description = "Delicately-TUICed high-performance proxy built on top of the QUIC protocol" +categories = ["network-programming", "command-line-utilities"] +keywords = ["tuic", "proxy", "quic"] +edition = "2021" +rust-version = "1.59" +readme = "README.md" +license = "GPL-3.0-or-later" +repository = "https://github.com/EAimTY/tuic" + +[[bin]] +name = "client" +path = "client/main.rs" + +[[bin]] +name = "server" +path = "server/main.rs" + +[dependencies] +blake3 = "1.3.*" +bytes = "1.2.*" +crossbeam-utils = { version = "0.8.*", default-features = false } +env_logger = { version = "0.9.*", features = ["humantime"], default-features = false } +futures-util = { version = "0.3.*", default-features = false } +getopts = "0.2.*" +log = { version = "0.4.*", features = ["serde", "std"] } +once_cell = { version = "1.13.*", features = ["parking_lot"] } +parking_lot = "0.12.*" +quinn = { version = "0.8.*", features = ["tls-rustls"], default-features = false } +rand = "0.8.*" +rustls = { version = "0.20.*", features = ["quic"], default-features = false } +rustls-native-certs = "0.6.*" +rustls-pemfile = "1.0.*" +serde = { version = "1.0.*", features = ["derive", "std"], default-features = false } +serde_json = { version = "1.0.*", features = ["std"], default-features = false } +socket2 = "0.4.*" +socks5-proto = "0.3.*" +socks5-server = "0.8.*" +thiserror = "1.0.*" +tokio = { version = "1.20.*", features = ["io-util", "macros", "net", "parking_lot", "rt-multi-thread", "sync", "time"] } +webpki = { version = "0.22.*", default-features = false } [profile.release] lto = true strip = true codegen-units = 1 -panic = "abort" +panic = "abort" \ No newline at end of file diff --git a/client/Cargo.toml b/client/Cargo.toml deleted file mode 100644 index 1e11d5a..0000000 --- a/client/Cargo.toml +++ /dev/null @@ -1,37 +0,0 @@ -[package] -name = "tuic-client" -version = "0.8.5" -authors = ["EAimTY "] -description = "Delicately-TUICed high-performance proxy built on top of the QUIC protocol" -categories = ["network-programming", "command-line-utilities"] -keywords = ["tuic", "proxy", "quic"] -edition = "2021" -rust-version = "1.59" -readme = "../README.md" -license = "GPL-3.0-or-later" -repository = "https://github.com/EAimTY/tuic" - -[dependencies] -tuic-protocol = { path="../protocol" } - -blake3 = "1.3.*" -bytes = "1.2.*" -env_logger = { version = "0.9.*", features = ["humantime"], default-features = false } -futures-util = { version = "0.3.*", default-features = false } -getopts = "0.2.*" -log = { version = "0.4.*", features = ["serde", "std"] } -once_cell = "1.13.*" -parking_lot = "0.12.*" -quinn = "0.8.*" -rand = "0.8.*" -rustls = { version = "0.20.*", features = ["quic"], default-features = false } -rustls-native-certs = "0.6.*" -rustls-pemfile = "1.0.*" -serde = { version = "1.0.*", features = ["derive", "std"], default-features = false } -serde_json = { version = "1.0.*", features = ["std"], default-features = false } -socket2 = "0.4.*" -socks5-proto = "0.3.*" -socks5-server = "0.8.*" -thiserror = "1.0.*" -tokio = { version = "1.20.*", features = ["io-util", "macros", "net", "parking_lot", "rt-multi-thread", "sync", "time"] } -webpki = { version = "0.22.*", default-features = false } diff --git a/client/main.rs b/client/main.rs new file mode 100644 index 0000000..f084790 --- /dev/null +++ b/client/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello World!"); +} \ No newline at end of file diff --git a/client/src/certificate.rs b/client/src/certificate.rs deleted file mode 100644 index 2c7a7b1..0000000 --- a/client/src/certificate.rs +++ /dev/null @@ -1,36 +0,0 @@ -use crate::config::ConfigError; -use rustls::{Certificate, RootCertStore}; -use rustls_pemfile::Item; -use std::{ - fs::{self, File}, - io::BufReader, -}; - -pub fn load_certificates(files: Vec) -> Result { - let mut certs = RootCertStore::empty(); - - for file in &files { - let mut file = - BufReader::new(File::open(file).map_err(|err| ConfigError::Io(file.to_owned(), err))?); - - while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) { - if let Item::X509Certificate(cert) = item { - certs.add(&Certificate(cert))?; - } - } - } - - if certs.is_empty() { - for file in &files { - certs.add(&Certificate( - fs::read(file).map_err(|err| ConfigError::Io(file.to_owned(), err))?, - ))?; - } - } - - for cert in rustls_native_certs::load_native_certs().map_err(ConfigError::NativeCertificate)? { - certs.add(&Certificate(cert.0))?; - } - - Ok(certs) -} diff --git a/client/src/config.rs b/client/src/config.rs deleted file mode 100644 index 92767e9..0000000 --- a/client/src/config.rs +++ /dev/null @@ -1,629 +0,0 @@ -use crate::{ - certificate, - relay::{ServerAddr, UdpRelayMode}, -}; -use getopts::{Fail, Options}; -use log::{LevelFilter, ParseLevelError}; -use quinn::{ - congestion::{BbrConfig, CubicConfig, NewRenoConfig}, - ClientConfig, -}; -use rustls::{version::TLS13, ClientConfig as RustlsClientConfig}; -use serde::{de::Error as DeError, Deserialize, Deserializer}; -use serde_json::Error as JsonError; -use socks5_server::{ - auth::{NoAuth, Password}, - Auth, -}; -use std::{ - env::ArgsOs, - fmt::Display, - fs::File, - io::Error as IoError, - net::{AddrParseError, IpAddr, Ipv4Addr, SocketAddr}, - num::ParseIntError, - str::FromStr, - sync::Arc, -}; -use thiserror::Error; -use webpki::Error as WebpkiError; - -pub struct Config { - pub client_config: ClientConfig, - pub server_addr: ServerAddr, - pub token_digest: [u8; 32], - pub udp_relay_mode: UdpRelayMode<(), ()>, - pub heartbeat_interval: u64, - pub reduce_rtt: bool, - pub request_timeout: u64, - pub max_udp_relay_packet_size: usize, - pub local_addr: SocketAddr, - pub socks5_auth: Arc, - pub log_level: LevelFilter, -} - -impl Config { - pub fn parse(args: ArgsOs) -> Result { - let raw = RawConfig::parse(args)?; - - let client_config = { - let certs = certificate::load_certificates(raw.relay.certificates)?; - - let mut crypto = RustlsClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&TLS13]) - .unwrap() - .with_root_certificates(certs) - .with_no_client_auth(); - - crypto.alpn_protocols = raw - .relay - .alpn - .into_iter() - .map(|alpn| alpn.into_bytes()) - .collect(); - - crypto.enable_early_data = true; - crypto.enable_sni = !raw.relay.disable_sni; - - let mut config = ClientConfig::new(Arc::new(crypto)); - let transport = Arc::get_mut(&mut config.transport).unwrap(); - - match raw.relay.congestion_controller { - CongestionController::Bbr => { - transport.congestion_controller_factory(Arc::new(BbrConfig::default())); - } - CongestionController::Cubic => { - transport.congestion_controller_factory(Arc::new(CubicConfig::default())); - } - CongestionController::NewReno => { - transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); - } - } - - transport.max_idle_timeout(None); - - config - }; - - let server_addr = { - let name = raw.relay.server.unwrap(); - let port = raw.relay.port.unwrap(); - - if let Some(ip) = raw.relay.ip { - ServerAddr::SocketAddr { - addr: SocketAddr::new(ip, port), - name, - } - } else { - ServerAddr::DomainAddr { domain: name, port } - } - }; - - let token_digest = *blake3::hash(&raw.relay.token.unwrap().into_bytes()).as_bytes(); - let udp_relay_mode = raw.relay.udp_relay_mode; - let heartbeat_interval = raw.relay.heartbeat_interval; - let reduce_rtt = raw.relay.reduce_rtt; - let request_timeout = raw.relay.request_timeout; - let max_udp_relay_packet_size = raw.relay.max_udp_relay_packet_size; - - let local_addr = SocketAddr::from((raw.local.ip, raw.local.port.unwrap())); - - let socks5_auth = match (raw.local.username, raw.local.password) { - (None, None) => Arc::new(NoAuth) as Arc, - (Some(username), Some(password)) => { - Arc::new(Password::new(username.into_bytes(), password.into_bytes())) - as Arc - } - _ => return Err(ConfigError::LocalAuthentication), - }; - - let log_level = raw.log_level; - - Ok(Self { - client_config, - server_addr, - token_digest, - udp_relay_mode, - heartbeat_interval, - reduce_rtt, - request_timeout, - max_udp_relay_packet_size, - local_addr, - socks5_auth, - log_level, - }) - } -} - -#[derive(Deserialize)] -#[serde(deny_unknown_fields)] -struct RawConfig { - relay: RawRelayConfig, - local: RawLocalConfig, - - #[serde(default = "default::log_level")] - log_level: LevelFilter, -} - -#[derive(Deserialize)] -#[serde(deny_unknown_fields)] -struct RawRelayConfig { - server: Option, - port: Option, - token: Option, - ip: Option, - - #[serde(default = "default::certificates")] - certificates: Vec, - - #[serde( - default = "default::udp_relay_mode", - deserialize_with = "deserialize_from_str" - )] - udp_relay_mode: UdpRelayMode<(), ()>, - - #[serde( - default = "default::congestion_controller", - deserialize_with = "deserialize_from_str" - )] - congestion_controller: CongestionController, - - #[serde(default = "default::heartbeat_interval")] - heartbeat_interval: u64, - - #[serde(default = "default::alpn")] - alpn: Vec, - - #[serde(default = "default::disable_sni")] - disable_sni: bool, - - #[serde(default = "default::reduce_rtt")] - reduce_rtt: bool, - - #[serde(default = "default::request_timeout")] - request_timeout: u64, - - #[serde(default = "default::max_udp_relay_packet_size")] - max_udp_relay_packet_size: usize, -} - -#[derive(Deserialize)] -#[serde(deny_unknown_fields)] -struct RawLocalConfig { - port: Option, - - #[serde(default = "default::local_ip")] - ip: IpAddr, - - username: Option, - password: Option, -} - -impl Default for RawConfig { - fn default() -> Self { - Self { - relay: RawRelayConfig::default(), - local: RawLocalConfig::default(), - log_level: default::log_level(), - } - } -} - -impl Default for RawRelayConfig { - fn default() -> Self { - Self { - server: None, - port: None, - ip: None, - token: None, - certificates: default::certificates(), - udp_relay_mode: default::udp_relay_mode(), - congestion_controller: default::congestion_controller(), - heartbeat_interval: default::heartbeat_interval(), - alpn: default::alpn(), - disable_sni: default::disable_sni(), - reduce_rtt: default::reduce_rtt(), - request_timeout: default::request_timeout(), - max_udp_relay_packet_size: default::max_udp_relay_packet_size(), - } - } -} - -impl Default for RawLocalConfig { - fn default() -> Self { - Self { - port: None, - ip: default::local_ip(), - username: None, - password: None, - } - } -} - -impl RawConfig { - fn parse(args: ArgsOs) -> Result { - let mut opts = Options::new(); - - opts.optopt( - "c", - "config", - "Read configuration from a file. Note that command line arguments will override the configuration file", - "CONFIG_FILE", - ); - - opts.optopt( - "", - "server", - "Set the server address. This address must be included in the certificate", - "SERVER", - ); - - opts.optopt("", "server-port", "Set the server port", "SERVER_PORT"); - - opts.optopt( - "", - "token", - "Set the token for TUIC authentication", - "TOKEN", - ); - - opts.optopt( - "", - "server-ip", - "Set the server IP, for overwriting the DNS lookup result of the server address set in option 'server'", - "SERVER_IP", - ); - - opts.optmulti( - "", - "certificate", - "Set custom X.509 certificate alongside native CA roots for the QUIC handshake. This option can be used multiple times to set multiple certificates", - "CERTIFICATE", - ); - - opts.optopt( - "", - "udp-relay-mode", - r#"Set the UDP relay mode. Available: "native", "quic". Default: "native""#, - "UDP_MODE", - ); - - opts.optopt( - "", - "congestion-controller", - r#"Set the congestion control algorithm. Available: "cubic", "new_reno", "bbr". Default: "cubic""#, - "CONGESTION_CONTROLLER", - ); - - opts.optopt( - "", - "heartbeat-interval", - "Set the heartbeat interval to ensures that the QUIC connection is not closed when there are relay tasks but no data transfer, in milliseconds. This value needs to be smaller than the maximum idle time set at the server side. Default: 10000", - "HEARTBEAT_INTERVAL", - ); - - opts.optmulti( - "", - "alpn", - "Set ALPN protocols included in the TLS client hello. This option can be used multiple times to set multiple ALPN protocols. If not set, no ALPN extension will be sent", - "ALPN_PROTOCOL", - ); - - opts.optflag( - "", - "disable-sni", - "Not sending the Server Name Indication (SNI) extension during the client TLS handshake", - ); - - opts.optflag("", "reduce-rtt", "Enable 0-RTT QUIC handshake"); - - opts.optopt( - "", - "request-timeout", - "Set the timeout for negotiating tasks between client and the server, in milliseconds. Default: 8000", - "REQUEST_TIMEOUT", - ); - - opts.optopt( - "", - "max-udp-relay-packet-size", - "UDP relay mode QUIC can transmit UDP packets larger than the MTU. Set this to a higher value allows inbound to receive larger UDP packet. Default: 1500", - "MAX_UDP_RELAY_PACKET_SIZE", - ); - - opts.optopt( - "", - "local-port", - "Set the listening port for the local socks5 server", - "LOCAL_PORT", - ); - - opts.optopt( - "", - "local-ip", - r#"Set the listening IP for the local socks5 server. Note that the sock5 server socket will be a dual-stack socket if it is IPv6. Default: "127.0.0.1""#, - "LOCAL_IP", - ); - - opts.optopt( - "", - "local-username", - "Set the username for the local socks5 server authentication", - "LOCAL_USERNAME", - ); - - opts.optopt( - "", - "local-password", - "Set the password for the local socks5 server authentication", - "LOCAL_PASSWORD", - ); - - opts.optopt( - "", - "log-level", - r#"Set the log level. Available: "off", "error", "warn", "info", "debug", "trace". Default: "info""#, - "LOG_LEVEL", - ); - - opts.optflag("v", "version", "Print the version"); - opts.optflag("h", "help", "Print this help menu"); - - let matches = opts.parse(args.skip(1))?; - - if matches.opt_present("help") { - return Err(ConfigError::Help(opts.usage(env!("CARGO_PKG_NAME")))); - } - - if matches.opt_present("version") { - return Err(ConfigError::Version(env!("CARGO_PKG_VERSION"))); - } - - if !matches.free.is_empty() { - return Err(ConfigError::UnexpectedArguments(matches.free.join(", "))); - } - - let server = matches.opt_str("server"); - let server_port = matches.opt_str("server-port").map(|port| port.parse()); - let token = matches.opt_str("token"); - let local_port = matches.opt_str("local-port").map(|port| port.parse()); - - let mut raw = if let Some(path) = matches.opt_str("config") { - let mut raw = RawConfig::from_file(path)?; - - raw.relay.server = Some( - server - .or(raw.relay.server) - .ok_or(ConfigError::MissingOption("server address"))?, - ); - - raw.relay.port = Some( - server_port - .transpose()? - .or(raw.relay.port) - .ok_or(ConfigError::MissingOption("server port"))?, - ); - - raw.relay.token = Some( - token - .or(raw.relay.token) - .ok_or(ConfigError::MissingOption("token"))?, - ); - - raw.local.port = Some( - local_port - .transpose()? - .or(raw.local.port) - .ok_or(ConfigError::MissingOption("local port"))?, - ); - - raw - } else { - let relay = RawRelayConfig { - server: Some(server.ok_or(ConfigError::MissingOption("server address"))?), - port: Some(server_port.ok_or(ConfigError::MissingOption("server port"))??), - token: Some(token.ok_or(ConfigError::MissingOption("token"))?), - ..Default::default() - }; - - let local = RawLocalConfig { - port: Some(local_port.ok_or(ConfigError::MissingOption("local port"))??), - ..Default::default() - }; - - RawConfig { - relay, - local, - ..Default::default() - } - }; - - if let Some(ip) = matches.opt_str("server-ip") { - raw.relay.ip = Some(ip.parse()?); - }; - - let certificates = matches.opt_strs("certificate"); - - if !certificates.is_empty() { - raw.relay.certificates = certificates; - } - - if let Some(mode) = matches.opt_str("udp-relay-mode") { - raw.relay.udp_relay_mode = mode.parse()?; - }; - - if let Some(cgstn_ctrl) = matches.opt_str("congestion-controller") { - raw.relay.congestion_controller = cgstn_ctrl.parse()?; - }; - - if let Some(interval) = matches.opt_str("heartbeat-interval") { - raw.relay.heartbeat_interval = interval.parse()?; - }; - - let alpn = matches.opt_strs("alpn"); - - if !alpn.is_empty() { - raw.relay.alpn = alpn; - } - - raw.relay.disable_sni |= matches.opt_present("disable-sni"); - raw.relay.reduce_rtt |= matches.opt_present("reduce-rtt"); - - if let Some(timeout) = matches.opt_str("request-timeout") { - raw.relay.request_timeout = timeout.parse()?; - }; - - if let Some(size) = matches.opt_str("max-udp-relay-packet-size") { - raw.relay.max_udp_relay_packet_size = size.parse()?; - }; - - if let Some(local_ip) = matches.opt_str("local-ip") { - raw.local.ip = local_ip.parse()?; - }; - - raw.local.username = matches.opt_str("local-username").or(raw.local.username); - raw.local.password = matches.opt_str("local-password").or(raw.local.password); - - if let Some(log_level) = matches.opt_str("log-level") { - raw.log_level = log_level.parse()?; - }; - - Ok(raw) - } - - fn from_file(path: String) -> Result { - let file = File::open(&path).map_err(|err| ConfigError::Io(path, err))?; - let raw = serde_json::from_reader(file)?; - Ok(raw) - } -} - -enum CongestionController { - Cubic, - NewReno, - Bbr, -} - -impl FromStr for CongestionController { - type Err = ConfigError; - - fn from_str(s: &str) -> Result { - if s.eq_ignore_ascii_case("cubic") { - Ok(Self::Cubic) - } else if s.eq_ignore_ascii_case("new_reno") || s.eq_ignore_ascii_case("newreno") { - Ok(Self::NewReno) - } else if s.eq_ignore_ascii_case("bbr") { - Ok(Self::Bbr) - } else { - Err(ConfigError::InvalidCongestionController) - } - } -} - -impl FromStr for UdpRelayMode<(), ()> { - type Err = ConfigError; - - fn from_str(s: &str) -> Result { - if s.eq_ignore_ascii_case("native") { - Ok(Self::Native(())) - } else if s.eq_ignore_ascii_case("quic") { - Ok(Self::Quic(())) - } else { - Err(ConfigError::InvalidUdpRelayMode) - } - } -} - -fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result -where - T: FromStr, - ::Err: Display, - D: Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - T::from_str(&s).map_err(DeError::custom) -} - -mod default { - use super::*; - - pub(super) const fn certificates() -> Vec { - Vec::new() - } - - pub(super) const fn udp_relay_mode() -> UdpRelayMode<(), ()> { - UdpRelayMode::Native(()) - } - - pub(super) const fn congestion_controller() -> CongestionController { - CongestionController::Cubic - } - - pub(super) const fn heartbeat_interval() -> u64 { - 10000 - } - - pub(super) const fn alpn() -> Vec { - Vec::new() - } - - pub(super) const fn disable_sni() -> bool { - false - } - - pub(super) const fn reduce_rtt() -> bool { - false - } - - pub(super) const fn request_timeout() -> u64 { - 8000 - } - - pub(super) const fn max_udp_relay_packet_size() -> usize { - 1500 - } - - pub(super) const fn local_ip() -> IpAddr { - IpAddr::V4(Ipv4Addr::LOCALHOST) - } - - pub(super) const fn log_level() -> LevelFilter { - LevelFilter::Info - } -} - -#[derive(Error, Debug)] -pub enum ConfigError { - #[error("{0}")] - Help(String), - #[error("{0}")] - Version(&'static str), - #[error("Failed to read '{0}': {1}")] - Io(String, #[source] IoError), - #[error("Failed to parse the config file: {0}")] - ParseConfigJson(#[from] JsonError), - #[error(transparent)] - ParseArgument(#[from] Fail), - #[error("Unexpected arguments: {0}")] - UnexpectedArguments(String), - #[error("Missing option: {0}")] - MissingOption(&'static str), - #[error(transparent)] - ParseInt(#[from] ParseIntError), - #[error(transparent)] - ParseAddr(#[from] AddrParseError), - #[error("Invalid congestion controller")] - InvalidCongestionController, - #[error("Invalid udp relay mode")] - InvalidUdpRelayMode, - #[error("Failed to load the certificate: {0}")] - Certificate(#[from] WebpkiError), - #[error("Could not load platform certs: {0}")] - NativeCertificate(#[source] IoError), - #[error("Username and password must be set together for the local socks5 server")] - LocalAuthentication, - #[error(transparent)] - ParseLogLevel(#[from] ParseLevelError), -} diff --git a/client/src/main.rs b/client/src/main.rs deleted file mode 100644 index e2df8b2..0000000 --- a/client/src/main.rs +++ /dev/null @@ -1,58 +0,0 @@ -use crate::config::{Config, ConfigError}; -use std::{env, process}; - -mod certificate; -mod config; -mod relay; -mod socks5; - -#[tokio::main] -async fn main() { - let args = env::args_os(); - - let config = match Config::parse(args) { - Ok(cfg) => cfg, - Err(err) => { - match err { - ConfigError::Help(help) => println!("{help}"), - ConfigError::Version(version) => println!("{version}"), - err => eprintln!("{err}"), - } - return; - } - }; - - env_logger::builder() - .filter_level(config.log_level) - .format_level(true) - .format_target(false) - .format_module_path(false) - .init(); - - let (relay, req_tx) = relay::init( - config.client_config, - config.server_addr, - config.token_digest, - config.heartbeat_interval, - config.reduce_rtt, - config.udp_relay_mode, - config.request_timeout, - config.max_udp_relay_packet_size, - ) - .await; - - let socks5 = match socks5::init(config.local_addr, config.socks5_auth, req_tx).await { - Ok(socks5) => socks5, - Err(err) => { - eprintln!("{err}"); - return; - } - }; - - tokio::select! { - res = relay => res, - res = socks5 => res, - }; - - process::exit(1); -} diff --git a/client/src/relay/address.rs b/client/src/relay/address.rs deleted file mode 100644 index 44b7168..0000000 --- a/client/src/relay/address.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::{ - fmt::{Display, Formatter, Result as FmtResult}, - net::SocketAddr, -}; -use tuic_protocol::Address as TuicAddress; - -pub enum Address { - DomainAddress(String, u16), - SocketAddress(SocketAddr), -} - -impl From for Address { - fn from(address: TuicAddress) -> Self { - match address { - TuicAddress::DomainAddress(hostname, port) => Self::DomainAddress(hostname, port), - TuicAddress::SocketAddress(socket_addr) => Self::SocketAddress(socket_addr), - } - } -} - -impl From
for TuicAddress { - fn from(address: Address) -> Self { - match address { - Address::DomainAddress(hostname, port) => Self::DomainAddress(hostname, port), - Address::SocketAddress(socket_addr) => Self::SocketAddress(socket_addr), - } - } -} - -impl Display for Address { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - Address::DomainAddress(hostname, port) => write!(f, "{hostname}:{port}"), - Address::SocketAddress(socket_addr) => write!(f, "{socket_addr}"), - } - } -} diff --git a/client/src/relay/connection.rs b/client/src/relay/connection.rs deleted file mode 100644 index 91fcbb6..0000000 --- a/client/src/relay/connection.rs +++ /dev/null @@ -1,389 +0,0 @@ -use super::{ - incoming::{self, Sender as IncomingSender}, - request::Wait as WaitRequest, - stream::{BiStream, IncomingUniStreams, RecvStream, Register as StreamRegister, SendStream}, - Address, ServerAddr, UdpRelayMode, -}; -use bytes::Bytes; -use parking_lot::Mutex; -use quinn::{ClientConfig, Connection as QuinnConnection, Datagrams, Endpoint, NewConnection}; -use std::{ - collections::HashMap, - future::Future, - io::{Error, ErrorKind, Result}, - net::{Ipv4Addr, Ipv6Addr, SocketAddr}, - ops::{Deref, DerefMut}, - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll, Waker}, - time::Duration, -}; -use tokio::{ - net, - sync::{mpsc::Sender as MpscSender, Mutex as AsyncMutex, OwnedMutexGuard}, - time, -}; -use tuic_protocol::Command; - -pub async fn manage_connection( - config: ConnectionConfig, - conn: Arc>>, - lock: OwnedMutexGuard>, - mut next_incoming_tx: UdpRelayMode< - IncomingSender, - IncomingSender, - >, - wait_req: WaitRequest, -) { - let mut lock = Some(lock); - - loop { - // establish a new connection - let new_conn = loop { - // start the procedure only if there is a request waiting - wait_req.clone().await; - - // try to establish a new connection - let (new_conn, dg, uni) = match Connection::connect(&config).await { - Ok(conn) => conn, - Err(err) => { - log::error!("[relay] [connection] {err}"); - - // sleep 1 second to avoid drawing too much CPU - time::sleep(Duration::from_secs(1)).await; - - continue; - } - }; - - // renew the connection mutex - // safety: the mutex must be locked before, so this container must have a lock guard inside - let mut lock = lock.take().unwrap(); - *lock.deref_mut() = Some(new_conn.clone()); - - // send the incoming streams to `incoming::listen_incoming` - match next_incoming_tx { - UdpRelayMode::Native(incoming_tx) => { - let (tx, rx) = incoming::channel::(); - let _ = incoming_tx.send(new_conn.clone(), dg, rx); - next_incoming_tx = UdpRelayMode::Native(tx); - } - UdpRelayMode::Quic(incoming_tx) => { - let (tx, rx) = incoming::channel::(); - let _ = incoming_tx.send(new_conn.clone(), uni, rx); - next_incoming_tx = UdpRelayMode::Quic(tx); - } - } - - new_conn.update_max_udp_relay_packet_size(); - - // connection established, drop the lock implicitly - break new_conn; - }; - - log::debug!("[relay] [connection] [establish]"); - - // wait for the connection to be closed, lock the mutex - new_conn.wait_close().await; - - log::debug!("[relay] [connection] [disconnect]"); - lock = Some(conn.clone().lock_owned().await); - } -} - -#[derive(Clone)] -pub struct Connection { - controller: QuinnConnection, - udp_sessions: Arc, - stream_reg: Arc, - udp_relay_mode: UdpRelayMode<(), ()>, - is_closed: IsClosed, - default_max_udp_relay_packet_size: usize, -} - -impl Connection { - async fn connect(config: &ConnectionConfig) -> Result<(Self, Datagrams, IncomingUniStreams)> { - let (addrs, name) = match &config.server_addr { - ServerAddr::SocketAddr { addr, name } => Ok((vec![*addr], name)), - ServerAddr::DomainAddr { domain, port } => net::lookup_host((domain.as_str(), *port)) - .await - .map(|res| (res.collect(), domain)), - }?; - - let mut conn = None; - let mut last_err = None; - - for addr in addrs { - match Self::connect_addr(config, addr, name).await { - Ok(new_conn) => { - conn = Some(new_conn); - break; - } - Err(err) => last_err = Some(err), - } - } - - conn.ok_or_else(|| { - last_err - .unwrap_or_else(|| Error::new(ErrorKind::Other, "Unable to connect to the server")) - }) - } - - async fn connect_addr( - config: &ConnectionConfig, - addr: SocketAddr, - name: &str, - ) -> Result<(Self, Datagrams, IncomingUniStreams)> { - let bind_addr = match addr { - SocketAddr::V4(_) => SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)), - SocketAddr::V6(_) => SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)), - }; - - let conn = Endpoint::client(bind_addr)? - .connect_with(config.quinn_config.clone(), addr, name) - .map_err(|err| Error::new(ErrorKind::Other, err))?; - - let NewConnection { - connection, - datagrams, - uni_streams, - .. - } = if config.reduce_rtt { - match conn.into_0rtt() { - Ok((conn, _)) => conn, - Err(conn) => { - log::warn!("[relay] [connection] Unable to convert the connection into 0-RTT"); - conn.await? - } - } - } else { - conn.await? - }; - - let conn = Self::new(connection, config).await; - let uni_streams = IncomingUniStreams::new(uni_streams, conn.stream_reg.get_registry()); - - Ok((conn, datagrams, uni_streams)) - } - - async fn new(conn: QuinnConnection, config: &ConnectionConfig) -> Self { - let conn = Self { - controller: conn, - udp_sessions: Arc::new(UdpSessionMap::new()), - stream_reg: Arc::new(StreamRegister::new()), - udp_relay_mode: config.udp_relay_mode, - is_closed: IsClosed::new(), - default_max_udp_relay_packet_size: config.max_udp_relay_packet_size, - }; - - // send auth - tokio::spawn(Self::send_authentication(conn.clone(), config.token_digest)); - - // heartbeat - tokio::spawn(Self::heartbeat(conn.clone(), config.heartbeat_interval)); - - conn - } - - async fn send_authentication(self, token_digest: [u8; 32]) { - async fn send_token(conn: &Connection, token_digest: [u8; 32]) -> Result<()> { - let mut send = conn.get_send_stream().await?; - let cmd = Command::new_authenticate(token_digest); - cmd.write_to(&mut send).await?; - send.finish().await?; - Ok(()) - } - - match send_token(&self, token_digest).await { - Ok(()) => log::debug!("[relay] [connection] [authentication]"), - Err(err) => log::warn!("[relay] [connection] [authentication] {err}"), - } - } - - async fn heartbeat(self, heartbeat_interval: u64) { - async fn send_heartbeat(conn: &Connection) -> Result<()> { - let mut send = conn.get_send_stream().await?; - let cmd = Command::new_heartbeat(); - cmd.write_to(&mut send).await?; - send.finish().await?; - Ok(()) - } - - let mut interval = time::interval(Duration::from_millis(heartbeat_interval)); - - while tokio::select! { - () = self.wait_close() => false, - _ = interval.tick() => true, - } { - if !self.no_active_stream() || !self.no_active_udp_session() { - match send_heartbeat(&self).await { - Ok(()) => log::debug!("[relay] [connection] [heartbeat]"), - Err(err) => log::warn!("[relay] [connection] [heartbeat] {err}"), - } - } - } - } - - pub async fn get_send_stream(&self) -> Result { - let send = self.controller.open_uni().await?; - let reg = (*self.stream_reg).clone(); // clone inner, not itself - Ok(SendStream::new(send, reg)) - } - - pub async fn get_bi_stream(&self) -> Result { - let (send, recv) = self.controller.open_bi().await?; - let reg = (*self.stream_reg).clone(); // clone inner, not itself - - Ok(BiStream::new( - SendStream::new(send, reg.clone()), - RecvStream::new(recv, reg), - )) - } - - pub fn send_datagram(&self, data: Bytes) -> Result<()> { - self.controller - .send_datagram(data) - .map_err(|err| Error::new(ErrorKind::Other, err)) - } - - pub fn udp_sessions(&self) -> &UdpSessionMap { - self.udp_sessions.deref() - } - - pub fn udp_relay_mode(&self) -> UdpRelayMode<(), ()> { - self.udp_relay_mode - } - - pub fn update_max_udp_relay_packet_size(&self) { - let size = match self.udp_relay_mode { - UdpRelayMode::Native(()) => match self.controller.max_datagram_size() { - Some(size) => size, - None => { - log::warn!("[relay] [connection] Failed to detect the max datagram size"); - self.default_max_udp_relay_packet_size - } - }, - UdpRelayMode::Quic(()) => self.default_max_udp_relay_packet_size, - }; - - super::MAX_UDP_RELAY_PACKET_SIZE.store(size, Ordering::Release); - } - - fn no_active_stream(&self) -> bool { - self.stream_reg.count() == 1 - } - - fn no_active_udp_session(&self) -> bool { - self.udp_sessions.is_empty() - } - - pub fn set_closed(&self) { - self.is_closed.set() - } - - fn wait_close(&self) -> IsClosed { - self.is_closed.clone() - } -} - -pub struct ConnectionConfig { - quinn_config: ClientConfig, - server_addr: ServerAddr, - token_digest: [u8; 32], - udp_relay_mode: UdpRelayMode<(), ()>, - heartbeat_interval: u64, - reduce_rtt: bool, - max_udp_relay_packet_size: usize, -} - -impl ConnectionConfig { - pub fn new( - quinn_config: ClientConfig, - server_addr: ServerAddr, - token_digest: [u8; 32], - udp_relay_mode: UdpRelayMode<(), ()>, - heartbeat_interval: u64, - reduce_rtt: bool, - max_udp_relay_packet_size: usize, - ) -> Self { - Self { - quinn_config, - server_addr, - token_digest, - udp_relay_mode, - heartbeat_interval, - reduce_rtt, - max_udp_relay_packet_size, - } - } -} - -pub struct UdpSessionMap(Mutex>>); - -impl UdpSessionMap { - fn new() -> Self { - Self(Mutex::new(HashMap::new())) - } - - pub fn insert( - &self, - id: u32, - tx: MpscSender<(Bytes, Address)>, - ) -> Option> { - self.0.lock().insert(id, tx) - } - - pub fn get(&self, id: &u32) -> Option> { - self.0.lock().get(id).cloned() - } - - pub fn remove(&self, id: &u32) -> Option> { - self.0.lock().remove(id) - } - - fn is_empty(&self) -> bool { - self.0.lock().is_empty() - } -} - -#[derive(Clone)] -struct IsClosed(Arc); - -struct IsClosedInner { - is_closed: AtomicBool, - waker: Mutex>, -} - -impl IsClosed { - fn new() -> Self { - Self(Arc::new(IsClosedInner { - is_closed: AtomicBool::new(false), - // Needs at least 2 slots for `manage_connection()` and `heartbeat()` - waker: Mutex::new(Vec::with_capacity(2)), - })) - } - - fn set(&self) { - self.0.is_closed.store(true, Ordering::Release); - - for waker in self.0.waker.lock().drain(..) { - waker.wake(); - } - } -} - -impl Future for IsClosed { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.0.is_closed.load(Ordering::Acquire) { - Poll::Ready(()) - } else { - self.0.waker.lock().push(cx.waker().clone()); - Poll::Pending - } - } -} diff --git a/client/src/relay/incoming.rs b/client/src/relay/incoming.rs deleted file mode 100644 index 993475f..0000000 --- a/client/src/relay/incoming.rs +++ /dev/null @@ -1,157 +0,0 @@ -use super::{ - stream::{IncomingUniStreams, RecvStream}, - Address, Connection, UdpRelayMode, -}; -use bytes::Bytes; -use futures_util::StreamExt; -use quinn::{ConnectionError, Datagrams}; -use std::{ - io::{Error, ErrorKind, Result}, - result::Result as StdResult, -}; -use tokio::{ - io::AsyncReadExt, - sync::oneshot::{self, error::RecvError, Receiver as OneshotReceiver, Sender as OneshotSender}, -}; -use tuic_protocol::Command as TuicCommand; - -pub async fn listen_incoming( - mut next_incoming_rx: UdpRelayMode, Receiver>, -) { - loop { - let (conn, incoming); - (conn, incoming, next_incoming_rx) = match next_incoming_rx { - UdpRelayMode::Native(incoming_rx) => { - let (conn, incoming, next_incoming_rx) = incoming_rx.next().await.unwrap(); // safety: the channel must not be closed unless the whole program is already terminated - ( - conn, - UdpRelayMode::Native(incoming), - UdpRelayMode::Native(next_incoming_rx), - ) - } - UdpRelayMode::Quic(incoming_rx) => { - let (conn, incoming, next_incoming_rx) = incoming_rx.next().await.unwrap(); // safety: the channel must not be closed unless the whole program is already terminated - ( - conn, - UdpRelayMode::Quic(incoming), - UdpRelayMode::Quic(next_incoming_rx), - ) - } - }; - - let err = match incoming { - UdpRelayMode::Native(mut incoming) => loop { - let pkt = match incoming.next().await { - Some(Ok(pkt)) => pkt, - Some(Err(err)) => break err, - None => break ConnectionError::LocallyClosed, - }; - - // process datagram - tokio::spawn(conn.clone().process_incoming_datagram(pkt)); - }, - UdpRelayMode::Quic(mut uni) => loop { - let recv = match uni.next().await { - Some(Ok(recv)) => recv, - Some(Err(err)) => break err, - None => break ConnectionError::LocallyClosed, - }; - - // process uni stream - tokio::spawn(conn.clone().process_incoming_uni_stream(recv)); - }, - }; - - match err { - ConnectionError::LocallyClosed => log::debug!("[relay] [connection] Locally closed"), - ConnectionError::TimedOut => log::debug!("[relay] [connection] Timeout"), - err => log::error!("[relay] [connection] {err}"), - } - - conn.set_closed(); - } -} - -impl Connection { - async fn process_incoming_datagram(self, pkt: Bytes) { - async fn parse_header(pkt: Bytes) -> Result<(u32, Bytes, Address)> { - let cmd = TuicCommand::read_from(&mut pkt.as_ref()).await?; - let cmd_len = cmd.serialized_len(); - - match cmd { - TuicCommand::Packet { - assoc_id, - len, - addr, - } => Ok(( - assoc_id, - pkt.slice(cmd_len..cmd_len + len as usize), - Address::from(addr), - )), - _ => Err(Error::new( - ErrorKind::InvalidData, - "[relay] [connection] Unexpected incoming datagram", - )), - } - } - - match parse_header(pkt).await { - Ok((assoc_id, pkt, addr)) => self.handle_packet_from(assoc_id, pkt, addr).await, - Err(err) => log::warn!("[relay] [connection] {err}"), - } - } - - async fn process_incoming_uni_stream(self, recv: RecvStream) { - async fn parse_header(mut recv: RecvStream) -> Result<(u32, Bytes, Address)> { - let cmd = TuicCommand::read_from(&mut recv).await?; - - match cmd { - TuicCommand::Packet { - assoc_id, - len, - addr, - } => { - let mut buf = vec![0; len as usize]; - recv.read_exact(&mut buf).await?; - let pkt = Bytes::from(buf); - Ok((assoc_id, pkt, Address::from(addr))) - } - _ => Err(Error::new( - ErrorKind::InvalidData, - "[relay] [connection] Unexpected incoming uni stream", - )), - } - } - - match parse_header(recv).await { - Ok((assoc_id, pkt, addr)) => self.handle_packet_from(assoc_id, pkt, addr).await, - Err(err) => log::warn!("[relay] [connection] {err}"), - } - } -} - -pub fn channel() -> (Sender, Receiver) { - let (tx, rx) = oneshot::channel(); - (Sender(tx), Receiver(rx)) -} - -pub struct Sender(OneshotSender<(Connection, T, Receiver)>); - -impl Sender { - pub fn send( - self, - conn: Connection, - incoming: T, - next_incoming_rx: Receiver, - ) -> StdResult<(), (Connection, T, Receiver)> { - self.0.send((conn, incoming, next_incoming_rx)) - } -} - -pub struct Receiver(OneshotReceiver<(Connection, T, Self)>); - -impl Receiver { - async fn next(self) -> StdResult<(Connection, T, Self), RecvError> { - self.0.await - } -} diff --git a/client/src/relay/mod.rs b/client/src/relay/mod.rs deleted file mode 100644 index a0ea1d1..0000000 --- a/client/src/relay/mod.rs +++ /dev/null @@ -1,100 +0,0 @@ -use self::{connection::ConnectionConfig, stream::IncomingUniStreams}; -use quinn::{ClientConfig, Datagrams}; -use std::{ - fmt::{Display, Formatter, Result as FmtResult}, - future::Future, - net::SocketAddr, - sync::{atomic::AtomicUsize, Arc}, -}; -use tokio::sync::{ - mpsc::{self, Sender}, - Mutex as AsyncMutex, -}; - -pub use self::{address::Address, connection::Connection, request::Request}; - -mod address; -mod connection; -mod incoming; -mod request; -mod stream; -mod task; - -pub static MAX_UDP_RELAY_PACKET_SIZE: AtomicUsize = AtomicUsize::new(1500); - -#[allow(clippy::too_many_arguments)] -pub async fn init( - quinn_config: ClientConfig, - server_addr: ServerAddr, - token_digest: [u8; 32], - heartbeat_interval: u64, - reduce_rtt: bool, - udp_relay_mode: UdpRelayMode<(), ()>, - req_timeout: u64, - max_udp_relay_packet_size: usize, -) -> (impl Future, Sender) { - let (req_tx, req_rx) = mpsc::channel(1); - - let config = ConnectionConfig::new( - quinn_config, - server_addr.clone(), - token_digest, - udp_relay_mode, - heartbeat_interval, - reduce_rtt, - max_udp_relay_packet_size, - ); - - let conn = Arc::new(AsyncMutex::new(None)); - let conn_lock = conn.clone().lock_owned().await; - - let (incoming_tx, incoming_rx) = match udp_relay_mode { - UdpRelayMode::Native(()) => { - let (tx, rx) = incoming::channel::(); - (UdpRelayMode::Native(tx), UdpRelayMode::Native(rx)) - } - UdpRelayMode::Quic(()) => { - let (tx, rx) = incoming::channel::(); - (UdpRelayMode::Quic(tx), UdpRelayMode::Quic(rx)) - } - }; - - let (listen_requests, wait_req) = request::listen_requests(conn.clone(), req_rx, req_timeout); - let listen_incoming = incoming::listen_incoming(incoming_rx); - - let manage_connection = - connection::manage_connection(config, conn, conn_lock, incoming_tx, wait_req); - - let task = async move { - log::info!("[relay] Started. Target server: {server_addr}"); - - tokio::select! { - _ = tokio::spawn(manage_connection) => {} - _ = tokio::spawn(listen_requests) => {} - _ = tokio::spawn(listen_incoming) => {} - } - }; - - (task, req_tx) -} - -#[derive(Clone)] -pub enum ServerAddr { - SocketAddr { addr: SocketAddr, name: String }, - DomainAddr { domain: String, port: u16 }, -} - -impl Display for ServerAddr { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - ServerAddr::SocketAddr { addr, name } => write!(f, "{addr} ({name})"), - ServerAddr::DomainAddr { domain, port } => write!(f, "{domain}:{port}"), - } - } -} - -#[derive(Clone, Copy)] -pub enum UdpRelayMode { - Native(N), - Quic(Q), -} diff --git a/client/src/relay/request.rs b/client/src/relay/request.rs deleted file mode 100644 index 030a78a..0000000 --- a/client/src/relay/request.rs +++ /dev/null @@ -1,177 +0,0 @@ -use super::{stream::BiStream, Address, Connection}; -use bytes::Bytes; -use once_cell::sync::Lazy; -use parking_lot::Mutex; -use rand::{rngs::StdRng, RngCore, SeedableRng}; -use std::{ - fmt::{Display, Formatter, Result as FmtResult}, - future::Future, - pin::Pin, - sync::{Arc, Weak}, - task::{Context, Poll, Waker}, - time::Duration, -}; -use tokio::{ - sync::{ - mpsc::{self, Receiver as MpscReceiver, Sender as MpscSender}, - oneshot::{self, Receiver as OneshotReceiver, Sender as OneshotSender}, - Mutex as AsyncMutex, - }, - time, -}; - -pub fn listen_requests( - conn: Arc>>, - mut req_rx: MpscReceiver, - timeout: u64, -) -> (impl Future, Wait) { - let (reg, count) = Register::new(); - - let listen = async move { - while let Some(req) = req_rx.recv().await { - tokio::spawn(process_request(conn.clone(), req, timeout, reg.clone())); - } - }; - - (listen, count) -} - -async fn process_request( - conn: Arc>>, - req: Request, - timeout: u64, - _reg: Register, -) { - log::info!("[relay] [task] {req}"); - - // try to get the current connection - if let Ok(lock) = time::timeout(Duration::from_millis(timeout), conn.lock()).await { - let conn = lock.as_ref().unwrap().clone(); // safety: there must be a connection if the lock is aquirable - drop(lock); - - match req { - Request::Connect { addr, tx } => conn.clone().handle_connect(addr, tx).await, - Request::Associate { - assoc_id, - mut pkt_send_rx, - pkt_recv_tx, - } => { - conn.udp_sessions().insert(assoc_id, pkt_recv_tx); - while let Some((pkt, addr)) = pkt_send_rx.recv().await { - tokio::spawn(conn.clone().handle_packet_to( - assoc_id, - pkt, - addr, - conn.udp_relay_mode(), - )); - } - - log::info!("[relay] [task] [dissociate] [{assoc_id}]"); - conn.clone().udp_sessions().remove(&assoc_id); - conn.handle_dissociate(assoc_id).await; - } - } - } else { - log::warn!("[relay] [task] {req} [timeout]"); - } -} - -pub enum Request { - Connect { - addr: Address, - tx: ConnectResponseSender, - }, - Associate { - assoc_id: u32, - pkt_send_rx: AssociateSendPacketReceiver, - pkt_recv_tx: AssociateRecvPacketSender, - }, -} - -type ConnectResponseSender = OneshotSender; -type ConnectResponseReceiver = OneshotReceiver; -type AssociateSendPacketSender = MpscSender<(Bytes, Address)>; -type AssociateSendPacketReceiver = MpscReceiver<(Bytes, Address)>; -type AssociateRecvPacketSender = MpscSender<(Bytes, Address)>; -type AssociateRecvPacketReceiver = MpscReceiver<(Bytes, Address)>; - -impl Request { - pub fn new_connect(addr: Address) -> (Self, ConnectResponseReceiver) { - let (tx, rx) = oneshot::channel(); - (Request::Connect { addr, tx }, rx) - } - - pub fn new_associate() -> (Self, AssociateSendPacketSender, AssociateRecvPacketReceiver) { - let assoc_id = get_random_u32(); - let (pkt_send_tx, pkt_send_rx) = mpsc::channel(1); - let (pkt_recv_tx, pkt_recv_rx) = mpsc::channel(1); - - ( - Self::Associate { - assoc_id, - pkt_send_rx, - pkt_recv_tx, - }, - pkt_send_tx, - pkt_recv_rx, - ) - } -} - -impl Display for Request { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - Request::Connect { addr, .. } => write!(f, "[connect] [{addr}]"), - Request::Associate { assoc_id, .. } => write!(f, "[associate] [{assoc_id}]"), - } - } -} - -static RNG: Lazy> = Lazy::new(|| Mutex::new(StdRng::from_entropy())); - -fn get_random_u32() -> u32 { - RNG.lock().next_u32() -} - -pub struct Register(Arc>>); - -impl Register { - pub fn new() -> (Self, Wait) { - let reg = Self(Arc::new(Mutex::new(None))); - let count = Wait(Arc::downgrade(®.0)); - (reg, count) - } -} - -impl Clone for Register { - fn clone(&self) -> Self { - let reg = Self(self.0.clone()); - - // wake the `Wait` hold by `guard_connection` - if let Some(waker) = self.0.lock().take() { - waker.wake(); - } - - reg - } -} - -#[derive(Clone)] -pub struct Wait(Weak>>); - -impl Future for Wait { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.0.strong_count() > 1 { - // there is a request waiting - Poll::Ready(()) - } else { - // there is no request waiting, pend the task - if let Some(reg) = self.0.upgrade() { - *reg.lock() = Some(cx.waker().clone()); - } - Poll::Pending - } - } -} diff --git a/client/src/relay/stream.rs b/client/src/relay/stream.rs deleted file mode 100644 index 535c71c..0000000 --- a/client/src/relay/stream.rs +++ /dev/null @@ -1,208 +0,0 @@ -use futures_util::Stream; -use quinn::{ - ConnectionError, IncomingUniStreams as QuinnIncomingUniStreams, RecvStream as QuinnRecvStream, - SendStream as QuinnSendStream, -}; -use std::{ - io::{Error, IoSlice, Result}, - pin::Pin, - result::Result as StdResult, - sync::{Arc, Weak}, - task::{Context, Poll}, -}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -pub struct SendStream { - send: QuinnSendStream, - _reg: Register, -} - -impl SendStream { - #[inline] - pub fn new(send: QuinnSendStream, reg: Register) -> Self { - Self { send, _reg: reg } - } - - #[inline] - pub async fn finish(&mut self) -> Result<()> { - self.send.finish().await.map_err(Error::from) - } -} - -impl AsyncWrite for SendStream { - #[inline] - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.send).poll_write(cx, buf) - } - - #[inline] - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.send).poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.send.is_write_vectored() - } - - #[inline] - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.send).poll_flush(cx) - } - - #[inline] - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.send).poll_shutdown(cx) - } -} - -pub struct RecvStream { - recv: QuinnRecvStream, - _reg: Register, -} - -impl RecvStream { - #[inline] - pub fn new(recv: QuinnRecvStream, reg: Register) -> Self { - Self { recv, _reg: reg } - } -} - -impl AsyncRead for RecvStream { - #[inline] - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.recv).poll_read(cx, buf) - } -} - -pub struct BiStream { - send: SendStream, - recv: RecvStream, -} - -impl BiStream { - #[inline] - pub fn new(send: SendStream, recv: RecvStream) -> Self { - Self { send, recv } - } - - #[inline] - pub async fn finish(&mut self) -> Result<()> { - self.send.finish().await - } -} - -impl AsyncRead for BiStream { - #[inline] - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.recv).poll_read(cx, buf) - } -} - -impl AsyncWrite for BiStream { - #[inline] - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.send).poll_write(cx, buf) - } - - #[inline] - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.send).poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.send.is_write_vectored() - } - - #[inline] - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.send).poll_flush(cx) - } - - #[inline] - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.send).poll_shutdown(cx) - } -} - -pub struct IncomingUniStreams { - incoming: QuinnIncomingUniStreams, - reg: Registry, -} - -impl IncomingUniStreams { - #[inline] - pub fn new(incoming: QuinnIncomingUniStreams, reg: Registry) -> Self { - Self { incoming, reg } - } -} - -impl Stream for IncomingUniStreams { - type Item = StdResult; - - #[inline] - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - if let Some(reg) = self.reg.get_register() { - Pin::new(&mut self.incoming) - .poll_next(cx) - .map_ok(|recv| RecvStream::new(recv, reg)) - } else { - // the connection is already dropped - Poll::Ready(None) - } - } -} - -#[derive(Clone)] -pub struct Register(Arc<()>); - -impl Register { - #[inline] - pub fn new() -> Self { - Self(Arc::new(())) - } - - #[inline] - pub fn get_registry(&self) -> Registry { - Registry(Arc::downgrade(&self.0)) - } - - #[inline] - pub fn count(&self) -> usize { - Arc::strong_count(&self.0) - } -} - -pub struct Registry(Weak<()>); - -impl Registry { - #[inline] - pub fn get_register(&self) -> Option { - self.0.upgrade().map(Register) - } -} diff --git a/client/src/relay/task.rs b/client/src/relay/task.rs deleted file mode 100644 index 2e95d1d..0000000 --- a/client/src/relay/task.rs +++ /dev/null @@ -1,121 +0,0 @@ -use super::{stream::BiStream, Address, Connection, UdpRelayMode}; -use bytes::{Bytes, BytesMut}; -use std::io::Result; -use tokio::{io::AsyncWriteExt, sync::oneshot::Sender as OneshotSender}; -use tuic_protocol::{Address as TuicAddress, Command as TuicCommand}; - -impl Connection { - pub async fn handle_connect(self, addr: Address, tx: OneshotSender) { - async fn negotiate_connect(conn: Connection, addr: Address) -> Result> { - let cmd = TuicCommand::new_connect(TuicAddress::from(addr)); - - let mut stream = conn.get_bi_stream().await?; - cmd.write_to(&mut stream).await?; - - let resp = match TuicCommand::read_from(&mut stream).await { - Ok(resp) => resp, - Err(err) => { - stream.finish().await?; - return Err(err); - } - }; - - if let TuicCommand::Response(true) = resp { - Ok(Some(stream)) - } else { - stream.finish().await?; - Ok(None) - } - } - - let display_addr = format!("{addr}"); - - match negotiate_connect(self, addr).await { - Ok(Some(stream)) => { - log::debug!("[relay] [task] [connect] [{display_addr}] [success]"); - let _ = tx.send(stream); - } - Ok(None) => log::debug!("[relay] [task] [connect] [{display_addr}] [fail]"), - Err(err) => log::warn!("[relay] [task] [connect] [{display_addr}] {err}"), - } - } - - pub async fn handle_packet_to( - self, - assoc_id: u32, - pkt: Bytes, - addr: Address, - mode: UdpRelayMode<(), ()>, - ) { - async fn send_packet( - conn: Connection, - assoc_id: u32, - pkt: Bytes, - addr: Address, - mode: UdpRelayMode<(), ()>, - ) -> Result<()> { - let cmd = TuicCommand::new_packet(assoc_id, pkt.len() as u16, TuicAddress::from(addr)); - - match mode { - UdpRelayMode::Native(()) => { - let mut buf = BytesMut::with_capacity(cmd.serialized_len()); - cmd.write_to_buf(&mut buf); - buf.extend_from_slice(&pkt); - let pkt = buf.freeze(); - conn.send_datagram(pkt)?; - } - UdpRelayMode::Quic(()) => { - let mut send = conn.get_send_stream().await?; - cmd.write_to(&mut send).await?; - send.write_all(&pkt).await?; - send.finish().await?; - } - } - - Ok(()) - } - - self.update_max_udp_relay_packet_size(); - let display_addr = format!("{addr}"); - - match send_packet(self, assoc_id, pkt, addr, mode).await { - Ok(()) => log::debug!( - "[relay] [task] [associate] [{assoc_id}] [send] [{display_addr}] [success]" - ), - Err(err) => { - log::warn!("[relay] [task] [associate] [{assoc_id}] [send] [{display_addr}] {err}") - } - } - } - - pub async fn handle_packet_from(self, assoc_id: u32, pkt: Bytes, addr: Address) { - self.update_max_udp_relay_packet_size(); - let display_addr = format!("{addr}"); - - if let Some(recv_pkt_tx) = self.udp_sessions().get(&assoc_id) { - log::debug!( - "[relay] [task] [associate] [{assoc_id}] [recv] [{display_addr}] [success]" - ); - let _ = recv_pkt_tx.send((pkt, addr)).await; - } else { - log::warn!("[relay] [task] [associate] [{assoc_id}] [recv] [{display_addr}] No corresponding UDP relay session found"); - } - } - - pub async fn handle_dissociate(self, assoc_id: u32) { - async fn send_dissociate(conn: Connection, assoc_id: u32) -> Result<()> { - let cmd = TuicCommand::new_dissociate(assoc_id); - - let mut send = conn.get_send_stream().await?; - cmd.write_to(&mut send).await?; - send.finish().await?; - - Ok(()) - } - - match send_dissociate(self, assoc_id).await { - Ok(()) => log::debug!("[relay] [task] [dissociate] [{assoc_id}] [success]"), - Err(err) => log::warn!("relay] [task] [dissociate] [{assoc_id}] {err}"), - } - } -} diff --git a/client/src/socks5/associate.rs b/client/src/socks5/associate.rs deleted file mode 100644 index 13f1385..0000000 --- a/client/src/socks5/associate.rs +++ /dev/null @@ -1,140 +0,0 @@ -use crate::relay::{self, Address as RelayAddress, Request as RelayRequest}; -use bytes::Bytes; -use socks5_proto::{Address, Reply, UdpHeader}; -use socks5_server::{ - connection::associate::{AssociatedUdpSocket, NeedReply}, - Associate, -}; -use std::{ - io::Result, - net::SocketAddr, - sync::{atomic::Ordering, Arc}, -}; -use tokio::{ - net::UdpSocket, - sync::mpsc::{Receiver, Sender}, -}; -use tuic_protocol::Command as TuicCommand; - -pub async fn handle( - conn: Associate, - req_tx: Sender, - target_addr: Address, -) -> Result<()> { - async fn bind_udp_socket(conn: &Associate) -> Result { - UdpSocket::bind(SocketAddr::from((conn.local_addr()?.ip(), 0))).await - } - - log::info!( - "[socks5] [{}] [associate] [{target_addr}]", - conn.peer_addr()? - ); - - match bind_udp_socket(&conn) - .await - .and_then(|socket| socket.local_addr().map(|addr| (socket, addr))) - { - Ok((socket, socket_addr)) => { - let (relay_req, pkt_send_tx, pkt_recv_rx) = RelayRequest::new_associate(); - let _ = req_tx.send(relay_req).await; - - let mut conn = conn - .reply(Reply::Succeeded, Address::SocketAddress(socket_addr)) - .await?; - - let buf_size = relay::MAX_UDP_RELAY_PACKET_SIZE.load(Ordering::Acquire) - - (TuicCommand::max_serialized_len() - UdpHeader::max_serialized_len()); - let socket = Arc::new(AssociatedUdpSocket::from((socket, buf_size))); - let ctrl_addr = conn.peer_addr()?; - - let res = tokio::select! { - _ = conn.wait_until_closed() => Ok(()), - res = socks5_to_relay(socket.clone(),ctrl_addr, pkt_send_tx) => res, - res = relay_to_socks5(socket,ctrl_addr, pkt_recv_rx) => res, - }; - - let _ = conn.shutdown().await; - - log::info!("[socks5] [{ctrl_addr}] [dissociate] [{target_addr}]"); - - res - } - Err(err) => { - let mut conn = conn - .reply(Reply::GeneralFailure, Address::unspecified()) - .await?; - - let _ = conn.shutdown().await; - Err(err) - } - } -} - -async fn socks5_to_relay( - socket: Arc, - ctrl_addr: SocketAddr, - pkt_send_tx: Sender<(Bytes, RelayAddress)>, -) -> Result<()> { - loop { - let buf_size = relay::MAX_UDP_RELAY_PACKET_SIZE.load(Ordering::Acquire) - - (TuicCommand::max_serialized_len() - UdpHeader::max_serialized_len()); - socket.set_max_packet_size(buf_size); - - let (pkt, frag, dst_addr, src_addr) = socket.recv_from().await?; - - if frag == 0 { - log::debug!("[socks5] [{ctrl_addr}] [associate] [packet-to] {dst_addr}"); - - let dst_addr = match dst_addr { - Address::DomainAddress(domain, port) => RelayAddress::DomainAddress(domain, port), - Address::SocketAddress(addr) => RelayAddress::SocketAddress(addr), - }; - - let _ = pkt_send_tx.send((pkt, dst_addr)).await; - socket.connect(src_addr).await?; - break; - } else { - log::warn!("[socks5] [{ctrl_addr}] [associate] [packet-to] socks5 UDP packet fragment is not supported"); - } - } - - loop { - let buf_size = relay::MAX_UDP_RELAY_PACKET_SIZE.load(Ordering::Acquire) - - (TuicCommand::max_serialized_len() - UdpHeader::max_serialized_len()); - socket.set_max_packet_size(buf_size); - - let (pkt, frag, dst_addr) = socket.recv().await?; - - if frag == 0 { - log::debug!("[socks5] [{ctrl_addr}] [associate] [packet-to] {dst_addr}"); - - let dst_addr = match dst_addr { - Address::DomainAddress(domain, port) => RelayAddress::DomainAddress(domain, port), - Address::SocketAddress(addr) => RelayAddress::SocketAddress(addr), - }; - - let _ = pkt_send_tx.send((pkt, dst_addr)).await; - } else { - log::warn!("[socks5] [{ctrl_addr}] [associate] [packet-to] socks5 UDP packet fragment is not supported"); - } - } -} - -async fn relay_to_socks5( - socket: Arc, - ctrl_addr: SocketAddr, - mut pkt_recv_rx: Receiver<(Bytes, RelayAddress)>, -) -> Result<()> { - while let Some((pkt, src_addr)) = pkt_recv_rx.recv().await { - log::debug!("[socks5] [{ctrl_addr}] [associate] [packet-from] {src_addr}"); - - let src_addr = match src_addr { - RelayAddress::DomainAddress(domain, port) => Address::DomainAddress(domain, port), - RelayAddress::SocketAddress(addr) => Address::SocketAddress(addr), - }; - - socket.send(pkt, 0, src_addr).await?; - } - - Ok(()) -} diff --git a/client/src/socks5/bind.rs b/client/src/socks5/bind.rs deleted file mode 100644 index 43b5f38..0000000 --- a/client/src/socks5/bind.rs +++ /dev/null @@ -1,24 +0,0 @@ -use crate::relay::Request as RelayRequest; -use socks5_proto::{Address, Reply}; -use socks5_server::{connection::bind::NeedFirstReply, Bind}; -use std::io::{Error, ErrorKind, Result}; -use tokio::sync::mpsc::Sender; - -pub async fn handle( - conn: Bind, - _req_tx: Sender, - target_addr: Address, -) -> Result<()> { - log::info!("[socks5] [{}] [bind] [{target_addr}]", conn.peer_addr()?); - - let mut conn = conn - .reply(Reply::CommandNotSupported, Address::unspecified()) - .await?; - - let _ = conn.shutdown().await; - - Err(Error::new( - ErrorKind::Unsupported, - "BIND command is not supported", - )) -} diff --git a/client/src/socks5/connect.rs b/client/src/socks5/connect.rs deleted file mode 100644 index 91ef64e..0000000 --- a/client/src/socks5/connect.rs +++ /dev/null @@ -1,34 +0,0 @@ -use crate::relay::{Address as RelayAddress, Request as RelayRequest}; -use socks5_proto::{Address, Reply}; -use socks5_server::{connection::connect::NeedReply, Connect}; -use std::io::Result; -use tokio::{io, sync::mpsc::Sender}; - -pub async fn handle( - conn: Connect, - req_tx: Sender, - target_addr: Address, -) -> Result<()> { - log::info!("[socks5] [{}] [connect] [{target_addr}]", conn.peer_addr()?); - - let target_addr = match target_addr { - Address::DomainAddress(domain, port) => RelayAddress::DomainAddress(domain, port), - Address::SocketAddress(addr) => RelayAddress::SocketAddress(addr), - }; - - let (relay_req, relay_resp_rx) = RelayRequest::new_connect(target_addr); - let _ = req_tx.send(relay_req).await; - - if let Ok(mut relay) = relay_resp_rx.await { - let mut conn = conn.reply(Reply::Succeeded, Address::unspecified()).await?; - io::copy_bidirectional(&mut conn, &mut relay).await?; - } else { - let mut conn = conn - .reply(Reply::NetworkUnreachable, Address::unspecified()) - .await?; - - let _ = conn.shutdown().await; - } - - Ok(()) -} diff --git a/client/src/socks5/mod.rs b/client/src/socks5/mod.rs deleted file mode 100644 index b7ec5e6..0000000 --- a/client/src/socks5/mod.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::relay::Request as RelayRequest; -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; -use socks5_server::{Auth, Connection, IncomingConnection, Server}; -use std::{ - future::Future, - io::Result, - net::{SocketAddr, TcpListener as StdTcpListener}, - sync::Arc, -}; -use tokio::{net::TcpListener, sync::mpsc::Sender}; - -mod associate; -mod bind; -mod connect; - -pub async fn init( - local_addr: SocketAddr, - auth: Arc, - req_tx: Sender, -) -> Result> { - let socks5 = Socks5::init(local_addr, auth, req_tx).await?; - Ok(socks5.run()) -} - -struct Socks5 { - server: Server, - req_tx: Sender, -} - -impl Socks5 { - async fn init( - local_addr: SocketAddr, - auth: Arc, - req_tx: Sender, - ) -> Result { - let listener = if local_addr.is_ipv4() { - TcpListener::bind(local_addr).await? - } else { - let socket = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))?; - socket.set_only_v6(false)?; - socket.bind(&SockAddr::from(local_addr))?; - socket.listen(128)?; - TcpListener::from_std(StdTcpListener::from(socket))? - }; - - let server = Server::new(listener, auth); - - Ok(Self { server, req_tx }) - } - - async fn run(self) { - async fn handle_connection( - conn: IncomingConnection, - req_tx: Sender, - ) -> Result<()> { - match conn.handshake().await? { - Connection::Connect(conn, addr) => connect::handle(conn, req_tx, addr).await, - Connection::Bind(conn, addr) => bind::handle(conn, req_tx, addr).await, - Connection::Associate(conn, addr) => associate::handle(conn, req_tx, addr).await, - } - } - - match self.server.local_addr() { - Ok(addr) => log::info!("[socks5] Started. Listening: {addr}"), - Err(err) => { - log::error!("[socks5] Failed to get local socks5 server address: {err}"); - return; - } - } - - loop { - let (conn, addr) = match self.server.accept().await { - Ok((conn, addr)) => { - log::debug!("[socks5] [{addr}] [establish]"); - (conn, addr) - } - Err(err) => { - log::warn!("[socks5] Failed to accept connection: {err}"); - continue; - } - }; - - let req_tx = self.req_tx.clone(); - - tokio::spawn(async move { - match handle_connection(conn, req_tx).await { - Ok(()) => log::debug!("[socks5] [{addr}] [disconnect]"), - Err(err) => log::warn!("[socks5] [{addr}] {err}"), - } - }); - } - } -} diff --git a/protocol/Cargo.toml b/protocol/Cargo.toml deleted file mode 100644 index b0ad6c0..0000000 --- a/protocol/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "tuic-protocol" -version = "4.1.2" -authors = ["EAimTY "] -description = "" -categories = ["network-programming"] -keywords = ["tuic", "proxy", "quic"] -edition = "2021" -rust-version = "1.59" -readme = "../README.md" -license = "GPL-3.0-or-later" -repository = "https://github.com/EAimTY/tuic" - -[dependencies] -byteorder = "1.4.*" -bytes = "1.2.*" -tokio = { version = "1.20.*", features = ["io-util"] } diff --git a/protocol/README.md b/protocol/README.md deleted file mode 100644 index cd67fbf..0000000 --- a/protocol/README.md +++ /dev/null @@ -1,202 +0,0 @@ -# tuic-protocol - -TUIC protocol is used to communicate between the TUIC client and the TUIC server. - -## Overview - -TUIC protocol is a stateful protocol. It is designed to be simple yet efficient. The current version is `0x04`. - -## Command - -Relay tasks are negotiated with `Command`s. -All fields are in Big Endian unless otherwise noted. - -```plain -+-----+------+----------+ -| VER | TYPE | OPT | -+-----+------+----------+ -| 1 | 1 | Variable | -+-----+------+----------+ -``` - -where: - -- `VER` - protocol version -- `TYPE` - command type -- `OPT` - command type specific data - -### Command Types - -There are six types of commands: - -- `0x00` - `Authenticate` - used to authenticate the client -- `0x01` - `Connect` - used to request a client-to-server TCP relay -- `0x02` - `Packet` - used to forward a UDP packet -- `0x03` - `Dissociate` - used to stop a UDP relay session -- `0x04` - `Heartbeat` - used to keep a QUIC connection alive -- `0xff` - `Response` - used to respond to a `Command` (currently only used for replying `Connect`) - -### Command Type Specific Data - -#### `Authenticate` - -```plain -+-----+ -| TKN | -+-----+ -| 32 | -+-----+ -``` - -where: - -- `TKN` - authentication token, hashed with [BLAKE3](https://github.com/BLAKE3-team/BLAKE3) - -#### `Connect` - -```plain -+----------+ -| ADDR | -+----------+ -| Variable | -+----------+ -``` - -where: - -- `ADDR` - target address. See [Address](#address) - -#### `Packet` - -```plain -+----------+-----+----------+ -| ASSOC_ID | LEN | ADDR | -+----------+-----+----------+ -| 4 | 2 | Variable | -+----------+-----+----------+ -``` - -where: - -- `ASSOC_ID` - UDP relay session ID. See [UDP relaying](#udp-relaying) -- `LEN` - length of the UDP packet -- `ADDR` - target (command from TUIC client) or source (command from TUIC server) address. See [Address](#address) - -#### `Dissociate` - -```plain -+----------+ -| ASSOC_ID | -+----------+ -| 4 | -+----------+ -``` - -#### `Heartbeat` - -```plain -+-+ -| | -+-+ -| | -+-+ -``` - -#### `Response` - -```plain -+-----+ -| REP | -+-----+ -| 1 | -+-----+ -``` - -where: - -- `REP` - reply code, which can be: - -- `0x00` - SUCCEEDED -- `0xff` - FAILED - -### Address - -```plain -+------+----------+----------+ -| TYPE | ADDR | PORT | -+------+----------+----------+ -| 1 | Variable | 2 | -+------+----------+----------+ -``` - -where: - -- `TYPE` - the address type -- `ADDR` - the address -- `PORT` - the port - -The address type can be one of the following: - -- `0x00` - fully-qualified domain name(the first byte indicates the length of the domain name) -- `0x01` - IPv4 address -- `0x02` - IPv6 address - -## Procedures - -TUIC protocol relies heavily on the multiplex-able trusted channel provided by QUIC. The protocol itself does not provide any security. - -### Authentication - -Once the QUIC connection is established between the server and the client, the client must immediately open a unidirectional stream and send an `Authenticate` command. - -If the authentication token is unmatched, or the server does not receive an authentication request from the client within the set time, the server will close the QUIC connection with specific error code and reason. See [Error Handling](#error-handling) for more details. - -Note that the server will not reply to the `Authenticate` command. The client should close the stream immediately after successfully sending the command. The client can start sending other data without waiting for the `Authenticate` command to be sent. - -The server will accept other streams carrying relay task requests before the authentication is completed, but it will stop after the Command Header is read, and will not do actual processing until the authentication is completed. - -### TCP Relaying - -`Connect` is used to request a client-to-server TCP relay. - -To establish a TCP connection with the target address via the relay server, the client needs to open a bidirectional stream and send a `Connect` command. After the server receives the request, it will try to establish a TCP connection to the target address. Depending on success, the server replies with a `Response` command via the same bidirectional stream. - -If the attempt to connect to the target address fails, the server must close the bidirectional stream as soon as the `Response` transmission is complete. - -If the connection to the target is successful, the server will synchronize the data in the bidirectional stream with the TCP stream between the server and the target address until one of the streams is disconnected. - -### UDP Relaying - -TUIC achieves 0-RTT FullCone UDP forwarding by synchronizing UDP session ID between the client and the server. - -The server should create a UDP session table for each QUIC connection, mapping every associate ID to a UDP socket. - -The associate ID is a 32-bit unsigned integer randomly generated by the client, which is placed in the `Packet` command and appended to the UDP packet data to be sent. When the client wants to send UDP packets using the same UDP socket of the server, the attached associate ID should be the same. - -When the server receives the `Packet` command, it should check whether the attached associate ID is already associated with a UDP socket. If not, the server should allocate a UDP socket for the associate ID. The server will use this UDP socket to send UDP packets requested by the client, and accepting UDP packets from any destination at the same time, appends the `Packet` command then sends back to the client. - -When the client wants to relay a UDP packet, it should send the UDP packet with the `Packet` command attached from: - -- Unidirectional stream (UDP relay mode `quic`) -- Datagram (UDP relay mode `native`) - -When the server receives the first `Packet` command, it will consider that the client is using corresponded UDP relay mode. When the UDP socket associated receives a UDP packet, the server should send the packet back to the client in the same way. - -When a client wants to stop associating a UDP socket, it should notify the server by sending a `Dissociate` command using a unidirectional stream. The server will remove the associate ID and release the UDP socket from the UDP session table. - -When the QUIC connection is disconnected, the server will release all UDP sockets in the connection's UDP session table and delete all sessions. - -### Heartbeat - -Even if there is an unclosed stream between the server and the client, the QUIC connection will still timeout after a period of idle time. This affects the timeout behavior for tasks without persistent data transfer (such as SSH connections). - -To solve this problem, when there is an active relay task (TCP relaying or UDP session), the client should send a `Heartbeat` command to the server every few seconds to keep the connection alive. - -### Error Handling - -When the server detects the following errors, it should close the QUIC connection immediately with the corresponding error code: - -- Protocol Error - `0xfffffff0` - TUIC protocol version mismatch, or the server cannot parse the header -- Authentication Failed - `0xfffffff1` - Authentication token mismatch -- Authentication Timeout - `0xfffffff2` - Authentication timeout -- Bad Command - `0xfffffff3` - Command received from wrong stream / datagram diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs deleted file mode 100644 index 90a05db..0000000 --- a/protocol/src/lib.rs +++ /dev/null @@ -1,355 +0,0 @@ -//! The TUIC protocol - -use byteorder::{BigEndian, ReadBytesExt}; -use bytes::BufMut; -use std::{ - fmt::{Display, Formatter, Result as FmtResult}, - io::{Cursor, Error, ErrorKind, Result}, - net::{Ipv4Addr, Ipv6Addr, SocketAddr}, -}; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - -pub const TUIC_PROTOCOL_VERSION: u8 = 0x04; - -/// Command -/// -/// ```plain -/// +-----+------+----------+ -/// | VER | TYPE | OPT | -/// +-----+------+----------+ -/// | 1 | 1 | Variable | -/// +-----+------+----------+ -/// ``` -#[non_exhaustive] -#[derive(Clone)] -pub enum Command { - Response(bool), - Authenticate { - digest: [u8; 32], - }, - Connect { - addr: Address, - }, - Packet { - assoc_id: u32, - len: u16, - addr: Address, - }, - Dissociate { - assoc_id: u32, - }, - Heartbeat, -} - -impl Command { - const TYPE_RESPONSE: u8 = 0xff; - const TYPE_AUTHENTICATE: u8 = 0x00; - const TYPE_CONNECT: u8 = 0x01; - const TYPE_PACKET: u8 = 0x02; - const TYPE_DISSOCIATE: u8 = 0x03; - const TYPE_HEARTBEAT: u8 = 0x04; - - const RESPONSE_SUCCEEDED: u8 = 0x00; - const RESPONSE_FAILED: u8 = 0xff; - - pub fn new_response(is_succeeded: bool) -> Self { - Self::Response(is_succeeded) - } - - pub fn new_authenticate(digest: [u8; 32]) -> Self { - Self::Authenticate { digest } - } - - pub fn new_connect(addr: Address) -> Self { - Self::Connect { addr } - } - - pub fn new_packet(assoc_id: u32, len: u16, addr: Address) -> Self { - Self::Packet { - assoc_id, - len, - addr, - } - } - - pub fn new_dissociate(assoc_id: u32) -> Self { - Self::Dissociate { assoc_id } - } - - pub fn new_heartbeat() -> Self { - Self::Heartbeat - } - - pub async fn read_from(r: &mut R) -> Result - where - R: AsyncRead + Unpin, - { - let ver = r.read_u8().await?; - - if ver != TUIC_PROTOCOL_VERSION { - return Err(Error::new( - ErrorKind::Unsupported, - format!("Unsupported TUIC version: {ver}"), - )); - } - - let cmd = r.read_u8().await?; - match cmd { - Self::TYPE_RESPONSE => { - let resp = r.read_u8().await?; - match resp { - Self::RESPONSE_SUCCEEDED => Ok(Self::new_response(true)), - Self::RESPONSE_FAILED => Ok(Self::new_response(false)), - _ => Err(Error::new( - ErrorKind::InvalidInput, - format!("Invalid response code: {resp}"), - )), - } - } - Self::TYPE_AUTHENTICATE => { - let mut digest = [0; 32]; - r.read_exact(&mut digest).await?; - Ok(Self::new_authenticate(digest)) - } - Self::TYPE_CONNECT => { - let addr = Address::read_from(r).await?; - Ok(Self::new_connect(addr)) - } - Self::TYPE_PACKET => { - let mut buf = [0; 6]; - r.read_exact(&mut buf).await?; - let mut rdr = Cursor::new(buf); - - let assoc_id = ReadBytesExt::read_u32::(&mut rdr).unwrap(); - let len = ReadBytesExt::read_u16::(&mut rdr).unwrap(); - let addr = Address::read_from(r).await?; - - Ok(Self::new_packet(assoc_id, len, addr)) - } - Self::TYPE_DISSOCIATE => { - let assoc_id = r.read_u32().await?; - Ok(Self::new_dissociate(assoc_id)) - } - Self::TYPE_HEARTBEAT => Ok(Self::new_heartbeat()), - _ => Err(Error::new( - ErrorKind::InvalidInput, - format!("Invalid command: {cmd}"), - )), - } - } - - pub async fn write_to(&self, w: &mut W) -> Result<()> - where - W: AsyncWrite + Unpin, - { - let mut buf = Vec::with_capacity(self.serialized_len()); - self.write_to_buf(&mut buf); - w.write_all(&buf).await - } - - pub fn write_to_buf(&self, buf: &mut B) { - buf.put_u8(TUIC_PROTOCOL_VERSION); - - match self { - Self::Response(is_succeeded) => { - buf.put_u8(Self::TYPE_RESPONSE); - if *is_succeeded { - buf.put_u8(Self::RESPONSE_SUCCEEDED); - } else { - buf.put_u8(Self::RESPONSE_FAILED); - } - } - Self::Authenticate { digest } => { - buf.put_u8(Self::TYPE_AUTHENTICATE); - buf.put_slice(digest); - } - Self::Connect { addr } => { - buf.put_u8(Self::TYPE_CONNECT); - addr.write_to_buf(buf); - } - Self::Packet { - assoc_id, - len, - addr, - } => { - buf.put_u8(Self::TYPE_PACKET); - buf.put_u32(*assoc_id); - buf.put_u16(*len); - addr.write_to_buf(buf); - } - Self::Dissociate { assoc_id } => { - buf.put_u8(Self::TYPE_DISSOCIATE); - buf.put_u32(*assoc_id); - } - Self::Heartbeat => { - buf.put_u8(Self::TYPE_HEARTBEAT); - } - } - } - - pub fn serialized_len(&self) -> usize { - 2 + match self { - Self::Response(_) => 1, - Self::Authenticate { .. } => 32, - Self::Connect { addr } => addr.serialized_len(), - Self::Packet { addr, .. } => 6 + addr.serialized_len(), - Self::Dissociate { .. } => 4, - Self::Heartbeat => 0, - } - } - - pub const fn max_serialized_len() -> usize { - 2 + 6 + Address::max_serialized_len() - } -} - -/// Address -/// -/// ```plain -/// +------+----------+----------+ -/// | TYPE | ADDR | PORT | -/// +------+----------+----------+ -/// | 1 | Variable | 2 | -/// +------+----------+----------+ -/// ``` -/// -/// The address type can be one of the following: -/// 0x00: fully-qualified domain name (the first byte indicates the length of the domain name) -/// 0x01: IPv4 address -/// 0x02: IPv6 address -#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub enum Address { - DomainAddress(String, u16), - SocketAddress(SocketAddr), -} - -impl Address { - const TYPE_DOMAIN: u8 = 0x00; - const TYPE_IPV4: u8 = 0x01; - const TYPE_IPV6: u8 = 0x02; - - pub async fn read_from(stream: &mut R) -> Result - where - R: AsyncRead + Unpin, - { - let addr_type = stream.read_u8().await?; - - match addr_type { - Self::TYPE_DOMAIN => { - let len = stream.read_u8().await? as usize; - - let mut buf = vec![0; len + 2]; - stream.read_exact(&mut buf).await?; - - let port = ReadBytesExt::read_u16::(&mut &buf[len..]).unwrap(); - buf.truncate(len); - - let addr = String::from_utf8(buf).map_err(|err| { - Error::new( - ErrorKind::InvalidData, - format!("Invalid address encoding: {err}"), - ) - })?; - - Ok(Self::DomainAddress(addr, port)) - } - Self::TYPE_IPV4 => { - let mut buf = [0; 6]; - stream.read_exact(&mut buf).await?; - let mut rdr = Cursor::new(buf); - - let addr = Ipv4Addr::new( - ReadBytesExt::read_u8(&mut rdr).unwrap(), - ReadBytesExt::read_u8(&mut rdr).unwrap(), - ReadBytesExt::read_u8(&mut rdr).unwrap(), - ReadBytesExt::read_u8(&mut rdr).unwrap(), - ); - - let port = ReadBytesExt::read_u16::(&mut rdr).unwrap(); - - Ok(Self::SocketAddress(SocketAddr::from((addr, port)))) - } - Self::TYPE_IPV6 => { - let mut buf = [0; 18]; - stream.read_exact(&mut buf).await?; - let mut rdr = Cursor::new(buf); - - let addr = Ipv6Addr::new( - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ); - - let port = ReadBytesExt::read_u16::(&mut rdr).unwrap(); - - Ok(Self::SocketAddress(SocketAddr::from((addr, port)))) - } - _ => Err(Error::new( - ErrorKind::InvalidInput, - format!("Unsupported address type: {addr_type}"), - )), - } - } - - pub async fn write_to(&self, writer: &mut W) -> Result<()> - where - W: AsyncWrite + Unpin, - { - let mut buf = Vec::with_capacity(self.serialized_len()); - self.write_to_buf(&mut buf); - writer.write_all(&buf).await - } - - pub fn write_to_buf(&self, buf: &mut B) { - match self { - Self::DomainAddress(addr, port) => { - buf.put_u8(Self::TYPE_DOMAIN); - buf.put_u8(addr.len() as u8); - buf.put_slice(addr.as_bytes()); - buf.put_u16(*port); - } - Self::SocketAddress(addr) => match addr { - SocketAddr::V4(addr) => { - buf.put_u8(Self::TYPE_IPV4); - buf.put_slice(&addr.ip().octets()); - buf.put_u16(addr.port()); - } - SocketAddr::V6(addr) => { - buf.put_u8(Self::TYPE_IPV6); - for seg in addr.ip().segments() { - buf.put_u16(seg); - } - buf.put_u16(addr.port()); - } - }, - } - } - - pub fn serialized_len(&self) -> usize { - 1 + match self { - Address::DomainAddress(addr, _) => 1 + addr.len() + 2, - Address::SocketAddress(addr) => match addr { - SocketAddr::V4(_) => 6, - SocketAddr::V6(_) => 18, - }, - } - } - - pub const fn max_serialized_len() -> usize { - 1 + 1 + u8::MAX as usize + 2 - } -} - -impl Display for Address { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - Self::DomainAddress(addr, port) => write!(f, "{addr}:{port}"), - Self::SocketAddress(addr) => write!(f, "{addr}"), - } - } -} diff --git a/server/Cargo.toml b/server/Cargo.toml deleted file mode 100644 index a9c399d..0000000 --- a/server/Cargo.toml +++ /dev/null @@ -1,32 +0,0 @@ -[package] -name = "tuic-server" -version = "0.8.5" -authors = ["EAimTY "] -description = "Delicately-TUICed high-performance proxy built on top of the QUIC protocol" -categories = ["network-programming", "command-line-utilities"] -keywords = ["tuic", "proxy", "quic"] -edition = "2021" -rust-version = "1.59" -readme = "../README.md" -license = "GPL-3.0-or-later" -repository = "https://github.com/EAimTY/tuic" - -[dependencies] -tuic-protocol = { path="../protocol" } - -blake3 = "1.3.*" -bytes = "1.2.*" -crossbeam-utils = { version = "0.8.*", default-features = false } -env_logger = { version = "0.9.*", features = ["humantime"], default-features = false } -futures-util = { version = "0.3.*", default-features = false } -getopts = "0.2.*" -log = { version = "0.4.*", features = ["serde", "std"] } -parking_lot = { version = "0.12.*", features = ["send_guard"] } -quinn = "0.8.*" -rustls = { version = "0.20.*", features = ["quic"], default-features = false } -rustls-pemfile = "1.0.*" -serde = { version = "1.0.*", features = ["derive", "std"], default-features = false } -serde_json = { version = "1.0.*", features = ["std"], default-features = false } -socket2 = "0.4.*" -thiserror = "1.0.*" -tokio = { version = "1.20.*", features = ["io-util", "macros", "net", "parking_lot", "rt-multi-thread", "sync", "time"] } diff --git a/server/main.rs b/server/main.rs new file mode 100644 index 0000000..f084790 --- /dev/null +++ b/server/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello World!"); +} \ No newline at end of file diff --git a/server/src/certificate.rs b/server/src/certificate.rs deleted file mode 100644 index 3f754bd..0000000 --- a/server/src/certificate.rs +++ /dev/null @@ -1,39 +0,0 @@ -use rustls::{Certificate, PrivateKey}; -use rustls_pemfile::Item; -use std::{ - fs::{self, File}, - io::{BufReader, Error as IoError}, -}; - -pub fn load_certificates(path: &str) -> Result, IoError> { - let mut file = BufReader::new(File::open(path)?); - let mut certs = Vec::new(); - - while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) { - if let Item::X509Certificate(cert) = item { - certs.push(Certificate(cert)); - } - } - - if certs.is_empty() { - certs = vec![Certificate(fs::read(path)?)]; - } - - Ok(certs) -} - -pub fn load_private_key(path: &str) -> Result { - let mut file = BufReader::new(File::open(path)?); - let mut priv_key = None; - - while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) { - if let Item::RSAKey(key) | Item::PKCS8Key(key) | Item::ECKey(key) = item { - priv_key = Some(key); - } - } - - priv_key - .map(Ok) - .unwrap_or_else(|| fs::read(path)) - .map(PrivateKey) -} diff --git a/server/src/config.rs b/server/src/config.rs deleted file mode 100644 index 7c8083b..0000000 --- a/server/src/config.rs +++ /dev/null @@ -1,428 +0,0 @@ -use crate::certificate; -use getopts::{Fail, Options}; -use log::{LevelFilter, ParseLevelError}; -use quinn::{ - congestion::{BbrConfig, CubicConfig, NewRenoConfig}, - IdleTimeout, ServerConfig, VarInt, -}; -use rustls::{version::TLS13, Error as RustlsError, ServerConfig as RustlsServerConfig}; -use serde::{de::Error as DeError, Deserialize, Deserializer}; -use serde_json::Error as JsonError; -use std::{ - collections::HashSet, - env::ArgsOs, - fmt::Display, - fs::File, - io::Error as IoError, - net::{AddrParseError, IpAddr, Ipv4Addr, SocketAddr}, - num::ParseIntError, - str::FromStr, - sync::Arc, - time::Duration, -}; -use thiserror::Error; - -pub struct Config { - pub server_config: ServerConfig, - pub listen_addr: SocketAddr, - pub token: HashSet<[u8; 32]>, - pub authentication_timeout: Duration, - pub max_udp_relay_packet_size: usize, - pub log_level: LevelFilter, -} - -impl Config { - pub fn parse(args: ArgsOs) -> Result { - let raw = RawConfig::parse(args)?; - - let server_config = { - let cert_path = raw.certificate.unwrap(); - let certs = certificate::load_certificates(&cert_path) - .map_err(|err| ConfigError::Io(cert_path, err))?; - - let priv_key_path = raw.private_key.unwrap(); - let priv_key = certificate::load_private_key(&priv_key_path) - .map_err(|err| ConfigError::Io(priv_key_path, err))?; - - let mut crypto = RustlsServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&TLS13]) - .unwrap() - .with_no_client_auth() - .with_single_cert(certs, priv_key)?; - - crypto.max_early_data_size = u32::MAX; - crypto.alpn_protocols = raw.alpn.into_iter().map(|alpn| alpn.into_bytes()).collect(); - - let mut config = ServerConfig::with_crypto(Arc::new(crypto)); - let transport = Arc::get_mut(&mut config.transport).unwrap(); - - match raw.congestion_controller { - CongestionController::Bbr => { - transport.congestion_controller_factory(Arc::new(BbrConfig::default())); - } - CongestionController::Cubic => { - transport.congestion_controller_factory(Arc::new(CubicConfig::default())); - } - CongestionController::NewReno => { - transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); - } - } - - transport - .max_idle_timeout(Some(IdleTimeout::from(VarInt::from_u32(raw.max_idle_time)))); - - config - }; - - let listen_addr = SocketAddr::from((raw.ip, raw.port.unwrap())); - - let token = raw - .token - .into_iter() - .map(|token| *blake3::hash(&token.into_bytes()).as_bytes()) - .collect(); - - let authentication_timeout = Duration::from_secs(raw.authentication_timeout); - let max_udp_relay_packet_size = raw.max_udp_relay_packet_size; - let log_level = raw.log_level; - - Ok(Self { - server_config, - listen_addr, - token, - authentication_timeout, - max_udp_relay_packet_size, - log_level, - }) - } -} - -#[derive(Deserialize)] -#[serde(deny_unknown_fields)] -struct RawConfig { - port: Option, - token: Vec, - certificate: Option, - private_key: Option, - - #[serde(default = "default::ip")] - ip: IpAddr, - - #[serde( - default = "default::congestion_controller", - deserialize_with = "deserialize_from_str" - )] - congestion_controller: CongestionController, - - #[serde(default = "default::max_idle_time")] - max_idle_time: u32, - - #[serde(default = "default::authentication_timeout")] - authentication_timeout: u64, - - #[serde(default = "default::alpn")] - alpn: Vec, - - #[serde(default = "default::max_udp_relay_packet_size")] - max_udp_relay_packet_size: usize, - - #[serde(default = "default::log_level")] - log_level: LevelFilter, -} - -impl Default for RawConfig { - fn default() -> Self { - Self { - port: None, - token: Vec::new(), - certificate: None, - private_key: None, - ip: default::ip(), - congestion_controller: default::congestion_controller(), - max_idle_time: default::max_idle_time(), - authentication_timeout: default::authentication_timeout(), - alpn: default::alpn(), - max_udp_relay_packet_size: default::max_udp_relay_packet_size(), - log_level: default::log_level(), - } - } -} - -impl RawConfig { - fn parse(args: ArgsOs) -> Result { - let mut opts = Options::new(); - - opts.optopt( - "c", - "config", - "Read configuration from a file. Note that command line arguments will override the configuration file", - "CONFIG_FILE", - ); - - opts.optopt("", "port", "Set the server listening port", "SERVER_PORT"); - - opts.optmulti( - "", - "token", - "Set the token for TUIC authentication. This option can be used multiple times to set multiple tokens.", - "TOKEN", - ); - - opts.optopt( - "", - "certificate", - "Set the X.509 certificate. This must be an end-entity certificate", - "CERTIFICATE", - ); - - opts.optopt( - "", - "private-key", - "Set the certificate private key", - "PRIVATE_KEY", - ); - - opts.optopt( - "", - "ip", - "Set the server listening IP. Default: 0.0.0.0", - "IP", - ); - - opts.optopt( - "", - "congestion-controller", - r#"Set the congestion control algorithm. Available: "cubic", "new_reno", "bbr". Default: "cubic""#, - "CONGESTION_CONTROLLER", - ); - - opts.optopt( - "", - "max-idle-time", - "Set the maximum idle time for QUIC connections, in milliseconds. Default: 15000", - "MAX_IDLE_TIME", - ); - - opts.optopt( - "", - "authentication-timeout", - "Set the maximum time allowed between a QUIC connection established and the TUIC authentication packet received, in milliseconds. Default: 1000", - "AUTHENTICATION_TIMEOUT", - ); - - opts.optmulti( - "", - "alpn", - "Set ALPN protocols that the server accepts. This option can be used multiple times to set multiple ALPN protocols. If not set, the server will not check ALPN at all", - "ALPN_PROTOCOL", - ); - - opts.optopt( - "", - "max-udp-relay-packet-size", - "UDP relay mode QUIC can transmit UDP packets larger than the MTU. Set this to a higher value allows outbound to receive larger UDP packet. Default: 1500", - "MAX_UDP_RELAY_PACKET_SIZE", - ); - - opts.optopt( - "", - "log-level", - r#"Set the log level. Available: "off", "error", "warn", "info", "debug", "trace". Default: "info""#, - "LOG_LEVEL", - ); - - opts.optflag("v", "version", "Print the version"); - opts.optflag("h", "help", "Print this help menu"); - - let matches = opts.parse(args.skip(1))?; - - if matches.opt_present("help") { - return Err(ConfigError::Help(opts.usage(env!("CARGO_PKG_NAME")))); - } - - if matches.opt_present("version") { - return Err(ConfigError::Version(env!("CARGO_PKG_VERSION"))); - } - - if !matches.free.is_empty() { - return Err(ConfigError::UnexpectedArguments(matches.free.join(", "))); - } - - let port = matches.opt_str("port").map(|port| port.parse()); - let token = matches.opt_strs("token"); - let certificate = matches.opt_str("certificate"); - let private_key = matches.opt_str("private-key"); - - let mut raw = if let Some(path) = matches.opt_str("config") { - let mut raw = RawConfig::from_file(path)?; - - raw.port = Some( - port.transpose()? - .or(raw.port) - .ok_or(ConfigError::MissingOption("port"))?, - ); - - if !token.is_empty() { - raw.token = token; - } else if raw.token.is_empty() { - return Err(ConfigError::MissingOption("token")); - } - - raw.certificate = Some( - certificate - .or(raw.certificate) - .ok_or(ConfigError::MissingOption("certificate"))?, - ); - - raw.private_key = Some( - private_key - .or(raw.private_key) - .ok_or(ConfigError::MissingOption("private key"))?, - ); - - raw - } else { - RawConfig { - port: Some(port.ok_or(ConfigError::MissingOption("port"))??), - token: (!token.is_empty()) - .then(|| token) - .ok_or(ConfigError::MissingOption("token"))?, - certificate: Some(certificate.ok_or(ConfigError::MissingOption("certificate"))?), - private_key: Some(private_key.ok_or(ConfigError::MissingOption("private key"))?), - ..Default::default() - } - }; - - if let Some(ip) = matches.opt_str("ip") { - raw.ip = ip.parse()?; - }; - - if let Some(cgstn_ctrl) = matches.opt_str("congestion-controller") { - raw.congestion_controller = cgstn_ctrl.parse()?; - }; - - if let Some(timeout) = matches.opt_str("max-idle-time") { - raw.max_idle_time = timeout.parse()?; - }; - - if let Some(timeout) = matches.opt_str("authentication-timeout") { - raw.authentication_timeout = timeout.parse()?; - }; - - if let Some(size) = matches.opt_str("max-udp-relay-packet-size") { - raw.max_udp_relay_packet_size = size.parse()?; - }; - - let alpn = matches.opt_strs("alpn"); - - if !alpn.is_empty() { - raw.alpn = alpn; - } - - if let Some(log_level) = matches.opt_str("log-level") { - raw.log_level = log_level.parse()?; - }; - - Ok(raw) - } - - fn from_file(path: String) -> Result { - let file = File::open(&path).map_err(|err| ConfigError::Io(path, err))?; - let raw = serde_json::from_reader(file)?; - Ok(raw) - } -} - -enum CongestionController { - Cubic, - NewReno, - Bbr, -} - -impl FromStr for CongestionController { - type Err = ConfigError; - - fn from_str(s: &str) -> Result { - if s.eq_ignore_ascii_case("cubic") { - Ok(CongestionController::Cubic) - } else if s.eq_ignore_ascii_case("new_reno") || s.eq_ignore_ascii_case("newreno") { - Ok(CongestionController::NewReno) - } else if s.eq_ignore_ascii_case("bbr") { - Ok(CongestionController::Bbr) - } else { - Err(ConfigError::InvalidCongestionController) - } - } -} - -fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result -where - T: FromStr, - ::Err: Display, - D: Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - T::from_str(&s).map_err(DeError::custom) -} - -mod default { - use super::*; - - pub(super) fn ip() -> IpAddr { - IpAddr::V4(Ipv4Addr::UNSPECIFIED) - } - - pub(super) const fn congestion_controller() -> CongestionController { - CongestionController::Cubic - } - - pub(super) const fn max_idle_time() -> u32 { - 15000 - } - - pub(super) const fn authentication_timeout() -> u64 { - 1000 - } - - pub(super) const fn alpn() -> Vec { - Vec::new() - } - - pub(super) const fn max_udp_relay_packet_size() -> usize { - 1500 - } - - pub(super) const fn log_level() -> LevelFilter { - LevelFilter::Info - } -} - -#[derive(Error, Debug)] -pub enum ConfigError { - #[error("{0}")] - Help(String), - #[error("{0}")] - Version(&'static str), - #[error("Failed to read '{0}': {1}")] - Io(String, #[source] IoError), - #[error("Failed to parse the config file: {0}")] - ParseConfigJson(#[from] JsonError), - #[error(transparent)] - ParseArgument(#[from] Fail), - #[error("Unexpected arguments: {0}")] - UnexpectedArguments(String), - #[error("Missing option: {0}")] - MissingOption(&'static str), - #[error(transparent)] - ParseInt(#[from] ParseIntError), - #[error(transparent)] - ParseAddr(#[from] AddrParseError), - #[error("Invalid congestion controller")] - InvalidCongestionController, - #[error(transparent)] - ParseLogLevel(#[from] ParseLevelError), - #[error("Failed to load certificate / private key: {0}")] - Rustls(#[from] RustlsError), -} diff --git a/server/src/connection/authenticate.rs b/server/src/connection/authenticate.rs deleted file mode 100644 index 7575382..0000000 --- a/server/src/connection/authenticate.rs +++ /dev/null @@ -1,53 +0,0 @@ -use super::IsClosed; -use parking_lot::Mutex; -use std::{ - future::Future, - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll, Waker}, -}; - -#[derive(Clone)] -pub struct IsAuthenticated { - is_connection_closed: IsClosed, - is_authenticated: Arc, - broadcast: Arc>>, -} - -impl IsAuthenticated { - pub fn new(is_closed: IsClosed) -> Self { - Self { - is_connection_closed: is_closed, - is_authenticated: Arc::new(AtomicBool::new(false)), - broadcast: Arc::new(Mutex::new(Vec::new())), - } - } - - pub fn set_authenticated(&self) { - self.is_authenticated.store(true, Ordering::Release); - } - - pub fn wake(&self) { - for waker in self.broadcast.lock().drain(..) { - waker.wake(); - } - } -} - -impl Future for IsAuthenticated { - type Output = bool; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.is_connection_closed.check() { - Poll::Ready(false) - } else if self.is_authenticated.load(Ordering::Relaxed) { - Poll::Ready(true) - } else { - self.broadcast.lock().push(cx.waker().clone()); - Poll::Pending - } - } -} diff --git a/server/src/connection/dispatch.rs b/server/src/connection/dispatch.rs deleted file mode 100644 index 6b5ef8d..0000000 --- a/server/src/connection/dispatch.rs +++ /dev/null @@ -1,226 +0,0 @@ -use super::{task, Connection, UdpPacketSource}; -use bytes::Bytes; -use quinn::{RecvStream, SendStream, VarInt}; -use std::io::Error as IoError; -use thiserror::Error; -use tuic_protocol::{Address, Command}; - -impl Connection { - pub async fn process_uni_stream(&self, mut stream: RecvStream) -> Result<(), DispatchError> { - let rmt_addr = self.controller.remote_address(); - let cmd = Command::read_from(&mut stream).await?; - - if let Command::Authenticate { digest } = cmd { - if self.token.contains(&digest) { - log::debug!("[{rmt_addr}] [authentication]"); - - self.is_authenticated.set_authenticated(); - self.is_authenticated.wake(); - return Ok(()); - } else { - let err = DispatchError::AuthenticationFailed; - self.controller - .close(err.as_error_code(), err.to_string().as_bytes()); - self.is_authenticated.wake(); - return Err(err); - } - } - - if self.is_authenticated.clone().await { - match cmd { - Command::Authenticate { .. } => unreachable!(), - Command::Packet { - assoc_id, - len, - addr, - } => { - if self.udp_packet_from.uni_stream() { - let dst_addr = addr.to_string(); - log::debug!("[{rmt_addr}] [packet-from-quic] [{assoc_id}] [{dst_addr}]"); - - let res = task::packet_from_uni_stream( - stream, - self.udp_sessions.clone(), - assoc_id, - len, - addr, - rmt_addr, - ) - .await; - - match res { - Ok(()) => {} - Err(err) => log::warn!( - "[{rmt_addr}] [packet-from-quic] [{assoc_id}] [{dst_addr}] {err}" - ), - } - - Ok(()) - } else { - Err(DispatchError::BadCommand) - } - } - Command::Dissociate { assoc_id } => { - let res = task::dissociate(self.udp_sessions.clone(), assoc_id, rmt_addr).await; - - match res { - Ok(()) => {} - Err(err) => log::warn!("[{rmt_addr}] [dissociate] {err}"), - } - - Ok(()) - } - Command::Heartbeat => { - log::debug!("[{rmt_addr}] [heartbeat]"); - Ok(()) - } - _ => Err(DispatchError::BadCommand), - } - } else { - Err(DispatchError::AuthenticationTimeout) - } - } - - pub async fn process_bi_stream( - &self, - send: SendStream, - mut recv: RecvStream, - ) -> Result<(), DispatchError> { - let cmd = Command::read_from(&mut recv).await?; - let rmt_addr = self.controller.remote_address(); - - if self.is_authenticated.clone().await { - match cmd { - Command::Connect { addr } => { - let dst_addr = addr.to_string(); - log::info!("[{rmt_addr}] [connect] [{dst_addr}]"); - - let res = task::connect(send, recv, addr).await; - - match res { - Ok(()) => {} - Err(err) => log::warn!("[{rmt_addr}] [connect] [{dst_addr}] {err}"), - } - - Ok(()) - } - _ => Err(DispatchError::BadCommand), - } - } else { - Err(DispatchError::AuthenticationTimeout) - } - } - - pub async fn process_datagram(&self, datagram: Bytes) -> Result<(), DispatchError> { - let cmd = Command::read_from(&mut datagram.as_ref()).await?; - let rmt_addr = self.controller.remote_address(); - let cmd_len = cmd.serialized_len(); - - if self.is_authenticated.clone().await { - match cmd { - Command::Packet { assoc_id, addr, .. } => { - if self.udp_packet_from.datagram() { - let dst_addr = addr.to_string(); - log::debug!("[{rmt_addr}] [packet-from-native] [{assoc_id}] [{dst_addr}]"); - - let res = task::packet_from_datagram( - datagram.slice(cmd_len..), - self.udp_sessions.clone(), - assoc_id, - addr, - rmt_addr, - ) - .await; - - match res { - Ok(()) => {} - Err(err) => { - log::warn!( - "[{rmt_addr}] [packet-from-native] [{assoc_id}] [{dst_addr}] {err}" - ) - } - } - - Ok(()) - } else { - Err(DispatchError::BadCommand) - } - } - _ => Err(DispatchError::BadCommand), - } - } else { - Err(DispatchError::AuthenticationTimeout) - } - } - - pub async fn process_received_udp_packet( - &self, - assoc_id: u32, - pkt: Bytes, - addr: Address, - ) -> Result<(), DispatchError> { - let rmt_addr = self.controller.remote_address(); - let dst_addr = addr.to_string(); - - match self.udp_packet_from.check().unwrap() { - UdpPacketSource::UniStream => { - log::debug!("[{rmt_addr}] [packet-to-quic] [{assoc_id}] [{dst_addr}]"); - - let res = - task::packet_to_uni_stream(self.controller.clone(), assoc_id, pkt, addr).await; - - match res { - Ok(()) => {} - Err(err) => { - log::warn!("[{rmt_addr}] [packet-to-quic] [{assoc_id}] [{dst_addr}] {err}") - } - } - } - UdpPacketSource::Datagram => { - log::debug!("[{rmt_addr}] [packet-to-native] [{assoc_id}] [{dst_addr}]"); - - let res = - task::packet_to_datagram(self.controller.clone(), assoc_id, pkt, addr).await; - - match res { - Ok(()) => {} - Err(err) => { - log::warn!( - "[{rmt_addr}] [packet-to-native] [{assoc_id}] [{dst_addr}] {err}" - ) - } - } - } - } - - Ok(()) - } -} - -#[derive(Error, Debug)] -pub enum DispatchError { - #[error(transparent)] - Io(#[from] IoError), - #[error("authentication failed")] - AuthenticationFailed, - #[error("authentication timeout")] - AuthenticationTimeout, - #[error("bad command")] - BadCommand, -} - -impl DispatchError { - const CODE_PROTOCOL: VarInt = VarInt::from_u32(0xfffffff0); - const CODE_AUTHENTICATION_FAILED: VarInt = VarInt::from_u32(0xfffffff1); - const CODE_AUTHENTICATION_TIMEOUT: VarInt = VarInt::from_u32(0xfffffff2); - const CODE_BAD_COMMAND: VarInt = VarInt::from_u32(0xfffffff3); - - pub fn as_error_code(&self) -> VarInt { - match self { - Self::Io(_) => Self::CODE_PROTOCOL, - Self::AuthenticationFailed => Self::CODE_AUTHENTICATION_FAILED, - Self::AuthenticationTimeout => Self::CODE_AUTHENTICATION_TIMEOUT, - Self::BadCommand => Self::CODE_BAD_COMMAND, - } - } -} diff --git a/server/src/connection/mod.rs b/server/src/connection/mod.rs deleted file mode 100644 index a4c8ff0..0000000 --- a/server/src/connection/mod.rs +++ /dev/null @@ -1,258 +0,0 @@ -use self::{ - authenticate::IsAuthenticated, - dispatch::DispatchError, - udp::{RecvPacketReceiver, UdpPacketFrom, UdpPacketSource, UdpSessionMap}, -}; -use futures_util::StreamExt; -use parking_lot::Mutex; -use quinn::{ - Connecting, Connection as QuinnConnection, ConnectionError, Datagrams, IncomingBiStreams, - IncomingUniStreams, NewConnection, -}; -use std::{ - collections::HashSet, - future::Future, - pin::Pin, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - task::{Context, Poll, Waker}, - time::Duration, -}; -use tokio::time; - -mod authenticate; -mod dispatch; -mod task; -mod udp; - -#[derive(Clone)] -pub struct Connection { - controller: QuinnConnection, - udp_packet_from: UdpPacketFrom, - udp_sessions: Arc, - token: Arc>, - is_authenticated: IsAuthenticated, -} - -impl Connection { - pub async fn handle( - conn: Connecting, - token: Arc>, - auth_timeout: Duration, - max_pkt_size: usize, - ) { - let rmt_addr = conn.remote_address(); - - match conn.await { - Ok(NewConnection { - connection, - uni_streams, - bi_streams, - datagrams, - .. - }) => { - log::debug!("[{rmt_addr}] [establish]"); - - let (udp_sessions, recv_pkt_rx) = UdpSessionMap::new(max_pkt_size); - let is_closed = IsClosed::new(); - let is_authed = IsAuthenticated::new(is_closed.clone()); - - let conn = Self { - controller: connection, - udp_packet_from: UdpPacketFrom::new(), - udp_sessions: Arc::new(udp_sessions), - token, - is_authenticated: is_authed, - }; - - let res = tokio::select! { - res = Self::listen_uni_streams(conn.clone(), uni_streams) => res, - res = Self::listen_bi_streams(conn.clone(), bi_streams) => res, - res = Self::listen_datagrams(conn.clone(), datagrams) => res, - res = Self::listen_received_udp_packet(conn.clone(), recv_pkt_rx) => res, - Err(err) = Self::handle_authentication_timeout(conn, auth_timeout) => Err(err), - }; - - match res { - Ok(()) => unreachable!(), - Err(err) => { - is_closed.set_closed(); - - match err { - ConnectionError::TimedOut => { - log::debug!("[{rmt_addr}] [disconnect] [connection timeout]") - } - ConnectionError::LocallyClosed => { - log::debug!("[{rmt_addr}] [disconnect] [locally closed]") - } - err => log::error!("[{rmt_addr}] [disconnect] {err}"), - } - } - } - } - Err(err) => log::error!("[{rmt_addr}] {err}"), - } - } - - async fn listen_uni_streams( - self, - mut uni_streams: IncomingUniStreams, - ) -> Result<(), ConnectionError> { - while let Some(stream) = uni_streams.next().await { - let stream = stream?; - let conn = self.clone(); - - tokio::spawn(async move { - match conn.process_uni_stream(stream).await { - Ok(()) => {} - Err(err) => { - conn.controller - .close(err.as_error_code(), err.to_string().as_bytes()); - - let rmt_addr = conn.controller.remote_address(); - log::error!("[{rmt_addr}] {err}"); - } - } - }); - } - - Err(ConnectionError::LocallyClosed) - } - - async fn listen_bi_streams( - self, - mut bi_streams: IncomingBiStreams, - ) -> Result<(), ConnectionError> { - while let Some(stream) = bi_streams.next().await { - let (send, recv) = stream?; - let conn = self.clone(); - - tokio::spawn(async move { - match conn.process_bi_stream(send, recv).await { - Ok(()) => {} - Err(err) => { - conn.controller - .close(err.as_error_code(), err.to_string().as_bytes()); - - let rmt_addr = conn.controller.remote_address(); - log::error!("[{rmt_addr}] {err}"); - } - } - }); - } - - Err(ConnectionError::LocallyClosed) - } - - async fn listen_datagrams(self, mut datagrams: Datagrams) -> Result<(), ConnectionError> { - while let Some(datagram) = datagrams.next().await { - let datagram = datagram?; - let conn = self.clone(); - - tokio::spawn(async move { - match conn.process_datagram(datagram).await { - Ok(()) => {} - Err(err) => { - conn.controller - .close(err.as_error_code(), err.to_string().as_bytes()); - - let rmt_addr = conn.controller.remote_address(); - log::error!("[{rmt_addr}] {err}"); - } - } - }); - } - - Err(ConnectionError::LocallyClosed) - } - - async fn listen_received_udp_packet( - self, - mut recv_pkt_rx: RecvPacketReceiver, - ) -> Result<(), ConnectionError> { - while let Some((assoc_id, pkt, addr)) = recv_pkt_rx.recv().await { - let conn = self.clone(); - - tokio::spawn(async move { - match conn.process_received_udp_packet(assoc_id, pkt, addr).await { - Ok(()) => {} - Err(err) => { - conn.controller - .close(err.as_error_code(), err.to_string().as_bytes()); - - let rmt_addr = conn.controller.remote_address(); - log::error!("[{rmt_addr}] {err}"); - } - } - }); - } - - Err(ConnectionError::LocallyClosed) - } - - async fn handle_authentication_timeout(self, timeout: Duration) -> Result<(), ConnectionError> { - let is_timeout = tokio::select! { - _ = self.is_authenticated.clone() => false, - () = time::sleep(timeout) => true, - }; - - if !is_timeout { - Ok(()) - } else { - let err = DispatchError::AuthenticationTimeout; - - self.controller - .close(err.as_error_code(), err.to_string().as_bytes()); - self.is_authenticated.wake(); - - let rmt_addr = self.controller.remote_address(); - log::error!("[{rmt_addr}] {err}"); - - Err(ConnectionError::LocallyClosed) - } - } -} - -#[derive(Clone)] -pub struct IsClosed(Arc); - -struct IsClosedInner { - is_closed: AtomicBool, - waker: Mutex>, -} - -impl IsClosed { - fn new() -> Self { - Self(Arc::new(IsClosedInner { - is_closed: AtomicBool::new(false), - waker: Mutex::new(None), - })) - } - - fn set_closed(&self) { - self.0.is_closed.store(true, Ordering::Release); - - if let Some(waker) = self.0.waker.lock().take() { - waker.wake(); - } - } - - fn check(&self) -> bool { - self.0.is_closed.load(Ordering::Acquire) - } -} - -impl Future for IsClosed { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.0.is_closed.load(Ordering::Acquire) { - Poll::Ready(()) - } else { - *self.0.waker.lock() = Some(cx.waker().clone()); - Poll::Pending - } - } -} diff --git a/server/src/connection/task.rs b/server/src/connection/task.rs deleted file mode 100644 index 2affc7b..0000000 --- a/server/src/connection/task.rs +++ /dev/null @@ -1,186 +0,0 @@ -use super::udp::UdpSessionMap; -use bytes::{Bytes, BytesMut}; -use quinn::{ - Connection as QuinnConnection, ConnectionError, ReadExactError, RecvStream, SendDatagramError, - SendStream, WriteError, -}; -use std::{ - io::{Error as IoError, IoSlice}, - net::SocketAddr, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; -use thiserror::Error; -use tokio::{ - io::{self, AsyncRead, AsyncWrite, ReadBuf}, - net::{self, TcpStream}, -}; -use tuic_protocol::{Address, Command}; - -pub async fn connect( - mut send: SendStream, - recv: RecvStream, - addr: Address, -) -> Result<(), TaskError> { - let mut target = None; - - let addrs = match addr { - Address::SocketAddress(addr) => Ok(vec![addr]), - Address::DomainAddress(domain, port) => net::lookup_host((domain.as_str(), port)) - .await - .map(|res| res.collect()), - }?; - - for addr in addrs { - if let Ok(target_stream) = TcpStream::connect(addr).await { - target = Some(target_stream); - break; - } - } - - if let Some(mut target) = target { - let resp = Command::new_response(true); - resp.write_to(&mut send).await?; - let mut tunnel = BiStream(send, recv); - io::copy_bidirectional(&mut target, &mut tunnel).await?; - } else { - let resp = Command::new_response(false); - resp.write_to(&mut send).await?; - send.finish().await?; - }; - - Ok(()) -} - -pub async fn packet_from_uni_stream( - mut stream: RecvStream, - udp_sessions: Arc, - assoc_id: u32, - len: u16, - addr: Address, - src_addr: SocketAddr, -) -> Result<(), TaskError> { - let mut buf = vec![0; len as usize]; - stream.read_exact(&mut buf).await?; - - let pkt = Bytes::from(buf); - udp_sessions.send(assoc_id, pkt, addr, src_addr).await?; - - Ok(()) -} - -pub async fn packet_from_datagram( - pkt: Bytes, - udp_sessions: Arc, - assoc_id: u32, - addr: Address, - src_addr: SocketAddr, -) -> Result<(), TaskError> { - udp_sessions.send(assoc_id, pkt, addr, src_addr).await?; - Ok(()) -} - -pub async fn packet_to_uni_stream( - conn: QuinnConnection, - assoc_id: u32, - pkt: Bytes, - addr: Address, -) -> Result<(), TaskError> { - let mut stream = conn.open_uni().await?; - - let cmd = Command::new_packet(assoc_id, pkt.len() as u16, addr); - cmd.write_to(&mut stream).await?; - stream.write_all(&pkt).await?; - stream.finish().await?; - - Ok(()) -} - -pub async fn packet_to_datagram( - conn: QuinnConnection, - assoc_id: u32, - pkt: Bytes, - addr: Address, -) -> Result<(), TaskError> { - let cmd = Command::new_packet(assoc_id, pkt.len() as u16, addr); - - let mut buf = BytesMut::with_capacity(cmd.serialized_len()); - cmd.write_to_buf(&mut buf); - buf.extend_from_slice(&pkt); - - let pkt = buf.freeze(); - conn.send_datagram(pkt)?; - - Ok(()) -} - -pub async fn dissociate( - udp_sessions: Arc, - assoc_id: u32, - src_addr: SocketAddr, -) -> Result<(), TaskError> { - udp_sessions.dissociate(assoc_id, src_addr); - Ok(()) -} - -struct BiStream(SendStream, RecvStream); - -impl AsyncRead for BiStream { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.1).poll_read(cx, buf) - } -} - -impl AsyncWrite for BiStream { - #[inline] - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) - } - - #[inline] - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() - } - - #[inline] - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) - } - - #[inline] - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) - } -} - -#[derive(Error, Debug)] -pub enum TaskError { - #[error(transparent)] - Io(#[from] IoError), - #[error(transparent)] - Connection(#[from] ConnectionError), - #[error(transparent)] - ReadStream(#[from] ReadExactError), - #[error(transparent)] - WriteStream(#[from] WriteError), - #[error(transparent)] - SendDatagram(#[from] SendDatagramError), -} diff --git a/server/src/connection/udp.rs b/server/src/connection/udp.rs deleted file mode 100644 index 77226a8..0000000 --- a/server/src/connection/udp.rs +++ /dev/null @@ -1,178 +0,0 @@ -use bytes::Bytes; -use crossbeam_utils::atomic::AtomicCell; -use parking_lot::Mutex; - -use std::{ - collections::HashMap, - io::Result, - net::{Ipv6Addr, SocketAddr}, - sync::Arc, -}; -use tokio::{ - net::UdpSocket, - sync::mpsc::{self, Receiver, Sender}, -}; -use tuic_protocol::Address; - -#[derive(Clone)] -pub struct UdpPacketFrom(Arc>>); - -impl UdpPacketFrom { - pub fn new() -> Self { - Self(Arc::new(AtomicCell::new(None))) - } - - pub fn check(&self) -> Option { - self.0.load() - } - - pub fn uni_stream(&self) -> bool { - self.0 - .compare_exchange(None, Some(UdpPacketSource::UniStream)) - .map_or_else(|from| from == Some(UdpPacketSource::UniStream), |_| true) - } - - pub fn datagram(&self) -> bool { - self.0 - .compare_exchange(None, Some(UdpPacketSource::Datagram)) - .map_or_else(|from| from == Some(UdpPacketSource::Datagram), |_| true) - } -} - -#[derive(Clone, Copy, Eq, PartialEq)] -pub enum UdpPacketSource { - UniStream, - Datagram, -} - -pub type SendPacketSender = Sender<(Bytes, Address)>; -pub type SendPacketReceiver = Receiver<(Bytes, Address)>; -pub type RecvPacketSender = Sender<(u32, Bytes, Address)>; -pub type RecvPacketReceiver = Receiver<(u32, Bytes, Address)>; - -pub struct UdpSessionMap { - map: Mutex>, - recv_pkt_tx_for_clone: RecvPacketSender, - max_pkt_size: usize, -} - -impl UdpSessionMap { - pub fn new(max_pkt_size: usize) -> (Self, RecvPacketReceiver) { - let (recv_pkt_tx, recv_pkt_rx) = mpsc::channel(1); - - ( - Self { - map: Mutex::new(HashMap::new()), - recv_pkt_tx_for_clone: recv_pkt_tx, - max_pkt_size, - }, - recv_pkt_rx, - ) - } - - #[allow(clippy::await_holding_lock)] - pub async fn send( - &self, - assoc_id: u32, - pkt: Bytes, - addr: Address, - src_addr: SocketAddr, - ) -> Result<()> { - let map = self.map.lock(); - - let send_pkt_tx = if let Some(session) = map.get(&assoc_id) { - let send_pkt_tx = session.0.clone(); - drop(map); - send_pkt_tx - } else { - log::info!("[{src_addr}] [associate] [{assoc_id}]"); - drop(map); - - let assoc = UdpSession::new( - assoc_id, - self.recv_pkt_tx_for_clone.clone(), - src_addr, - self.max_pkt_size, - ) - .await?; - - let send_pkt_tx = assoc.0.clone(); - - let mut map = self.map.lock(); - map.insert(assoc_id, assoc); - - send_pkt_tx - }; - - let _ = send_pkt_tx.send((pkt, addr)).await; - - Ok(()) - } - - pub fn dissociate(&self, assoc_id: u32, src_addr: SocketAddr) { - log::info!("[{src_addr}] [dissociate] [{assoc_id}]"); - self.map.lock().remove(&assoc_id); - } -} - -struct UdpSession(SendPacketSender); - -impl UdpSession { - async fn new( - assoc_id: u32, - recv_pkt_tx: RecvPacketSender, - src_addr: SocketAddr, - max_pkt_size: usize, - ) -> Result { - let socket = Arc::new(UdpSocket::bind(SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))).await?); - let (send_pkt_tx, send_pkt_rx) = mpsc::channel(1); - - tokio::spawn(async move { - match tokio::select! { - res = Self::listen_send_packet(socket.clone(), send_pkt_rx) => res, - res = Self::listen_receive_packet(socket, assoc_id, recv_pkt_tx, max_pkt_size) => res, - } { - Ok(()) => (), - Err(err) => log::warn!("[{src_addr}] [udp-session] [{assoc_id}] {err}"), - } - }); - - Ok(Self(send_pkt_tx)) - } - - async fn listen_send_packet( - socket: Arc, - mut send_pkt_rx: SendPacketReceiver, - ) -> Result<()> { - while let Some((pkt, addr)) = send_pkt_rx.recv().await { - match addr { - Address::DomainAddress(hostname, port) => { - socket.send_to(&pkt, (hostname, port)).await?; - } - Address::SocketAddress(addr) => { - socket.send_to(&pkt, addr).await?; - } - } - } - - Ok(()) - } - - async fn listen_receive_packet( - socket: Arc, - assoc_id: u32, - recv_pkt_tx: RecvPacketSender, - max_pkt_size: usize, - ) -> Result<()> { - loop { - let mut buf = vec![0; max_pkt_size]; - let (len, addr) = socket.recv_from(&mut buf).await?; - buf.truncate(len); - - let pkt = Bytes::from(buf); - let _ = recv_pkt_tx - .send((assoc_id, pkt, Address::SocketAddress(addr))) - .await; - } - } -} diff --git a/server/src/main.rs b/server/src/main.rs deleted file mode 100644 index 682ecac..0000000 --- a/server/src/main.rs +++ /dev/null @@ -1,51 +0,0 @@ -use crate::{ - config::{Config, ConfigError}, - server::Server, -}; -use std::{env, process}; - -mod certificate; -mod config; -mod connection; -mod server; - -#[tokio::main] -async fn main() { - let args = env::args_os(); - - let config = match Config::parse(args) { - Ok(cfg) => cfg, - Err(err) => { - match err { - ConfigError::Help(help) => println!("{help}"), - ConfigError::Version(version) => println!("{version}"), - err => eprintln!("{err}"), - } - return; - } - }; - - env_logger::builder() - .filter_level(config.log_level) - .format_level(true) - .format_target(false) - .format_module_path(false) - .init(); - - let server = match Server::init( - config.server_config, - config.listen_addr, - config.token, - config.authentication_timeout, - config.max_udp_relay_packet_size, - ) { - Ok(server) => server, - Err(err) => { - eprintln!("{err}"); - return; - } - }; - - server.run().await; - process::exit(1); -} diff --git a/server/src/server.rs b/server/src/server.rs deleted file mode 100644 index 1c110b2..0000000 --- a/server/src/server.rs +++ /dev/null @@ -1,62 +0,0 @@ -use crate::connection::Connection; -use futures_util::StreamExt; -use quinn::{Endpoint, EndpointConfig, Incoming, ServerConfig}; -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; -use std::{ - collections::HashSet, - io::Result, - net::{SocketAddr, UdpSocket}, - sync::Arc, - time::Duration, -}; - -pub struct Server { - incoming: Incoming, - listen_addr: SocketAddr, - token: Arc>, - authentication_timeout: Duration, - max_pkt_size: usize, -} - -impl Server { - pub fn init( - config: ServerConfig, - listen_addr: SocketAddr, - token: HashSet<[u8; 32]>, - auth_timeout: Duration, - max_pkt_size: usize, - ) -> Result { - let socket = match listen_addr { - SocketAddr::V4(_) => UdpSocket::bind(listen_addr)?, - SocketAddr::V6(_) => { - let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; - socket.set_only_v6(false)?; - socket.bind(&SockAddr::from(listen_addr))?; - UdpSocket::from(socket) - } - }; - - let (_, incoming) = Endpoint::new(EndpointConfig::default(), Some(config), socket)?; - - Ok(Self { - incoming, - listen_addr, - token: Arc::new(token), - authentication_timeout: auth_timeout, - max_pkt_size, - }) - } - - pub async fn run(mut self) { - log::info!("Server started. Listening: {}", self.listen_addr); - - while let Some(conn) = self.incoming.next().await { - tokio::spawn(Connection::handle( - conn, - self.token.clone(), - self.authentication_timeout, - self.max_pkt_size, - )); - } - } -} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..e69de29 From 75aa2d7d2111443834df51077c60d098f9e8c20c Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 31 Jul 2022 14:20:02 +0900 Subject: [PATCH 002/103] adding protocol back --- Cargo.toml | 75 +++++++---- client/main.rs | 2 +- server/main.rs | 2 +- src/lib.rs | 1 + src/protocol/marshaling.rs | 261 +++++++++++++++++++++++++++++++++++++ src/protocol/mod.rs | 97 ++++++++++++++ 6 files changed, 409 insertions(+), 29 deletions(-) create mode 100644 src/protocol/marshaling.rs create mode 100644 src/protocol/mod.rs diff --git a/Cargo.toml b/Cargo.toml index e78e84e..4f96f1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,40 +11,61 @@ readme = "README.md" license = "GPL-3.0-or-later" repository = "https://github.com/EAimTY/tuic" -[[bin]] -name = "client" -path = "client/main.rs" - [[bin]] name = "server" path = "server/main.rs" +required-features = ["server"] + +[[bin]] +name = "client" +path = "client/main.rs" +required-features = ["client"] [dependencies] -blake3 = "1.3.*" -bytes = "1.2.*" -crossbeam-utils = { version = "0.8.*", default-features = false } -env_logger = { version = "0.9.*", features = ["humantime"], default-features = false } -futures-util = { version = "0.3.*", default-features = false } -getopts = "0.2.*" -log = { version = "0.4.*", features = ["serde", "std"] } -once_cell = { version = "1.13.*", features = ["parking_lot"] } -parking_lot = "0.12.*" -quinn = { version = "0.8.*", features = ["tls-rustls"], default-features = false } -rand = "0.8.*" -rustls = { version = "0.20.*", features = ["quic"], default-features = false } -rustls-native-certs = "0.6.*" -rustls-pemfile = "1.0.*" -serde = { version = "1.0.*", features = ["derive", "std"], default-features = false } -serde_json = { version = "1.0.*", features = ["std"], default-features = false } -socket2 = "0.4.*" -socks5-proto = "0.3.*" -socks5-server = "0.8.*" -thiserror = "1.0.*" -tokio = { version = "1.20.*", features = ["io-util", "macros", "net", "parking_lot", "rt-multi-thread", "sync", "time"] } -webpki = { version = "0.22.*", default-features = false } +# blake3 = "1.3.*" +byteorder = { version = "1.4.*", default-features = false, optional = true } +bytes = { version = "1.2.*", default-features = false, optional = true } +# crossbeam-utils = { version = "0.8.*", default-features = false } +# env_logger = { version = "0.9.*", features = ["humantime"], default-features = false } +# futures-util = { version = "0.3.*", default-features = false } +# getopts = "0.2.*" +# log = { version = "0.4.*", features = ["serde", "std"] } +# once_cell = { version = "1.13.*", features = ["parking_lot"] } +# parking_lot = "0.12.*" +# quinn = { version = "0.8.*", features = ["tls-rustls"], default-features = false } +# rand = "0.8.*" +# rustls = { version = "0.20.*", features = ["quic"], default-features = false } +# rustls-native-certs = "0.6.*" +# rustls-pemfile = "1.0.*" +# serde = { version = "1.0.*", features = ["derive", "std"], default-features = false } +# serde_json = { version = "1.0.*", features = ["std"], default-features = false } +# socket2 = "0.4.*" +# socks5-proto = "0.3.*" +# socks5-server = "0.8.*" +thiserror = { version = "1.0.*", default-features = false, optional = true } +tokio = { version = "1.20.*", default-features = false, optional = true } +# tokio = { version = "1.20.*", features = ["io-util", "macros", "net", "parking_lot", "rt-multi-thread", "sync", "time"] } +# webpki = { version = "0.22.*", default-features = false } + +[features] +default = [] + +all = ["protocol_marshaling_tokio", "server", "client"] + +protocol_marshaling_tokio = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] + +server = ["protocol_marshaling_tokio"] +client = ["protocol_marshaling_tokio"] + +[dev-dependencies] +tuic = { path = ".", features = ["all"] } [profile.release] lto = true strip = true +incremental = false codegen-units = 1 -panic = "abort" \ No newline at end of file +panic = "abort" + +[package.metadata.docs.rs] +all-features = true \ No newline at end of file diff --git a/client/main.rs b/client/main.rs index f084790..47ad8c6 100644 --- a/client/main.rs +++ b/client/main.rs @@ -1,3 +1,3 @@ fn main() { println!("Hello World!"); -} \ No newline at end of file +} diff --git a/server/main.rs b/server/main.rs index f084790..47ad8c6 100644 --- a/server/main.rs +++ b/server/main.rs @@ -1,3 +1,3 @@ fn main() { println!("Hello World!"); -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index e69de29..1b800ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -0,0 +1 @@ +pub mod protocol; diff --git a/src/protocol/marshaling.rs b/src/protocol/marshaling.rs new file mode 100644 index 0000000..a8eb39b --- /dev/null +++ b/src/protocol/marshaling.rs @@ -0,0 +1,261 @@ +use super::{Address, Command, TUIC_PROTOCOL_VERSION}; +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::BufMut; +use std::{ + io::{Cursor, Error as IoError, ErrorKind as IoErrorKind}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, + string::FromUtf8Error, +}; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +impl Command { + pub async fn read_from(r: &mut R) -> Result + where + R: AsyncRead + Unpin, + { + let ver = r.read_u8().await?; + + if ver != TUIC_PROTOCOL_VERSION { + return Err(Error::UnsupportedVersion(ver)); + } + + let cmd = r.read_u8().await?; + + match cmd { + Self::TYPE_RESPONSE => { + let resp = r.read_u8().await?; + match resp { + Self::RESPONSE_SUCCEEDED => Ok(Self::Response(true)), + Self::RESPONSE_FAILED => Ok(Self::Response(false)), + _ => Err(Error::InvalidResponse(resp)), + } + } + Self::TYPE_AUTHENTICATE => { + let mut digest = [0; 32]; + r.read_exact(&mut digest).await?; + Ok(Self::Authenticate(digest)) + } + Self::TYPE_CONNECT => { + let addr = Address::read_from(r).await?; + Ok(Self::Connect { addr }) + } + Self::TYPE_PACKET => { + let mut buf = [0; 6]; + r.read_exact(&mut buf).await?; + let mut rdr = Cursor::new(buf); + + let assoc_id = ReadBytesExt::read_u32::(&mut rdr).unwrap(); + let len = ReadBytesExt::read_u16::(&mut rdr).unwrap(); + let addr = Address::read_from(r).await?; + + Ok(Self::Packet { + assoc_id, + len, + addr, + }) + } + Self::TYPE_DISSOCIATE => { + let assoc_id = r.read_u32().await?; + Ok(Self::Dissociate { assoc_id }) + } + Self::TYPE_HEARTBEAT => Ok(Self::Heartbeat), + _ => Err(Error::InvalidCommand(cmd)), + } + } + + pub async fn write_to(&self, w: &mut W) -> Result<(), Error> + where + W: AsyncWrite + Unpin, + { + let mut buf = Vec::with_capacity(self.serialized_len()); + self.write_to_buf(&mut buf); + w.write_all(&buf).await?; + Ok(()) + } + + pub fn write_to_buf(&self, buf: &mut B) { + buf.put_u8(TUIC_PROTOCOL_VERSION); + + match self { + Self::Response(is_succeeded) => { + buf.put_u8(Self::TYPE_RESPONSE); + if *is_succeeded { + buf.put_u8(Self::RESPONSE_SUCCEEDED); + } else { + buf.put_u8(Self::RESPONSE_FAILED); + } + } + Self::Authenticate(digest) => { + buf.put_u8(Self::TYPE_AUTHENTICATE); + buf.put_slice(digest); + } + Self::Connect { addr } => { + buf.put_u8(Self::TYPE_CONNECT); + addr.write_to_buf(buf); + } + Self::Packet { + assoc_id, + len, + addr, + } => { + buf.put_u8(Self::TYPE_PACKET); + buf.put_u32(*assoc_id); + buf.put_u16(*len); + addr.write_to_buf(buf); + } + Self::Dissociate { assoc_id } => { + buf.put_u8(Self::TYPE_DISSOCIATE); + buf.put_u32(*assoc_id); + } + Self::Heartbeat => { + buf.put_u8(Self::TYPE_HEARTBEAT); + } + } + } + + pub fn serialized_len(&self) -> usize { + 2 + match self { + Self::Response(_) => 1, + Self::Authenticate { .. } => 32, + Self::Connect { addr } => addr.serialized_len(), + Self::Packet { addr, .. } => 6 + addr.serialized_len(), + Self::Dissociate { .. } => 4, + Self::Heartbeat => 0, + } + } +} + +impl Address { + pub async fn read_from(stream: &mut R) -> Result + where + R: AsyncRead + Unpin, + { + let addr_type = stream.read_u8().await?; + + match addr_type { + Self::TYPE_DOMAIN => { + let len = stream.read_u8().await? as usize; + + let mut buf = vec![0; len + 2]; + stream.read_exact(&mut buf).await?; + + let port = ReadBytesExt::read_u16::(&mut &buf[len..]).unwrap(); + buf.truncate(len); + + let addr = String::from_utf8(buf)?; + + Ok(Self::DomainAddress(addr, port)) + } + Self::TYPE_IPV4 => { + let mut buf = [0; 6]; + stream.read_exact(&mut buf).await?; + let mut rdr = Cursor::new(buf); + + let addr = Ipv4Addr::new( + ReadBytesExt::read_u8(&mut rdr).unwrap(), + ReadBytesExt::read_u8(&mut rdr).unwrap(), + ReadBytesExt::read_u8(&mut rdr).unwrap(), + ReadBytesExt::read_u8(&mut rdr).unwrap(), + ); + + let port = ReadBytesExt::read_u16::(&mut rdr).unwrap(); + + Ok(Self::SocketAddress(SocketAddr::from((addr, port)))) + } + Self::TYPE_IPV6 => { + let mut buf = [0; 18]; + stream.read_exact(&mut buf).await?; + let mut rdr = Cursor::new(buf); + + let addr = Ipv6Addr::new( + ReadBytesExt::read_u16::(&mut rdr).unwrap(), + ReadBytesExt::read_u16::(&mut rdr).unwrap(), + ReadBytesExt::read_u16::(&mut rdr).unwrap(), + ReadBytesExt::read_u16::(&mut rdr).unwrap(), + ReadBytesExt::read_u16::(&mut rdr).unwrap(), + ReadBytesExt::read_u16::(&mut rdr).unwrap(), + ReadBytesExt::read_u16::(&mut rdr).unwrap(), + ReadBytesExt::read_u16::(&mut rdr).unwrap(), + ); + + let port = ReadBytesExt::read_u16::(&mut rdr).unwrap(); + + Ok(Self::SocketAddress(SocketAddr::from((addr, port)))) + } + _ => Err(Error::InvalidAddressType(addr_type)), + } + } + + pub async fn write_to(&self, writer: &mut W) -> Result<(), Error> + where + W: AsyncWrite + Unpin, + { + let mut buf = Vec::with_capacity(self.serialized_len()); + self.write_to_buf(&mut buf); + writer.write_all(&buf).await?; + Ok(()) + } + + pub fn write_to_buf(&self, buf: &mut B) { + match self { + Self::DomainAddress(addr, port) => { + buf.put_u8(Self::TYPE_DOMAIN); + buf.put_u8(addr.len() as u8); + buf.put_slice(addr.as_bytes()); + buf.put_u16(*port); + } + Self::SocketAddress(addr) => match addr { + SocketAddr::V4(addr) => { + buf.put_u8(Self::TYPE_IPV4); + buf.put_slice(&addr.ip().octets()); + buf.put_u16(addr.port()); + } + SocketAddr::V6(addr) => { + buf.put_u8(Self::TYPE_IPV6); + for seg in addr.ip().segments() { + buf.put_u16(seg); + } + buf.put_u16(addr.port()); + } + }, + } + } + + pub fn serialized_len(&self) -> usize { + 1 + match self { + Address::DomainAddress(addr, _) => 1 + addr.len() + 2, + Address::SocketAddress(addr) => match addr { + SocketAddr::V4(_) => 6, + SocketAddr::V6(_) => 18, + }, + } + } +} + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + Io(#[from] IoError), + #[error("unsupported TUIC version: {0:#x}")] + UnsupportedVersion(u8), + #[error("invalid response: {0:#x}")] + InvalidResponse(u8), + #[error("invalid command: {0:#x}")] + InvalidCommand(u8), + #[error("invalid address type: {0:#x}")] + InvalidAddressType(u8), + #[error("invalid address encoding: {0}")] + InvalidAddressEncoding(#[from] FromUtf8Error), +} + +impl From for IoError { + fn from(err: Error) -> Self { + let kind = match err { + Error::Io(err) => return err, + _ => IoErrorKind::Other, + }; + + Self::new(kind, err) + } +} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs new file mode 100644 index 0000000..be4631a --- /dev/null +++ b/src/protocol/mod.rs @@ -0,0 +1,97 @@ +//! The TUIC protocol + +#[cfg(feature = "protocol_marshaling_tokio")] +mod marshaling; + +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + net::SocketAddr, +}; + +pub const TUIC_PROTOCOL_VERSION: u8 = 0x04; + +#[cfg(feature = "protocol_marshaling_tokio")] +pub use self::marshaling::Error; + +/// Command +/// +/// ```plain +/// +-----+------+----------+ +/// | VER | TYPE | OPT | +/// +-----+------+----------+ +/// | 1 | 1 | Variable | +/// +-----+------+----------+ +/// ``` +#[non_exhaustive] +#[derive(Clone, Debug)] +pub enum Command { + Response(bool), + Authenticate([u8; 32]), + Connect { + addr: Address, + }, + Packet { + assoc_id: u32, + len: u16, + addr: Address, + }, + Dissociate { + assoc_id: u32, + }, + Heartbeat, +} + +impl Command { + pub const TYPE_RESPONSE: u8 = 0xff; + pub const TYPE_AUTHENTICATE: u8 = 0x00; + pub const TYPE_CONNECT: u8 = 0x01; + pub const TYPE_PACKET: u8 = 0x02; + pub const TYPE_DISSOCIATE: u8 = 0x03; + pub const TYPE_HEARTBEAT: u8 = 0x04; + + pub const RESPONSE_SUCCEEDED: u8 = 0x00; + pub const RESPONSE_FAILED: u8 = 0xff; + + pub const fn max_serialized_len() -> usize { + 2 + 6 + Address::max_serialized_len() + } +} + +/// Address +/// +/// ```plain +/// +------+----------+----------+ +/// | TYPE | ADDR | PORT | +/// +------+----------+----------+ +/// | 1 | Variable | 2 | +/// +------+----------+----------+ +/// ``` +/// +/// The address type can be one of the following: +/// 0x00: fully-qualified domain name (the first byte indicates the length of the domain name) +/// 0x01: IPv4 address +/// 0x02: IPv6 address +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Address { + DomainAddress(String, u16), + SocketAddress(SocketAddr), +} + +impl Address { + pub const TYPE_DOMAIN: u8 = 0x00; + pub const TYPE_IPV4: u8 = 0x01; + pub const TYPE_IPV6: u8 = 0x02; + + pub const fn max_serialized_len() -> usize { + 1 + 1 + u8::MAX as usize + 2 + } +} + +impl Display for Address { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + Self::DomainAddress(addr, port) => write!(f, "{addr}:{port}"), + Self::SocketAddress(addr) => write!(f, "{addr}"), + } + } +} From 7292c78c6c39d23123d77ebea33b7a3dbe552b0c Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 31 Jul 2022 14:22:53 +0900 Subject: [PATCH 003/103] ignoring `.DS_Store` --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 96ef6c0..16d5636 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +.DS_Store From 0ad3ca1a48062b7f464157023a822de740b806a5 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 31 Jul 2022 17:56:33 +0900 Subject: [PATCH 004/103] adding client endpoint --- Cargo.toml | 9 +- src/client/connection.rs | 68 ++++++++++++ src/client/mod.rs | 216 +++++++++++++++++++++++++++++++++++++++ src/common/mod.rs | 3 + src/common/udp.rs | 5 + src/lib.rs | 15 +++ src/server.rs | 1 + 7 files changed, 313 insertions(+), 4 deletions(-) create mode 100644 src/client/connection.rs create mode 100644 src/client/mod.rs create mode 100644 src/common/mod.rs create mode 100644 src/common/udp.rs create mode 100644 src/server.rs diff --git a/Cargo.toml b/Cargo.toml index 4f96f1a..93cb60e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,9 +32,10 @@ bytes = { version = "1.2.*", default-features = false, optional = true } # log = { version = "0.4.*", features = ["serde", "std"] } # once_cell = { version = "1.13.*", features = ["parking_lot"] } # parking_lot = "0.12.*" -# quinn = { version = "0.8.*", features = ["tls-rustls"], default-features = false } +quinn = { version = "0.8.*", features = ["tls-rustls"], default-features = false, optional = true } +quinn-proto = { version = "0.8.*", default-features = false, optional = true } # rand = "0.8.*" -# rustls = { version = "0.20.*", features = ["quic"], default-features = false } +rustls = { version = "0.20.*", default-features = false, optional = true } # rustls-native-certs = "0.6.*" # rustls-pemfile = "1.0.*" # serde = { version = "1.0.*", features = ["derive", "std"], default-features = false } @@ -54,8 +55,8 @@ all = ["protocol_marshaling_tokio", "server", "client"] protocol_marshaling_tokio = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] -server = ["protocol_marshaling_tokio"] -client = ["protocol_marshaling_tokio"] +server = ["protocol_marshaling_tokio", "quinn", "rustls"] +client = ["protocol_marshaling_tokio", "quinn", "quinn-proto", "rustls", "thiserror", "tokio/net"] [dev-dependencies] tuic = { path = ".", features = ["all"] } diff --git a/src/client/connection.rs b/src/client/connection.rs new file mode 100644 index 0000000..30fdf8e --- /dev/null +++ b/src/client/connection.rs @@ -0,0 +1,68 @@ +use quinn::{ + Connecting as QuinnConnecting, Connection as QuinnConnection, Datagrams, IncomingUniStreams, +}; + +use crate::UdpRelayMode; + +#[derive(Debug)] +pub struct Connecting { + conn: QuinnConnecting, + token: [u8; 32], + udp_relay_mode: UdpRelayMode, +} + +impl Connecting { + pub(super) fn new( + conn: QuinnConnecting, + token: [u8; 32], + udp_relay_mode: UdpRelayMode, + ) -> Self { + Self { + conn, + token, + udp_relay_mode, + } + } +} + +#[derive(Debug)] +pub struct Connection { + conn: QuinnConnection, + token: [u8; 32], + udp_relay_mode: UdpRelayMode, +} + +impl Connection { + pub(super) fn new( + conn: QuinnConnection, + token: [u8; 32], + udp_relay_mode: UdpRelayMode, + ) -> Self { + Self { + conn, + token, + udp_relay_mode, + } + } +} + +#[derive(Debug)] +pub struct IncomingPackets { + uni_streams: IncomingUniStreams, + datagrams: Datagrams, + udp_relay_mode: UdpRelayMode, +} + +impl IncomingPackets { + pub(super) fn new( + uni_streams: IncomingUniStreams, + datagrams: Datagrams, + udp_relay_mode: UdpRelayMode, + ) -> Self { + Self { + uni_streams, + datagrams, + udp_relay_mode, + } + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs new file mode 100644 index 0000000..5e03059 --- /dev/null +++ b/src/client/mod.rs @@ -0,0 +1,216 @@ +mod connection; + +pub use self::connection::{Connecting, Connection, IncomingPackets}; + +use crate::UdpRelayMode; +use quinn::{ + congestion::ControllerFactory, ApplicationClose, ClientConfig as QuinnClientConfig, + ConnectError as QuinnConnectError, ConnectionClose, ConnectionError as QuinnConnectionError, + Endpoint, EndpointConfig, NewConnection as QuinnNewConnection, +}; +use quinn_proto::TransportError; +use rustls::{version, ClientConfig as RustlsClientConfig, RootCertStore}; +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + io::Result as IoResult, + net::{SocketAddr, ToSocketAddrs, UdpSocket}, + sync::Arc, +}; +use thiserror::Error; + +pub struct Client { + endpoint: Endpoint, + enable_0rtt: bool, + udp_relay_mode: UdpRelayMode, +} + +impl Client { + pub fn bind(cfg: ClientConfig, addr: impl ToSocketAddrs) -> IoResult + where + C: ControllerFactory + Send + Sync + 'static, + { + let socket = UdpSocket::bind(addr)?; + let (mut ep, _) = Endpoint::new(EndpointConfig::default(), None, socket)?; + + let mut crypto = RustlsClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&version::TLS13]) + .unwrap() + .with_root_certificates(cfg.certs) + .with_no_client_auth(); + + crypto.alpn_protocols = cfg.alpn_protocols; + crypto.enable_early_data = cfg.enable_0rtt; + crypto.enable_sni = !cfg.disable_sni; + + let mut quinn_config = QuinnClientConfig::new(Arc::new(crypto)); + + let transport = Arc::get_mut(&mut quinn_config.transport).unwrap(); + transport.congestion_controller_factory(cfg.congestion_controller); + transport.max_idle_timeout(None); + + ep.set_default_client_config(quinn_config); + + Ok(Self { + endpoint: ep, + udp_relay_mode: cfg.udp_relay_mode, + enable_0rtt: cfg.enable_0rtt, + }) + } + + pub fn reconfigure(&mut self, cfg: ClientConfig) + where + C: ControllerFactory + Send + Sync + 'static, + { + let mut crypto = RustlsClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&version::TLS13]) + .unwrap() + .with_root_certificates(cfg.certs) + .with_no_client_auth(); + + crypto.alpn_protocols = cfg.alpn_protocols; + crypto.enable_early_data = cfg.enable_0rtt; + crypto.enable_sni = !cfg.disable_sni; + + let mut quinn_config = QuinnClientConfig::new(Arc::new(crypto)); + + let transport = Arc::get_mut(&mut quinn_config.transport).unwrap(); + transport.congestion_controller_factory(cfg.congestion_controller); + transport.max_idle_timeout(None); + + self.endpoint.set_default_client_config(quinn_config); + + self.udp_relay_mode = cfg.udp_relay_mode; + self.enable_0rtt = cfg.enable_0rtt; + } + + pub fn rebind(&mut self, addr: impl ToSocketAddrs) -> IoResult<()> { + let socket = UdpSocket::bind(addr)?; + self.endpoint.rebind(socket) + } + + pub async fn connect( + &self, + addr: SocketAddr, + server_name: &str, + token: [u8; 32], + ) -> Result<(Connection, IncomingPackets), ConnectError> { + let conn = match self.endpoint.connect(addr, server_name) { + Ok(conn) => conn, + Err(err) => { + return Err(match err { + QuinnConnectError::UnsupportedVersion => ConnectError::UnsupportedQUICVersion, + QuinnConnectError::EndpointStopping => ConnectError::EndpointStopping, + QuinnConnectError::TooManyConnections => ConnectError::TooManyConnections, + QuinnConnectError::InvalidDnsName(err) => ConnectError::InvalidDomainName(err), + QuinnConnectError::InvalidRemoteAddress(err) => { + ConnectError::InvalidRemoteAddress(err) + } + QuinnConnectError::NoDefaultClientConfig => unreachable!(), + }) + } + }; + + let QuinnNewConnection { + connection, + datagrams, + uni_streams, + .. + } = if self.enable_0rtt { + match conn.into_0rtt() { + Ok((conn, _)) => conn, + Err(conn) => { + return Err(ConnectError::Convert0Rtt(Connecting::new( + conn, + token, + self.udp_relay_mode, + ))) + } + } + } else { + match conn.await { + Ok(conn) => conn, + Err(err) => { + return Err(match err { + QuinnConnectionError::VersionMismatch => { + ConnectError::UnsupportedQUICVersion + } + QuinnConnectionError::TransportError(err) => { + ConnectError::TransportError(err) + } + QuinnConnectionError::ConnectionClosed(err) => { + ConnectError::ConnectionClosed(err) + } + QuinnConnectionError::ApplicationClosed(err) => { + ConnectError::ApplicationClosed(err) + } + QuinnConnectionError::Reset => ConnectError::Reset, + QuinnConnectionError::TimedOut => ConnectError::TimedOut, + QuinnConnectionError::LocallyClosed => ConnectError::LocallyClosed, + }) + } + } + }; + + let conn = Connection::new(connection, token, self.udp_relay_mode); + let pkts = IncomingPackets::new(uni_streams, datagrams, self.udp_relay_mode); + + Ok((conn, pkts)) + } +} + +#[derive(Clone, Debug)] +pub struct ClientConfig { + pub certs: RootCertStore, + pub alpn_protocols: Vec>, + pub disable_sni: bool, + pub enable_0rtt: bool, + pub udp_relay_mode: UdpRelayMode, + pub congestion_controller: C, +} + +#[derive(Clone, Debug)] +pub enum ServerAddr { + SocketAddr { addr: SocketAddr, name: String }, + DomainAddr { domain: String, port: u16 }, +} + +impl Display for ServerAddr { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + ServerAddr::SocketAddr { addr, name } => write!(f, "{addr} ({name})"), + ServerAddr::DomainAddr { domain, port } => write!(f, "{domain}:{port}"), + } + } +} + +#[derive(Error, Debug)] +pub enum ConnectError { + #[error("failed to convert QUIC connection into 0-RTT")] + Convert0Rtt(Connecting), + #[error("unsupported QUIC version")] + UnsupportedQUICVersion, + #[error("endpoint stopping")] + EndpointStopping, + #[error("too many connections")] + TooManyConnections, + #[error("invalid domain name: {0}")] + InvalidDomainName(String), + #[error("invalid remote address: {0}")] + InvalidRemoteAddress(SocketAddr), + #[error(transparent)] + TransportError(#[from] TransportError), + #[error("aborted by peer: {0}")] + ConnectionClosed(ConnectionClose), + #[error("closed by peer: {0}")] + ApplicationClosed(ApplicationClose), + #[error("reset by peer")] + Reset, + #[error("timed out")] + TimedOut, + #[error("closed")] + LocallyClosed, +} diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..88e0d83 --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,3 @@ +mod udp; + +pub use self::udp::UdpRelayMode; diff --git a/src/common/udp.rs b/src/common/udp.rs new file mode 100644 index 0000000..331c133 --- /dev/null +++ b/src/common/udp.rs @@ -0,0 +1,5 @@ +#[derive(Clone, Copy, Debug)] +pub enum UdpRelayMode { + Native, + Quic, +} diff --git a/src/lib.rs b/src/lib.rs index 1b800ec..62da5ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1 +1,16 @@ pub mod protocol; + +#[cfg(any(feature = "server", feature = "client"))] +mod common; + +#[cfg(feature = "server")] +mod server; + +#[cfg(feature = "client")] +pub mod client; + +#[cfg(any(feature = "server", feature = "client"))] +pub use crate::common::UdpRelayMode; + +#[cfg(feature = "client")] +pub use crate::client::{Client, ClientConfig, ServerAddr}; diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/server.rs @@ -0,0 +1 @@ + From 8671e749149ae5c45931b1bd8c5640bcd53083ed Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 31 Jul 2022 19:56:13 +0900 Subject: [PATCH 005/103] adding `CONNECT` command support in client conn --- src/client/connection.rs | 129 ++++++++++++++++++++++++++++++------ src/client/mod.rs | 96 ++++++++++++--------------- src/client/stream.rs | 137 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 +- src/protocol/mod.rs | 11 ++++ 5 files changed, 302 insertions(+), 73 deletions(-) create mode 100644 src/client/stream.rs diff --git a/src/client/connection.rs b/src/client/connection.rs index 30fdf8e..7befe0a 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -1,8 +1,19 @@ +use super::{ + stream::{RecvStream, SendStream, StreamReg}, + ConnectError, Stream, +}; +use crate::{ + protocol::{Address, Command, Error as TuicError}, + UdpRelayMode, +}; use quinn::{ Connecting as QuinnConnecting, Connection as QuinnConnection, Datagrams, IncomingUniStreams, + NewConnection as QuinnNewConnection, +}; +use std::{ + io::{Error as IoError, Result as IoResult}, + sync::Arc, }; - -use crate::UdpRelayMode; #[derive(Debug)] pub struct Connecting { @@ -23,6 +34,26 @@ impl Connecting { udp_relay_mode, } } + + pub async fn establish(self) -> Result<(Connection, IncomingPackets), ConnectError> { + let QuinnNewConnection { + connection, + datagrams, + uni_streams, + .. + } = match self.conn.await { + Ok(conn) => conn, + Err(err) => return Err(ConnectError::from_quinn_connection_error(err)), + }; + + Ok(Connection::new( + connection, + uni_streams, + datagrams, + self.token, + self.udp_relay_mode, + )) + } } #[derive(Debug)] @@ -30,19 +61,92 @@ pub struct Connection { conn: QuinnConnection, token: [u8; 32], udp_relay_mode: UdpRelayMode, + stream_reg: Arc, } impl Connection { pub(super) fn new( conn: QuinnConnection, + uni_streams: IncomingUniStreams, + datagrams: Datagrams, token: [u8; 32], udp_relay_mode: UdpRelayMode, - ) -> Self { - Self { + ) -> (Self, IncomingPackets) { + let stream_reg = Arc::new(Arc::new(())); + + let conn = Self { conn, token, udp_relay_mode, - } + stream_reg: stream_reg.clone(), + }; + + let incoming = IncomingPackets { + uni_streams, + datagrams, + udp_relay_mode, + stream_reg, + }; + + (conn, incoming) + } + + pub async fn authenticate(&self) -> IoResult<()> { + let mut send = self.get_send_stream().await?; + let cmd = Command::Authenticate(self.token); + cmd.write_to(&mut send).await?; + send.finish().await?; + Ok(()) + } + + pub async fn heartbeat(&self) -> IoResult<()> { + let mut send = self.get_send_stream().await?; + let cmd = Command::Heartbeat; + cmd.write_to(&mut send).await?; + send.finish().await?; + Ok(()) + } + + pub async fn connect(&self, addr: Address) -> IoResult> { + let mut stream = self.get_bi_stream().await?; + + let cmd = Command::Connect { addr }; + cmd.write_to(&mut stream).await?; + + let resp = match Command::read_from(&mut stream).await { + Ok(Command::Response(resp)) => Ok(resp), + Ok(cmd) => Err(TuicError::InvalidCommand(cmd.type_code())), + Err(err) => Err(err), + }; + + let res = match resp { + Ok(true) => return Ok(Some(stream)), + Ok(false) => Ok(None), + Err(err) => Err(IoError::from(err)), + }; + + stream.finish().await?; + res + } + + pub async fn packet(&self) -> IoResult<()> { + todo!() + } + + pub async fn dissociate(&self) -> IoResult<()> { + todo!() + } + + async fn get_send_stream(&self) -> IoResult { + let send = self.conn.open_uni().await?; + Ok(SendStream::new(send, self.stream_reg.as_ref().clone())) + } + + async fn get_bi_stream(&self) -> IoResult { + let (send, recv) = self.conn.open_bi().await?; + let send = SendStream::new(send, self.stream_reg.as_ref().clone()); + let recv = RecvStream::new(recv, self.stream_reg.as_ref().clone()); + Ok(Stream::new(send, recv)) } } @@ -51,18 +155,5 @@ pub struct IncomingPackets { uni_streams: IncomingUniStreams, datagrams: Datagrams, udp_relay_mode: UdpRelayMode, -} - -impl IncomingPackets { - pub(super) fn new( - uni_streams: IncomingUniStreams, - datagrams: Datagrams, - udp_relay_mode: UdpRelayMode, - ) -> Self { - Self { - uni_streams, - datagrams, - udp_relay_mode, - } - } + stream_reg: Arc, } diff --git a/src/client/mod.rs b/src/client/mod.rs index 5e03059..063b604 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,6 +1,10 @@ mod connection; +mod stream; -pub use self::connection::{Connecting, Connection, IncomingPackets}; +pub use self::{ + connection::{Connecting, Connection, IncomingPackets}, + stream::Stream, +}; use crate::UdpRelayMode; use quinn::{ @@ -11,7 +15,6 @@ use quinn::{ use quinn_proto::TransportError; use rustls::{version, ClientConfig as RustlsClientConfig, RootCertStore}; use std::{ - fmt::{Display, Formatter, Result as FmtResult}, io::Result as IoResult, net::{SocketAddr, ToSocketAddrs, UdpSocket}, sync::Arc, @@ -100,18 +103,7 @@ impl Client { ) -> Result<(Connection, IncomingPackets), ConnectError> { let conn = match self.endpoint.connect(addr, server_name) { Ok(conn) => conn, - Err(err) => { - return Err(match err { - QuinnConnectError::UnsupportedVersion => ConnectError::UnsupportedQUICVersion, - QuinnConnectError::EndpointStopping => ConnectError::EndpointStopping, - QuinnConnectError::TooManyConnections => ConnectError::TooManyConnections, - QuinnConnectError::InvalidDnsName(err) => ConnectError::InvalidDomainName(err), - QuinnConnectError::InvalidRemoteAddress(err) => { - ConnectError::InvalidRemoteAddress(err) - } - QuinnConnectError::NoDefaultClientConfig => unreachable!(), - }) - } + Err(err) => return Err(ConnectError::from_quinn_connect_error(err)), }; let QuinnNewConnection { @@ -133,37 +125,25 @@ impl Client { } else { match conn.await { Ok(conn) => conn, - Err(err) => { - return Err(match err { - QuinnConnectionError::VersionMismatch => { - ConnectError::UnsupportedQUICVersion - } - QuinnConnectionError::TransportError(err) => { - ConnectError::TransportError(err) - } - QuinnConnectionError::ConnectionClosed(err) => { - ConnectError::ConnectionClosed(err) - } - QuinnConnectionError::ApplicationClosed(err) => { - ConnectError::ApplicationClosed(err) - } - QuinnConnectionError::Reset => ConnectError::Reset, - QuinnConnectionError::TimedOut => ConnectError::TimedOut, - QuinnConnectionError::LocallyClosed => ConnectError::LocallyClosed, - }) - } + Err(err) => return Err(ConnectError::from_quinn_connection_error(err)), } }; - let conn = Connection::new(connection, token, self.udp_relay_mode); - let pkts = IncomingPackets::new(uni_streams, datagrams, self.udp_relay_mode); - - Ok((conn, pkts)) + Ok(Connection::new( + connection, + uni_streams, + datagrams, + token, + self.udp_relay_mode, + )) } } #[derive(Clone, Debug)] -pub struct ClientConfig { +pub struct ClientConfig +where + C: ControllerFactory + Send + Sync + 'static, +{ pub certs: RootCertStore, pub alpn_protocols: Vec>, pub disable_sni: bool, @@ -172,21 +152,6 @@ pub struct ClientConfig { pub congestion_controller: C, } -#[derive(Clone, Debug)] -pub enum ServerAddr { - SocketAddr { addr: SocketAddr, name: String }, - DomainAddr { domain: String, port: u16 }, -} - -impl Display for ServerAddr { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - ServerAddr::SocketAddr { addr, name } => write!(f, "{addr} ({name})"), - ServerAddr::DomainAddr { domain, port } => write!(f, "{domain}:{port}"), - } - } -} - #[derive(Error, Debug)] pub enum ConnectError { #[error("failed to convert QUIC connection into 0-RTT")] @@ -214,3 +179,28 @@ pub enum ConnectError { #[error("closed")] LocallyClosed, } + +impl ConnectError { + fn from_quinn_connect_error(err: QuinnConnectError) -> Self { + match err { + QuinnConnectError::UnsupportedVersion => Self::UnsupportedQUICVersion, + QuinnConnectError::EndpointStopping => Self::EndpointStopping, + QuinnConnectError::TooManyConnections => Self::TooManyConnections, + QuinnConnectError::InvalidDnsName(err) => Self::InvalidDomainName(err), + QuinnConnectError::InvalidRemoteAddress(err) => Self::InvalidRemoteAddress(err), + QuinnConnectError::NoDefaultClientConfig => unreachable!(), + } + } + + fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { + match err { + QuinnConnectionError::VersionMismatch => Self::UnsupportedQUICVersion, + QuinnConnectionError::TransportError(err) => Self::TransportError(err), + QuinnConnectionError::ConnectionClosed(err) => Self::ConnectionClosed(err), + QuinnConnectionError::ApplicationClosed(err) => Self::ApplicationClosed(err), + QuinnConnectionError::Reset => Self::Reset, + QuinnConnectionError::TimedOut => Self::TimedOut, + QuinnConnectionError::LocallyClosed => Self::LocallyClosed, + } + } +} diff --git a/src/client/stream.rs b/src/client/stream.rs new file mode 100644 index 0000000..35fa9c6 --- /dev/null +++ b/src/client/stream.rs @@ -0,0 +1,137 @@ +use quinn::{RecvStream as QuinnRecvStream, SendStream as QuinnSendStream}; +use std::{ + io::{IoSlice, Result}, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub(super) type StreamReg = Arc<()>; + +pub(super) struct SendStream(QuinnSendStream, StreamReg); + +impl SendStream { + pub(super) fn new(send: QuinnSendStream, reg: StreamReg) -> Self { + Self(send, reg) + } + + #[inline] + pub async fn finish(&mut self) -> Result<()> { + self.0.finish().await?; + Ok(()) + } +} + +pub(super) struct RecvStream(QuinnRecvStream, StreamReg); + +impl RecvStream { + pub(super) fn new(recv: QuinnRecvStream, reg: StreamReg) -> Self { + Self(recv, reg) + } +} + +pub struct Stream(SendStream, RecvStream); + +impl Stream { + pub(super) fn new(send: SendStream, recv: RecvStream) -> Self { + Self(send, recv) + } + + #[inline] + pub async fn finish(&mut self) -> Result<()> { + self.0.finish().await + } +} + +impl AsyncWrite for SendStream { + #[inline] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + #[inline] + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + + #[inline] + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + #[inline] + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +impl AsyncRead for RecvStream { + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } +} + +impl AsyncWrite for Stream { + #[inline] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + #[inline] + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.0).poll_write_vectored(cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } + + #[inline] + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + #[inline] + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} + +impl AsyncRead for Stream { + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.1).poll_read(cx, buf) + } +} diff --git a/src/lib.rs b/src/lib.rs index 62da5ec..fb6f47c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,4 +13,4 @@ pub mod client; pub use crate::common::UdpRelayMode; #[cfg(feature = "client")] -pub use crate::client::{Client, ClientConfig, ServerAddr}; +pub use crate::client::{Client, ClientConfig}; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index be4631a..402b5f4 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -52,6 +52,17 @@ impl Command { pub const RESPONSE_SUCCEEDED: u8 = 0x00; pub const RESPONSE_FAILED: u8 = 0xff; + pub const fn type_code(&self) -> u8 { + match self { + Command::Response(_) => Self::TYPE_RESPONSE, + Command::Authenticate(_) => Self::TYPE_AUTHENTICATE, + Command::Connect { .. } => Self::TYPE_CONNECT, + Command::Packet { .. } => Self::TYPE_PACKET, + Command::Dissociate { .. } => Self::TYPE_DISSOCIATE, + Command::Heartbeat => Self::TYPE_HEARTBEAT, + } + } + pub const fn max_serialized_len() -> usize { 2 + 6 + Address::max_serialized_len() } From ed6dea83d3d1961fb81bb6b823ae679a6d389140 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 1 Aug 2022 00:31:47 +0900 Subject: [PATCH 006/103] support UDP fragmentation in protocol --- Cargo.toml | 8 +++---- src/protocol/marshaling.rs | 16 ++++++++++++-- src/protocol/mod.rs | 44 +++++++++++++++++++++++++++++++++++--- 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 93cb60e..dfe24e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,12 +51,12 @@ tokio = { version = "1.20.*", default-features = false, optional = true } [features] default = [] -all = ["protocol_marshaling_tokio", "server", "client"] +all = ["protocol_marshaling", "server", "client"] -protocol_marshaling_tokio = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] +protocol_marshaling = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] -server = ["protocol_marshaling_tokio", "quinn", "rustls"] -client = ["protocol_marshaling_tokio", "quinn", "quinn-proto", "rustls", "thiserror", "tokio/net"] +server = ["protocol_marshaling", "quinn", "rustls"] +client = ["protocol_marshaling", "quinn", "quinn-proto", "rustls", "thiserror"] [dev-dependencies] tuic = { path = ".", features = ["all"] } diff --git a/src/protocol/marshaling.rs b/src/protocol/marshaling.rs index a8eb39b..764684b 100644 --- a/src/protocol/marshaling.rs +++ b/src/protocol/marshaling.rs @@ -41,16 +41,22 @@ impl Command { Ok(Self::Connect { addr }) } Self::TYPE_PACKET => { - let mut buf = [0; 6]; + let mut buf = [0; 12]; r.read_exact(&mut buf).await?; let mut rdr = Cursor::new(buf); let assoc_id = ReadBytesExt::read_u32::(&mut rdr).unwrap(); + let pkt_id = ReadBytesExt::read_u32::(&mut rdr).unwrap(); + let frag_total = ReadBytesExt::read_u8(&mut rdr).unwrap(); + let frag_id = ReadBytesExt::read_u8(&mut rdr).unwrap(); let len = ReadBytesExt::read_u16::(&mut rdr).unwrap(); let addr = Address::read_from(r).await?; Ok(Self::Packet { assoc_id, + pkt_id, + frag_total, + frag_id, len, addr, }) @@ -96,11 +102,17 @@ impl Command { } Self::Packet { assoc_id, + pkt_id, + frag_total, + frag_id, len, addr, } => { buf.put_u8(Self::TYPE_PACKET); buf.put_u32(*assoc_id); + buf.put_u32(*pkt_id); + buf.put_u8(*frag_total); + buf.put_u8(*frag_id); buf.put_u16(*len); addr.write_to_buf(buf); } @@ -119,7 +131,7 @@ impl Command { Self::Response(_) => 1, Self::Authenticate { .. } => 32, Self::Connect { addr } => addr.serialized_len(), - Self::Packet { addr, .. } => 6 + addr.serialized_len(), + Self::Packet { addr, .. } => 12 + addr.serialized_len(), Self::Dissociate { .. } => 4, Self::Heartbeat => 0, } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 402b5f4..ef49389 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -1,6 +1,6 @@ //! The TUIC protocol -#[cfg(feature = "protocol_marshaling_tokio")] +#[cfg(feature = "protocol_marshaling")] mod marshaling; use std::{ @@ -10,7 +10,7 @@ use std::{ pub const TUIC_PROTOCOL_VERSION: u8 = 0x04; -#[cfg(feature = "protocol_marshaling_tokio")] +#[cfg(feature = "protocol_marshaling")] pub use self::marshaling::Error; /// Command @@ -25,19 +25,57 @@ pub use self::marshaling::Error; #[non_exhaustive] #[derive(Clone, Debug)] pub enum Command { + // +-----+ + // | REP | + // +-----+ + // | 1 | + // +-----+ Response(bool), + + // +-----+ + // | TKN | + // +-----+ + // | 32 | + // +-----+ Authenticate([u8; 32]), + + // +----------+ + // | ADDR | + // +----------+ + // | Variable | + // +----------+ Connect { addr: Address, }, + + // +----------+--------+------------+---------+-----+----------+ + // | ASSOC_ID | PKT_ID | FRAG_TOTAL | FRAG_ID | LEN | ADDR | + // +----------+--------+------------+---------+-----+----------+ + // | 4 | 4 | 1 | 1 | 2 | Variable | + // +----------+--------+------------+---------+-----+----------+ Packet { assoc_id: u32, + pkt_id: u32, + frag_total: u8, + frag_id: u8, len: u16, addr: Address, }, + + // +----------+ + // | ASSOC_ID | + // +----------+ + // | 4 | + // +----------+ Dissociate { assoc_id: u32, }, + + // +-+ + // | | + // +-+ + // | | + // +-+ Heartbeat, } @@ -64,7 +102,7 @@ impl Command { } pub const fn max_serialized_len() -> usize { - 2 + 6 + Address::max_serialized_len() + 2 + 12 + Address::max_serialized_len() } } From b55c995c80628eb90500552644ce9dbd123f6b0d Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 1 Aug 2022 00:32:10 +0900 Subject: [PATCH 007/103] bump protocol version --- src/protocol/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index ef49389..f9293ca 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -8,7 +8,7 @@ use std::{ net::SocketAddr, }; -pub const TUIC_PROTOCOL_VERSION: u8 = 0x04; +pub const TUIC_PROTOCOL_VERSION: u8 = 0x05; #[cfg(feature = "protocol_marshaling")] pub use self::marshaling::Error; From cb9049531b8b03b85ad2e3034a6ee3c459db279a Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 1 Aug 2022 01:53:06 +0900 Subject: [PATCH 008/103] support sending packet through uni_streams --- src/client/connection.rs | 75 +++++++++++++++++++++++++++++++++++--- src/protocol/marshaling.rs | 18 ++++++--- src/protocol/mod.rs | 8 ++-- 3 files changed, 86 insertions(+), 15 deletions(-) diff --git a/src/client/connection.rs b/src/client/connection.rs index 7befe0a..1f259d3 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -6,14 +6,19 @@ use crate::{ protocol::{Address, Command, Error as TuicError}, UdpRelayMode, }; +use bytes::Bytes; use quinn::{ Connecting as QuinnConnecting, Connection as QuinnConnection, Datagrams, IncomingUniStreams, NewConnection as QuinnNewConnection, }; use std::{ - io::{Error as IoError, Result as IoResult}, - sync::Arc, + io::{Error as IoError, ErrorKind, Result as IoResult}, + sync::{ + atomic::{AtomicU16, Ordering}, + Arc, + }, }; +use tokio::io::AsyncWriteExt; #[derive(Debug)] pub struct Connecting { @@ -62,6 +67,7 @@ pub struct Connection { token: [u8; 32], udp_relay_mode: UdpRelayMode, stream_reg: Arc, + next_pkt_id: Arc, } impl Connection { @@ -79,6 +85,7 @@ impl Connection { token, udp_relay_mode, stream_reg: stream_reg.clone(), + next_pkt_id: Arc::new(AtomicU16::new(0)), }; let incoming = IncomingPackets { @@ -129,12 +136,19 @@ impl Connection { res } - pub async fn packet(&self) -> IoResult<()> { - todo!() + pub async fn send_packet(&self, assoc_id: u32, addr: Address, pkt: Bytes) -> IoResult<()> { + match self.udp_relay_mode { + UdpRelayMode::Native => self.send_packet_to_datagram(assoc_id, addr, pkt), + UdpRelayMode::Quic => self.send_packet_to_uni_stream(assoc_id, addr, pkt).await, + } } - pub async fn dissociate(&self) -> IoResult<()> { - todo!() + pub async fn dissociate(&self, assoc_id: u32) -> IoResult<()> { + let mut send = self.get_send_stream().await?; + let cmd = Command::Dissociate { assoc_id }; + cmd.write_to(&mut send).await?; + send.finish().await?; + Ok(()) } async fn get_send_stream(&self) -> IoResult { @@ -148,6 +162,55 @@ impl Connection { let recv = RecvStream::new(recv, self.stream_reg.as_ref().clone()); Ok(Stream::new(send, recv)) } + + fn send_packet_to_datagram(&self, assoc_id: u32, addr: Address, pkt: Bytes) -> IoResult<()> { + let max_dg_size = if let Some(size) = self.conn.max_datagram_size() { + size + } else { + return Err(IoError::new(ErrorKind::Other, "datagram not supported")); + }; + + let pkt_id = self.next_pkt_id.fetch_add(1, Ordering::SeqCst); + + let header_without_addr_len = Command::Packet { + assoc_id, + pkt_id, + frag_total: 0, + frag_id: 0, + len: 0, + addr: None, + } + .serialized_len(); + + let first_frag_len = max_dg_size - header_without_addr_len - addr.serialized_len(); + + todo!() + } + + async fn send_packet_to_uni_stream( + &self, + assoc_id: u32, + addr: Address, + pkt: Bytes, + ) -> IoResult<()> { + let mut send = self.get_send_stream().await?; + + let cmd = Command::Packet { + assoc_id, + pkt_id: self.next_pkt_id.fetch_add(1, Ordering::SeqCst), + frag_total: 1, + frag_id: 0, + len: pkt.len() as u16, + addr: Some(addr), + }; + + cmd.write_to(&mut send).await?; + send.write_all(&pkt).await?; + + send.finish().await?; + + Ok(()) + } } #[derive(Debug)] diff --git a/src/protocol/marshaling.rs b/src/protocol/marshaling.rs index 764684b..f2edfab 100644 --- a/src/protocol/marshaling.rs +++ b/src/protocol/marshaling.rs @@ -46,11 +46,16 @@ impl Command { let mut rdr = Cursor::new(buf); let assoc_id = ReadBytesExt::read_u32::(&mut rdr).unwrap(); - let pkt_id = ReadBytesExt::read_u32::(&mut rdr).unwrap(); + let pkt_id = ReadBytesExt::read_u16::(&mut rdr).unwrap(); let frag_total = ReadBytesExt::read_u8(&mut rdr).unwrap(); let frag_id = ReadBytesExt::read_u8(&mut rdr).unwrap(); let len = ReadBytesExt::read_u16::(&mut rdr).unwrap(); - let addr = Address::read_from(r).await?; + + let addr = if frag_id == 0 { + Some(Address::read_from(r).await?) + } else { + None + }; Ok(Self::Packet { assoc_id, @@ -110,11 +115,14 @@ impl Command { } => { buf.put_u8(Self::TYPE_PACKET); buf.put_u32(*assoc_id); - buf.put_u32(*pkt_id); + buf.put_u16(*pkt_id); buf.put_u8(*frag_total); buf.put_u8(*frag_id); buf.put_u16(*len); - addr.write_to_buf(buf); + + if *frag_id == 0 { + addr.as_ref().unwrap().write_to_buf(buf); + } } Self::Dissociate { assoc_id } => { buf.put_u8(Self::TYPE_DISSOCIATE); @@ -131,7 +139,7 @@ impl Command { Self::Response(_) => 1, Self::Authenticate { .. } => 32, Self::Connect { addr } => addr.serialized_len(), - Self::Packet { addr, .. } => 12 + addr.serialized_len(), + Self::Packet { addr, .. } => 10 + addr.as_ref().map_or(0, |addr| addr.serialized_len()), Self::Dissociate { .. } => 4, Self::Heartbeat => 0, } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index f9293ca..dc0eb77 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -51,15 +51,15 @@ pub enum Command { // +----------+--------+------------+---------+-----+----------+ // | ASSOC_ID | PKT_ID | FRAG_TOTAL | FRAG_ID | LEN | ADDR | // +----------+--------+------------+---------+-----+----------+ - // | 4 | 4 | 1 | 1 | 2 | Variable | + // | 4 | 2 | 1 | 1 | 2 | Variable | // +----------+--------+------------+---------+-----+----------+ Packet { assoc_id: u32, - pkt_id: u32, + pkt_id: u16, frag_total: u8, frag_id: u8, len: u16, - addr: Address, + addr: Option
, }, // +----------+ @@ -102,7 +102,7 @@ impl Command { } pub const fn max_serialized_len() -> usize { - 2 + 12 + Address::max_serialized_len() + 2 + 10 + Address::max_serialized_len() } } From d66c360567e82f6ec19eb3e021d7bcae975c3e50 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 2 Aug 2022 02:45:45 +0900 Subject: [PATCH 009/103] support sending packet through datagram --- src/client/connection.rs | 47 ++++++++++++++++++++------ src/common/mod.rs | 4 +-- src/common/udp.rs | 68 ++++++++++++++++++++++++++++++++++++++ src/lib.rs | 2 +- src/protocol/marshaling.rs | 21 ------------ src/protocol/mod.rs | 21 +++++++++--- 6 files changed, 123 insertions(+), 40 deletions(-) diff --git a/src/client/connection.rs b/src/client/connection.rs index 1f259d3..8b12dbd 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -4,9 +4,9 @@ use super::{ }; use crate::{ protocol::{Address, Command, Error as TuicError}, - UdpRelayMode, + udp, UdpRelayMode, }; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use quinn::{ Connecting as QuinnConnecting, Connection as QuinnConnection, Datagrams, IncomingUniStreams, NewConnection as QuinnNewConnection, @@ -164,27 +164,52 @@ impl Connection { } fn send_packet_to_datagram(&self, assoc_id: u32, addr: Address, pkt: Bytes) -> IoResult<()> { - let max_dg_size = if let Some(size) = self.conn.max_datagram_size() { + let max_datagram_size = if let Some(size) = self.conn.max_datagram_size() { size } else { return Err(IoError::new(ErrorKind::Other, "datagram not supported")); }; let pkt_id = self.next_pkt_id.fetch_add(1, Ordering::SeqCst); + let mut pkts = udp::split_packet(pkt, &addr, max_datagram_size); + let frag_total = pkts.len() as u8; - let header_without_addr_len = Command::Packet { + let first_pkt = pkts.next().unwrap(); + let first_pkt_header = Command::Packet { assoc_id, pkt_id, - frag_total: 0, + frag_total, frag_id: 0, - len: 0, - addr: None, + len: first_pkt.len() as u16, + addr: Some(addr), + }; + + let mut buf = BytesMut::with_capacity(first_pkt_header.serialized_len() + first_pkt.len()); + first_pkt_header.write_to_buf(&mut buf); + buf.extend_from_slice(&first_pkt); + let buf = buf.freeze(); + + self.conn.send_datagram(buf).unwrap(); // TODO: error handling + + for (id, pkt) in pkts.enumerate() { + let pkt_header = Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id: id as u8 + 1, + len: pkt.len() as u16, + addr: None, + }; + + let mut buf = BytesMut::with_capacity(pkt_header.serialized_len() + pkt.len()); + pkt_header.write_to_buf(&mut buf); + buf.extend_from_slice(&pkt); + let buf = buf.freeze(); + + self.conn.send_datagram(buf).unwrap(); // TODO: error handling } - .serialized_len(); - let first_frag_len = max_dg_size - header_without_addr_len - addr.serialized_len(); - - todo!() + Ok(()) } async fn send_packet_to_uni_stream( diff --git a/src/common/mod.rs b/src/common/mod.rs index 88e0d83..7e5aaa1 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,3 +1 @@ -mod udp; - -pub use self::udp::UdpRelayMode; +pub mod udp; diff --git a/src/common/udp.rs b/src/common/udp.rs index 331c133..61065e5 100644 --- a/src/common/udp.rs +++ b/src/common/udp.rs @@ -1,5 +1,73 @@ +use crate::protocol::{Address, Command}; +use bytes::Bytes; + #[derive(Clone, Copy, Debug)] pub enum UdpRelayMode { Native, Quic, } + +pub fn split_packet(pkt: Bytes, addr: &Address, max_datagram_len: usize) -> SplitPacket { + SplitPacket::new(pkt, addr, max_datagram_len) +} + +#[derive(Debug)] +pub struct SplitPacket { + pkt: Bytes, + max_pkt_size: usize, + start: usize, + end: usize, + len: usize, +} + +impl SplitPacket { + #[inline] + fn new(pkt: Bytes, addr: &Address, max_datagram_size: usize) -> Self { + const DEFAULT_HEADER: Command = Command::Packet { + assoc_id: 0, + pkt_id: 0, + frag_total: 0, + frag_id: 0, + len: 0, + addr: None, + }; + + let first_pkt_size = + max_datagram_size - DEFAULT_HEADER.serialized_len() - addr.serialized_len(); + let max_pkt_size = max_datagram_size - DEFAULT_HEADER.serialized_len(); + let len = if first_pkt_size > pkt.len() { + 1 + (pkt.len() - first_pkt_size) / max_pkt_size + 1 + } else { + 1 + }; + + Self { + pkt, + max_pkt_size, + start: 0, + end: first_pkt_size, + len, + } + } +} + +impl Iterator for SplitPacket { + type Item = Bytes; + + fn next(&mut self) -> Option { + if self.start <= self.pkt.len() { + let next = self.pkt.slice(self.start..self.end.min(self.pkt.len())); + self.start += self.max_pkt_size; + self.end += self.max_pkt_size; + Some(next) + } else { + None + } + } +} + +impl ExactSizeIterator for SplitPacket { + fn len(&self) -> usize { + self.len + } +} diff --git a/src/lib.rs b/src/lib.rs index fb6f47c..952f68f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ mod server; pub mod client; #[cfg(any(feature = "server", feature = "client"))] -pub use crate::common::UdpRelayMode; +pub use crate::common::udp::{self, UdpRelayMode}; #[cfg(feature = "client")] pub use crate::client::{Client, ClientConfig}; diff --git a/src/protocol/marshaling.rs b/src/protocol/marshaling.rs index f2edfab..c86b12e 100644 --- a/src/protocol/marshaling.rs +++ b/src/protocol/marshaling.rs @@ -133,17 +133,6 @@ impl Command { } } } - - pub fn serialized_len(&self) -> usize { - 2 + match self { - Self::Response(_) => 1, - Self::Authenticate { .. } => 32, - Self::Connect { addr } => addr.serialized_len(), - Self::Packet { addr, .. } => 10 + addr.as_ref().map_or(0, |addr| addr.serialized_len()), - Self::Dissociate { .. } => 4, - Self::Heartbeat => 0, - } - } } impl Address { @@ -241,16 +230,6 @@ impl Address { }, } } - - pub fn serialized_len(&self) -> usize { - 1 + match self { - Address::DomainAddress(addr, _) => 1 + addr.len() + 2, - Address::SocketAddress(addr) => match addr { - SocketAddr::V4(_) => 6, - SocketAddr::V6(_) => 18, - }, - } - } } #[derive(Error, Debug)] diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index dc0eb77..8e6eb49 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -101,8 +101,15 @@ impl Command { } } - pub const fn max_serialized_len() -> usize { - 2 + 10 + Address::max_serialized_len() + pub fn serialized_len(&self) -> usize { + 2 + match self { + Self::Response(_) => 1, + Self::Authenticate { .. } => 32, + Self::Connect { addr } => addr.serialized_len(), + Self::Packet { addr, .. } => 10 + addr.as_ref().map_or(0, |addr| addr.serialized_len()), + Self::Dissociate { .. } => 4, + Self::Heartbeat => 0, + } } } @@ -131,8 +138,14 @@ impl Address { pub const TYPE_IPV4: u8 = 0x01; pub const TYPE_IPV6: u8 = 0x02; - pub const fn max_serialized_len() -> usize { - 1 + 1 + u8::MAX as usize + 2 + pub fn serialized_len(&self) -> usize { + 1 + match self { + Address::DomainAddress(addr, _) => 1 + addr.len() + 2, + Address::SocketAddress(addr) => match addr { + SocketAddr::V4(_) => 6, + SocketAddr::V6(_) => 18, + }, + } } } From f6ecd5edf2245caed3f25317b25e829cea97e79a Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 2 Aug 2022 02:49:49 +0900 Subject: [PATCH 010/103] optimizing `SplitPacket` --- src/common/udp.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/common/udp.rs b/src/common/udp.rs index 61065e5..e6d45e7 100644 --- a/src/common/udp.rs +++ b/src/common/udp.rs @@ -15,8 +15,8 @@ pub fn split_packet(pkt: Bytes, addr: &Address, max_datagram_len: usize) -> Spli pub struct SplitPacket { pkt: Bytes, max_pkt_size: usize, - start: usize, - end: usize, + next_start: usize, + next_end: usize, len: usize, } @@ -44,8 +44,8 @@ impl SplitPacket { Self { pkt, max_pkt_size, - start: 0, - end: first_pkt_size, + next_start: 0, + next_end: first_pkt_size, len, } } @@ -55,10 +55,14 @@ impl Iterator for SplitPacket { type Item = Bytes; fn next(&mut self) -> Option { - if self.start <= self.pkt.len() { - let next = self.pkt.slice(self.start..self.end.min(self.pkt.len())); - self.start += self.max_pkt_size; - self.end += self.max_pkt_size; + if self.next_start <= self.pkt.len() { + let next = self + .pkt + .slice(self.next_start..self.next_end.min(self.pkt.len())); + + self.next_start += self.max_pkt_size; + self.next_end += self.max_pkt_size; + Some(next) } else { None @@ -67,6 +71,7 @@ impl Iterator for SplitPacket { } impl ExactSizeIterator for SplitPacket { + #[inline] fn len(&self) -> usize { self.len } From c64af47e0705a6ee7b5473555254099259ce6243 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 2 Aug 2022 07:27:23 +0900 Subject: [PATCH 011/103] add `PacketBuffer` for reassembling packets --- Cargo.toml | 2 +- src/client/connection.rs | 7 +- src/common.rs | 199 +++++++++++++++++++++++++++++++++++++++ src/common/mod.rs | 1 - src/common/udp.rs | 78 --------------- src/lib.rs | 2 +- 6 files changed, 206 insertions(+), 83 deletions(-) create mode 100644 src/common.rs delete mode 100644 src/common/mod.rs delete mode 100644 src/common/udp.rs diff --git a/Cargo.toml b/Cargo.toml index dfe24e4..de914ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,7 +56,7 @@ all = ["protocol_marshaling", "server", "client"] protocol_marshaling = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] server = ["protocol_marshaling", "quinn", "rustls"] -client = ["protocol_marshaling", "quinn", "quinn-proto", "rustls", "thiserror"] +client = ["protocol_marshaling", "quinn", "quinn-proto", "rustls", "thiserror", "tokio/io-util"] [dev-dependencies] tuic = { path = ".", features = ["all"] } diff --git a/src/client/connection.rs b/src/client/connection.rs index 8b12dbd..ea9ad1b 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -3,8 +3,9 @@ use super::{ ConnectError, Stream, }; use crate::{ + common, protocol::{Address, Command, Error as TuicError}, - udp, UdpRelayMode, + PacketBuffer, UdpRelayMode, }; use bytes::{Bytes, BytesMut}; use quinn::{ @@ -93,6 +94,7 @@ impl Connection { datagrams, udp_relay_mode, stream_reg, + pkt_buf: PacketBuffer::new(), }; (conn, incoming) @@ -171,7 +173,7 @@ impl Connection { }; let pkt_id = self.next_pkt_id.fetch_add(1, Ordering::SeqCst); - let mut pkts = udp::split_packet(pkt, &addr, max_datagram_size); + let mut pkts = common::split_packet(pkt, &addr, max_datagram_size); let frag_total = pkts.len() as u8; let first_pkt = pkts.next().unwrap(); @@ -244,4 +246,5 @@ pub struct IncomingPackets { datagrams: Datagrams, udp_relay_mode: UdpRelayMode, stream_reg: Arc, + pkt_buf: PacketBuffer, } diff --git a/src/common.rs b/src/common.rs new file mode 100644 index 0000000..988a3bc --- /dev/null +++ b/src/common.rs @@ -0,0 +1,199 @@ +use crate::protocol::{Address, Command}; +use bytes::{Bytes, BytesMut}; +use std::{ + collections::{hash_map::Entry, HashMap}, + time::{Duration, Instant}, +}; +use thiserror::Error; + +#[derive(Clone, Copy, Debug)] +pub enum UdpRelayMode { + Native, + Quic, +} + +#[derive(Debug)] +pub struct PacketBuffer(HashMap); + +impl PacketBuffer { + pub(crate) fn new() -> Self { + Self(HashMap::new()) + } + + fn insert( + &mut self, + assoc_id: u32, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + addr: Option
, + pkt: Bytes, + ) -> Result, PacketBufferError> { + let key = PacketBufferKey { assoc_id, pkt_id }; + + if frag_id == 0 && addr.is_none() { + self.0.remove(&key); + return Err(PacketBufferError::NoAddress); + } + + if frag_id != 0 && addr.is_some() { + self.0.remove(&key); + return Err(PacketBufferError::UnexpectedAddress); + } + + match self.0.entry(key) { + Entry::Occupied(mut entry) => { + let v = entry.get_mut(); + + if v.buf[frag_id as usize].is_some() { + entry.remove(); + return Err(PacketBufferError::DuplicatedFragId); + } + + if v.buf.len() != frag_total as usize { + entry.remove(); + return Err(PacketBufferError::FragTotalNotMatch); + } + + v.total_len += pkt.len(); + v.buf[frag_id as usize] = Some(pkt); + v.recv_count += 1; + + if v.recv_count == frag_total as usize { + let v = entry.remove(); + let mut res = BytesMut::with_capacity(v.total_len); + + for pkt in v.buf { + res.extend_from_slice(&pkt.unwrap()); + } + + Ok(Some((assoc_id, v.addr.unwrap(), res.freeze()))) + } else { + Ok(None) + } + } + Entry::Vacant(entry) => { + if frag_total == 1 { + return Ok(Some((assoc_id, addr.unwrap(), pkt))); + } + + let mut v = PacketBufferValue { + buf: vec![None; frag_total as usize], + addr, + recv_count: 0, + total_len: 0, + create_time: Instant::now(), + }; + + v.total_len += pkt.len(); + v.buf[frag_id as usize] = Some(pkt); + v.recv_count += 1; + entry.insert(v); + + Ok(None) + } + } + } + + fn collect_garbage(&mut self, timeout: Duration) { + self.0.retain(|_, v| v.create_time.elapsed() < timeout); + } +} + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +struct PacketBufferKey { + assoc_id: u32, + pkt_id: u16, +} + +#[derive(Debug)] +struct PacketBufferValue { + buf: Vec>, + addr: Option
, + recv_count: usize, + total_len: usize, + create_time: Instant, +} + +#[derive(Error, Debug)] +pub enum PacketBufferError { + #[error("missing address in packet with frag_id 0")] + NoAddress, + #[error("unexpected address in packet")] + UnexpectedAddress, + #[error("duplicated frag_id")] + DuplicatedFragId, + #[error("frag_total not match")] + FragTotalNotMatch, +} + +#[inline] +pub(crate) fn split_packet(pkt: Bytes, addr: &Address, max_datagram_len: usize) -> SplitPacket { + SplitPacket::new(pkt, addr, max_datagram_len) +} + +#[derive(Debug)] +pub(crate) struct SplitPacket { + pkt: Bytes, + max_pkt_size: usize, + next_start: usize, + next_end: usize, + len: usize, +} + +impl SplitPacket { + #[inline] + fn new(pkt: Bytes, addr: &Address, max_datagram_size: usize) -> Self { + const DEFAULT_HEADER: Command = Command::Packet { + assoc_id: 0, + pkt_id: 0, + frag_total: 0, + frag_id: 0, + len: 0, + addr: None, + }; + + let first_pkt_size = + max_datagram_size - DEFAULT_HEADER.serialized_len() - addr.serialized_len(); + let max_pkt_size = max_datagram_size - DEFAULT_HEADER.serialized_len(); + let len = if first_pkt_size > pkt.len() { + 1 + (pkt.len() - first_pkt_size) / max_pkt_size + 1 + } else { + 1 + }; + + Self { + pkt, + max_pkt_size, + next_start: 0, + next_end: first_pkt_size, + len, + } + } +} + +impl Iterator for SplitPacket { + type Item = Bytes; + + fn next(&mut self) -> Option { + if self.next_start <= self.pkt.len() { + let next = self + .pkt + .slice(self.next_start..self.next_end.min(self.pkt.len())); + + self.next_start += self.max_pkt_size; + self.next_end += self.max_pkt_size; + + Some(next) + } else { + None + } + } +} + +impl ExactSizeIterator for SplitPacket { + #[inline] + fn len(&self) -> usize { + self.len + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs deleted file mode 100644 index 7e5aaa1..0000000 --- a/src/common/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod udp; diff --git a/src/common/udp.rs b/src/common/udp.rs deleted file mode 100644 index e6d45e7..0000000 --- a/src/common/udp.rs +++ /dev/null @@ -1,78 +0,0 @@ -use crate::protocol::{Address, Command}; -use bytes::Bytes; - -#[derive(Clone, Copy, Debug)] -pub enum UdpRelayMode { - Native, - Quic, -} - -pub fn split_packet(pkt: Bytes, addr: &Address, max_datagram_len: usize) -> SplitPacket { - SplitPacket::new(pkt, addr, max_datagram_len) -} - -#[derive(Debug)] -pub struct SplitPacket { - pkt: Bytes, - max_pkt_size: usize, - next_start: usize, - next_end: usize, - len: usize, -} - -impl SplitPacket { - #[inline] - fn new(pkt: Bytes, addr: &Address, max_datagram_size: usize) -> Self { - const DEFAULT_HEADER: Command = Command::Packet { - assoc_id: 0, - pkt_id: 0, - frag_total: 0, - frag_id: 0, - len: 0, - addr: None, - }; - - let first_pkt_size = - max_datagram_size - DEFAULT_HEADER.serialized_len() - addr.serialized_len(); - let max_pkt_size = max_datagram_size - DEFAULT_HEADER.serialized_len(); - let len = if first_pkt_size > pkt.len() { - 1 + (pkt.len() - first_pkt_size) / max_pkt_size + 1 - } else { - 1 - }; - - Self { - pkt, - max_pkt_size, - next_start: 0, - next_end: first_pkt_size, - len, - } - } -} - -impl Iterator for SplitPacket { - type Item = Bytes; - - fn next(&mut self) -> Option { - if self.next_start <= self.pkt.len() { - let next = self - .pkt - .slice(self.next_start..self.next_end.min(self.pkt.len())); - - self.next_start += self.max_pkt_size; - self.next_end += self.max_pkt_size; - - Some(next) - } else { - None - } - } -} - -impl ExactSizeIterator for SplitPacket { - #[inline] - fn len(&self) -> usize { - self.len - } -} diff --git a/src/lib.rs b/src/lib.rs index 952f68f..6bfdd14 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ mod server; pub mod client; #[cfg(any(feature = "server", feature = "client"))] -pub use crate::common::udp::{self, UdpRelayMode}; +pub use crate::common::{PacketBuffer, PacketBufferError, UdpRelayMode}; #[cfg(feature = "client")] pub use crate::client::{Client, ClientConfig}; From b1435d4a048cfba5e6148b0c9fc0d1efe7064a08 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 2 Aug 2022 08:31:56 +0900 Subject: [PATCH 012/103] support receiving packets from server --- Cargo.toml | 4 +- src/client/connection.rs | 123 +++++++++++++++++++++++++++++++++++++-- src/common.rs | 18 +++--- src/lib.rs | 2 +- 4 files changed, 131 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index de914ef..175c41a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ byteorder = { version = "1.4.*", default-features = false, optional = true } bytes = { version = "1.2.*", default-features = false, optional = true } # crossbeam-utils = { version = "0.8.*", default-features = false } # env_logger = { version = "0.9.*", features = ["humantime"], default-features = false } -# futures-util = { version = "0.3.*", default-features = false } +futures-util = { version = "0.3.*", default-features = false, optional = true } # getopts = "0.2.*" # log = { version = "0.4.*", features = ["serde", "std"] } # once_cell = { version = "1.13.*", features = ["parking_lot"] } @@ -56,7 +56,7 @@ all = ["protocol_marshaling", "server", "client"] protocol_marshaling = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] server = ["protocol_marshaling", "quinn", "rustls"] -client = ["protocol_marshaling", "quinn", "quinn-proto", "rustls", "thiserror", "tokio/io-util"] +client = ["futures-util", "protocol_marshaling", "quinn", "quinn-proto", "rustls", "thiserror", "tokio/io-util", "tokio/macros", "tokio/time"] [dev-dependencies] tuic = { path = ".", features = ["all"] } diff --git a/src/client/connection.rs b/src/client/connection.rs index ea9ad1b..f458c31 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -3,11 +3,12 @@ use super::{ ConnectError, Stream, }; use crate::{ - common, + common::{self, PacketBuffer}, protocol::{Address, Command, Error as TuicError}, - PacketBuffer, UdpRelayMode, + UdpRelayMode, }; use bytes::{Bytes, BytesMut}; +use futures_util::StreamExt; use quinn::{ Connecting as QuinnConnecting, Connection as QuinnConnection, Datagrams, IncomingUniStreams, NewConnection as QuinnNewConnection, @@ -18,8 +19,12 @@ use std::{ atomic::{AtomicU16, Ordering}, Arc, }, + time::{Duration, Instant}, +}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + time, }; -use tokio::io::AsyncWriteExt; #[derive(Debug)] pub struct Connecting { @@ -95,6 +100,7 @@ impl Connection { udp_relay_mode, stream_reg, pkt_buf: PacketBuffer::new(), + last_gc_time: Instant::now(), }; (conn, incoming) @@ -233,7 +239,6 @@ impl Connection { cmd.write_to(&mut send).await?; send.write_all(&pkt).await?; - send.finish().await?; Ok(()) @@ -247,4 +252,114 @@ pub struct IncomingPackets { udp_relay_mode: UdpRelayMode, stream_reg: Arc, pkt_buf: PacketBuffer, + last_gc_time: Instant, +} + +impl IncomingPackets { + pub async fn accept( + &mut self, + gc_interval: Duration, + gc_timeout: Duration, + ) -> Option<(u32, u16, Address, Bytes)> { + match self.udp_relay_mode { + UdpRelayMode::Native => self.accept_from_datagrams(gc_interval, gc_timeout).await, + UdpRelayMode::Quic => self.accept_from_uni_streams().await, + } + } + + async fn accept_from_datagrams( + &mut self, + gc_interval: Duration, + gc_timeout: Duration, + ) -> Option<(u32, u16, Address, Bytes)> { + if self.last_gc_time.elapsed() > gc_interval { + self.pkt_buf.collect_garbage(gc_timeout); + self.last_gc_time = Instant::now(); + } + + let mut gc_interval = time::interval(gc_interval); + + loop { + tokio::select! { + dg = self.datagrams.next() => { + if let Some(dg) = dg { + let dg = dg.unwrap(); + let cmd = Command::read_from(&mut dg.as_ref()).await.unwrap(); + let cmd_len = cmd.serialized_len(); + + if let Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } = cmd + { + if let Some(pkt) = self + .pkt_buf + .insert( + assoc_id, + pkt_id, + frag_total, + frag_id, + addr, + dg.slice(cmd_len..cmd_len + len as usize), + ) + .unwrap() + { + break Some(pkt); + }; + } else { + todo!() + } + } else { + break None; + } + } + _ = gc_interval.tick() => { + self.pkt_buf.collect_garbage(gc_timeout); + self.last_gc_time = Instant::now(); + } + } + } + } + + async fn accept_from_uni_streams(&mut self) -> Option<(u32, u16, Address, Bytes)> { + if let Some(stream) = self.uni_streams.next().await { + let mut recv = RecvStream::new(stream.unwrap(), self.stream_reg.as_ref().clone()); + + if let Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } = Command::read_from(&mut recv).await.unwrap() + { + if frag_id != 0 { + todo!() + } + + if frag_total != 1 { + todo!() + } + + if addr.is_none() { + todo!() + } + + let mut buf = vec![0; len as usize]; + recv.read_exact(&mut buf).await.unwrap(); + let pkt = Bytes::from(buf); + + Some((assoc_id, pkt_id, addr.unwrap(), pkt)) + } else { + todo!() + } + } else { + None + } + } } diff --git a/src/common.rs b/src/common.rs index 988a3bc..c1767d4 100644 --- a/src/common.rs +++ b/src/common.rs @@ -13,14 +13,14 @@ pub enum UdpRelayMode { } #[derive(Debug)] -pub struct PacketBuffer(HashMap); +pub(crate) struct PacketBuffer(HashMap); impl PacketBuffer { pub(crate) fn new() -> Self { Self(HashMap::new()) } - fn insert( + pub(crate) fn insert( &mut self, assoc_id: u32, pkt_id: u16, @@ -28,7 +28,7 @@ impl PacketBuffer { frag_id: u8, addr: Option
, pkt: Bytes, - ) -> Result, PacketBufferError> { + ) -> Result, PacketBufferError> { let key = PacketBufferKey { assoc_id, pkt_id }; if frag_id == 0 && addr.is_none() { @@ -67,14 +67,14 @@ impl PacketBuffer { res.extend_from_slice(&pkt.unwrap()); } - Ok(Some((assoc_id, v.addr.unwrap(), res.freeze()))) + Ok(Some((assoc_id, pkt_id, v.addr.unwrap(), res.freeze()))) } else { Ok(None) } } Entry::Vacant(entry) => { if frag_total == 1 { - return Ok(Some((assoc_id, addr.unwrap(), pkt))); + return Ok(Some((assoc_id, pkt_id, addr.unwrap(), pkt))); } let mut v = PacketBufferValue { @@ -82,7 +82,7 @@ impl PacketBuffer { addr, recv_count: 0, total_len: 0, - create_time: Instant::now(), + c_time: Instant::now(), }; v.total_len += pkt.len(); @@ -95,8 +95,8 @@ impl PacketBuffer { } } - fn collect_garbage(&mut self, timeout: Duration) { - self.0.retain(|_, v| v.create_time.elapsed() < timeout); + pub(crate) fn collect_garbage(&mut self, timeout: Duration) { + self.0.retain(|_, v| v.c_time.elapsed() < timeout); } } @@ -112,7 +112,7 @@ struct PacketBufferValue { addr: Option
, recv_count: usize, total_len: usize, - create_time: Instant, + c_time: Instant, } #[derive(Error, Debug)] diff --git a/src/lib.rs b/src/lib.rs index 6bfdd14..8a704c8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ mod server; pub mod client; #[cfg(any(feature = "server", feature = "client"))] -pub use crate::common::{PacketBuffer, PacketBufferError, UdpRelayMode}; +pub use crate::common::{PacketBufferError, UdpRelayMode}; #[cfg(feature = "client")] pub use crate::client::{Client, ClientConfig}; From 5cf0c6637bb90c88ef987e908c2fe430bb6674de Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 2 Aug 2022 08:38:53 +0900 Subject: [PATCH 013/103] workspace member `client` and `server` --- Cargo.toml | 16 ++++++---------- client/Cargo.toml | 8 ++++++++ client/main.rs | 3 --- client/src/main.rs | 3 +++ server/Cargo.toml | 8 ++++++++ server/main.rs | 3 --- server/src/main.rs | 3 +++ 7 files changed, 28 insertions(+), 16 deletions(-) create mode 100644 client/Cargo.toml delete mode 100644 client/main.rs create mode 100644 client/src/main.rs create mode 100644 server/Cargo.toml delete mode 100644 server/main.rs create mode 100644 server/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index 175c41a..924f30e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tuic" -version = "0.9.0" +version = "5.0.0" authors = ["EAimTY "] description = "Delicately-TUICed high-performance proxy built on top of the QUIC protocol" categories = ["network-programming", "command-line-utilities"] @@ -11,15 +11,11 @@ readme = "README.md" license = "GPL-3.0-or-later" repository = "https://github.com/EAimTY/tuic" -[[bin]] -name = "server" -path = "server/main.rs" -required-features = ["server"] - -[[bin]] -name = "client" -path = "client/main.rs" -required-features = ["client"] +[workspace] +members = [ + "client", + "server", +] [dependencies] # blake3 = "1.3.*" diff --git a/client/Cargo.toml b/client/Cargo.toml new file mode 100644 index 0000000..035a0de --- /dev/null +++ b/client/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "tuic-client" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/client/main.rs b/client/main.rs deleted file mode 100644 index 47ad8c6..0000000 --- a/client/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("Hello World!"); -} diff --git a/client/src/main.rs b/client/src/main.rs new file mode 100644 index 0000000..e7a11a9 --- /dev/null +++ b/client/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +} diff --git a/server/Cargo.toml b/server/Cargo.toml new file mode 100644 index 0000000..23e9215 --- /dev/null +++ b/server/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "tuic-server" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/server/main.rs b/server/main.rs deleted file mode 100644 index 47ad8c6..0000000 --- a/server/main.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - println!("Hello World!"); -} diff --git a/server/src/main.rs b/server/src/main.rs new file mode 100644 index 0000000..e7a11a9 --- /dev/null +++ b/server/src/main.rs @@ -0,0 +1,3 @@ +fn main() { + println!("Hello, world!"); +} From 4ef0c959aab817387d784c06aaaeb0175d948790 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 2 Aug 2022 10:14:13 +0900 Subject: [PATCH 014/103] error handling for `IncomingPackets` --- src/client/connection.rs | 197 ++++++++++++++++++++++++++------------- src/client/mod.rs | 2 + src/common.rs | 30 +++++- src/lib.rs | 2 +- src/protocol/mod.rs | 2 +- 5 files changed, 162 insertions(+), 71 deletions(-) diff --git a/src/client/connection.rs b/src/client/connection.rs index f458c31..a8c2db3 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -5,13 +5,14 @@ use super::{ use crate::{ common::{self, PacketBuffer}, protocol::{Address, Command, Error as TuicError}, - UdpRelayMode, + Packet, PacketBufferError, UdpRelayMode, }; use bytes::{Bytes, BytesMut}; use futures_util::StreamExt; use quinn::{ - Connecting as QuinnConnecting, Connection as QuinnConnection, Datagrams, IncomingUniStreams, - NewConnection as QuinnNewConnection, + Connecting as QuinnConnecting, Connection as QuinnConnection, + ConnectionError as QuinnConnectionError, Datagrams, IncomingUniStreams, + NewConnection as QuinnNewConnection, RecvStream as QuinnRecvStream, }; use std::{ io::{Error as IoError, ErrorKind, Result as IoResult}, @@ -21,6 +22,7 @@ use std::{ }, time::{Duration, Instant}, }; +use thiserror::Error; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, time, @@ -130,7 +132,7 @@ impl Connection { let resp = match Command::read_from(&mut stream).await { Ok(Command::Response(resp)) => Ok(resp), - Ok(cmd) => Err(TuicError::InvalidCommand(cmd.type_code())), + Ok(cmd) => Err(TuicError::InvalidCommand(cmd.as_type_code())), Err(err) => Err(err), }; @@ -144,7 +146,7 @@ impl Connection { res } - pub async fn send_packet(&self, assoc_id: u32, addr: Address, pkt: Bytes) -> IoResult<()> { + pub async fn packet(&self, assoc_id: u32, addr: Address, pkt: Bytes) -> IoResult<()> { match self.udp_relay_mode { UdpRelayMode::Native => self.send_packet_to_datagram(assoc_id, addr, pkt), UdpRelayMode::Quic => self.send_packet_to_uni_stream(assoc_id, addr, pkt).await, @@ -260,7 +262,7 @@ impl IncomingPackets { &mut self, gc_interval: Duration, gc_timeout: Duration, - ) -> Option<(u32, u16, Address, Bytes)> { + ) -> Option> { match self.udp_relay_mode { UdpRelayMode::Native => self.accept_from_datagrams(gc_interval, gc_timeout).await, UdpRelayMode::Quic => self.accept_from_uni_streams().await, @@ -271,7 +273,43 @@ impl IncomingPackets { &mut self, gc_interval: Duration, gc_timeout: Duration, - ) -> Option<(u32, u16, Address, Bytes)> { + ) -> Option> { + async fn process_datagram( + pkt_buf: &mut PacketBuffer, + dg: Result, + ) -> Result, IncomingPacketsError> { + let dg = dg.unwrap(); + let cmd = Command::read_from(&mut dg.as_ref()).await.unwrap(); + let cmd_len = cmd.serialized_len(); + + match cmd { + Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } => { + if let Some(pkt) = pkt_buf.insert( + assoc_id, + pkt_id, + frag_total, + frag_id, + addr, + dg.slice(cmd_len..cmd_len + len as usize), + )? { + Ok(Some(pkt)) + } else { + Ok(None) + } + } + cmd => Err(IncomingPacketsError::Tuic(TuicError::InvalidCommand( + cmd.as_type_code(), + ))), + } + } + if self.last_gc_time.elapsed() > gc_interval { self.pkt_buf.collect_garbage(gc_timeout); self.last_gc_time = Instant::now(); @@ -283,35 +321,10 @@ impl IncomingPackets { tokio::select! { dg = self.datagrams.next() => { if let Some(dg) = dg { - let dg = dg.unwrap(); - let cmd = Command::read_from(&mut dg.as_ref()).await.unwrap(); - let cmd_len = cmd.serialized_len(); - - if let Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - } = cmd - { - if let Some(pkt) = self - .pkt_buf - .insert( - assoc_id, - pkt_id, - frag_total, - frag_id, - addr, - dg.slice(cmd_len..cmd_len + len as usize), - ) - .unwrap() - { - break Some(pkt); - }; - } else { - todo!() + match process_datagram(&mut self.pkt_buf, dg).await { + Ok(Some(pkt)) => break Some(Ok(pkt)), + Ok(None) => {} + Err(err) => break Some(Err(err)), } } else { break None; @@ -325,41 +338,93 @@ impl IncomingPackets { } } - async fn accept_from_uni_streams(&mut self) -> Option<(u32, u16, Address, Bytes)> { - if let Some(stream) = self.uni_streams.next().await { - let mut recv = RecvStream::new(stream.unwrap(), self.stream_reg.as_ref().clone()); + async fn accept_from_uni_streams(&mut self) -> Option> { + async fn process_uni_stream( + recv: Result, + stream_reg: StreamReg, + ) -> Result { + let recv = match recv { + Ok(recv) => recv, + Err(err) => return Err(IncomingPacketsError::from_quinn_connection_error(err)), + }; - if let Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - } = Command::read_from(&mut recv).await.unwrap() - { - if frag_id != 0 { - todo!() + let mut recv = RecvStream::new(recv, stream_reg); + let cmd = Command::read_from(&mut recv).await?; + + match cmd { + Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } => { + if frag_id != 0 || frag_total != 1 { + return Err(IncomingPacketsError::FragmentedPacketFromUniStream); + } + + if addr.is_none() { + return Err(IncomingPacketsError::NoAddressPacketFromUniStream); + } + + let mut buf = vec![0; len as usize]; + recv.read_exact(&mut buf).await?; + let pkt = Bytes::from(buf); + + Ok(Packet::new(assoc_id, pkt_id, addr.unwrap(), pkt)) } - - if frag_total != 1 { - todo!() - } - - if addr.is_none() { - todo!() - } - - let mut buf = vec![0; len as usize]; - recv.read_exact(&mut buf).await.unwrap(); - let pkt = Bytes::from(buf); - - Some((assoc_id, pkt_id, addr.unwrap(), pkt)) - } else { - todo!() + _ => Err(IncomingPacketsError::Tuic(TuicError::InvalidCommand( + cmd.as_type_code(), + ))), } + } + + if let Some(recv) = self.uni_streams.next().await { + Some(process_uni_stream(recv, self.stream_reg.as_ref().clone()).await) } else { None } } } + +#[derive(Error, Debug)] +pub enum IncomingPacketsError { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Tuic(TuicError), + #[error(transparent)] + PacketBuffer(#[from] PacketBufferError), + #[error("received fragmented packet from uni stream")] + FragmentedPacketFromUniStream, + #[error("received packet without address from uni stream")] + NoAddressPacketFromUniStream, +} + +impl IncomingPacketsError { + #[inline] + fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { + Self::Io(IoError::from(err)) + } +} + +impl From for IncomingPacketsError { + #[inline] + fn from(err: TuicError) -> Self { + match err { + TuicError::Io(err) => Self::Io(err), + err => Self::Tuic(err), + } + } +} + +impl From for IoError { + #[inline] + fn from(err: IncomingPacketsError) -> Self { + match err { + IncomingPacketsError::Io(err) => Self::from(err), + err => Self::new(ErrorKind::Other, err), + } + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 063b604..4178609 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -181,6 +181,7 @@ pub enum ConnectError { } impl ConnectError { + #[inline] fn from_quinn_connect_error(err: QuinnConnectError) -> Self { match err { QuinnConnectError::UnsupportedVersion => Self::UnsupportedQUICVersion, @@ -192,6 +193,7 @@ impl ConnectError { } } + #[inline] fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { match err { QuinnConnectionError::VersionMismatch => Self::UnsupportedQUICVersion, diff --git a/src/common.rs b/src/common.rs index c1767d4..9bdd084 100644 --- a/src/common.rs +++ b/src/common.rs @@ -12,6 +12,25 @@ pub enum UdpRelayMode { Quic, } +#[derive(Debug)] +pub struct Packet { + pub id: u16, + pub associate_id: u32, + pub address: Address, + pub data: Bytes, +} + +impl Packet { + pub(crate) fn new(assoc_id: u32, pkt_id: u16, addr: Address, pkt: Bytes) -> Self { + Self { + id: pkt_id, + associate_id: assoc_id, + address: addr, + data: pkt, + } + } +} + #[derive(Debug)] pub(crate) struct PacketBuffer(HashMap); @@ -28,7 +47,7 @@ impl PacketBuffer { frag_id: u8, addr: Option
, pkt: Bytes, - ) -> Result, PacketBufferError> { + ) -> Result, PacketBufferError> { let key = PacketBufferKey { assoc_id, pkt_id }; if frag_id == 0 && addr.is_none() { @@ -67,14 +86,19 @@ impl PacketBuffer { res.extend_from_slice(&pkt.unwrap()); } - Ok(Some((assoc_id, pkt_id, v.addr.unwrap(), res.freeze()))) + Ok(Some(Packet::new( + assoc_id, + pkt_id, + v.addr.unwrap(), + res.freeze(), + ))) } else { Ok(None) } } Entry::Vacant(entry) => { if frag_total == 1 { - return Ok(Some((assoc_id, pkt_id, addr.unwrap(), pkt))); + return Ok(Some(Packet::new(assoc_id, pkt_id, addr.unwrap(), pkt))); } let mut v = PacketBufferValue { diff --git a/src/lib.rs b/src/lib.rs index 8a704c8..0617542 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ mod server; pub mod client; #[cfg(any(feature = "server", feature = "client"))] -pub use crate::common::{PacketBufferError, UdpRelayMode}; +pub use crate::common::{Packet, PacketBufferError, UdpRelayMode}; #[cfg(feature = "client")] pub use crate::client::{Client, ClientConfig}; diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 8e6eb49..df1323f 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -90,7 +90,7 @@ impl Command { pub const RESPONSE_SUCCEEDED: u8 = 0x00; pub const RESPONSE_FAILED: u8 = 0xff; - pub const fn type_code(&self) -> u8 { + pub const fn as_type_code(&self) -> u8 { match self { Command::Response(_) => Self::TYPE_RESPONSE, Command::Authenticate(_) => Self::TYPE_AUTHENTICATE, From 1bbac1ce3bcc7760bb78fe3cb9e4201b714a243e Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 2 Aug 2022 10:17:42 +0900 Subject: [PATCH 015/103] re-export `IncomingPacketsError` --- src/client/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 4178609..c404928 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -2,7 +2,7 @@ mod connection; mod stream; pub use self::{ - connection::{Connecting, Connection, IncomingPackets}, + connection::{Connecting, Connection, IncomingPackets, IncomingPacketsError}, stream::Stream, }; From abad5613d721f6042188f3a0dd3618f1ba3c5ef6 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 3 Aug 2022 12:11:30 +0900 Subject: [PATCH 016/103] error handling for client `Connection` --- src/client/connection.rs | 140 +++++++++++++++++++++++++++++---------- src/client/mod.rs | 14 ++-- src/common.rs | 14 +++- 3 files changed, 122 insertions(+), 46 deletions(-) diff --git a/src/client/connection.rs b/src/client/connection.rs index a8c2db3..62259a8 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -12,10 +12,10 @@ use futures_util::StreamExt; use quinn::{ Connecting as QuinnConnecting, Connection as QuinnConnection, ConnectionError as QuinnConnectionError, Datagrams, IncomingUniStreams, - NewConnection as QuinnNewConnection, RecvStream as QuinnRecvStream, + NewConnection as QuinnNewConnection, SendDatagramError as QuinnSendDatagramError, }; use std::{ - io::{Error as IoError, ErrorKind, Result as IoResult}, + io::{Error as IoError, ErrorKind}, sync::{ atomic::{AtomicU16, Ordering}, Arc, @@ -54,10 +54,10 @@ impl Connecting { datagrams, uni_streams, .. - } = match self.conn.await { - Ok(conn) => conn, - Err(err) => return Err(ConnectError::from_quinn_connection_error(err)), - }; + } = self + .conn + .await + .map_err(ConnectError::from_quinn_connection_error)?; Ok(Connection::new( connection, @@ -108,7 +108,7 @@ impl Connection { (conn, incoming) } - pub async fn authenticate(&self) -> IoResult<()> { + pub async fn authenticate(&self) -> Result<(), ConnectionError> { let mut send = self.get_send_stream().await?; let cmd = Command::Authenticate(self.token); cmd.write_to(&mut send).await?; @@ -116,7 +116,7 @@ impl Connection { Ok(()) } - pub async fn heartbeat(&self) -> IoResult<()> { + pub async fn heartbeat(&self) -> Result<(), ConnectionError> { let mut send = self.get_send_stream().await?; let cmd = Command::Heartbeat; cmd.write_to(&mut send).await?; @@ -124,7 +124,7 @@ impl Connection { Ok(()) } - pub async fn connect(&self, addr: Address) -> IoResult> { + pub async fn connect(&self, addr: Address) -> Result, ConnectionError> { let mut stream = self.get_bi_stream().await?; let cmd = Command::Connect { addr }; @@ -139,21 +139,26 @@ impl Connection { let res = match resp { Ok(true) => return Ok(Some(stream)), Ok(false) => Ok(None), - Err(err) => Err(IoError::from(err)), + Err(err) => Err(ConnectionError::from(err)), }; stream.finish().await?; res } - pub async fn packet(&self, assoc_id: u32, addr: Address, pkt: Bytes) -> IoResult<()> { + pub async fn packet( + &self, + assoc_id: u32, + addr: Address, + pkt: Bytes, + ) -> Result<(), ConnectionError> { match self.udp_relay_mode { UdpRelayMode::Native => self.send_packet_to_datagram(assoc_id, addr, pkt), UdpRelayMode::Quic => self.send_packet_to_uni_stream(assoc_id, addr, pkt).await, } } - pub async fn dissociate(&self, assoc_id: u32) -> IoResult<()> { + pub async fn dissociate(&self, assoc_id: u32) -> Result<(), ConnectionError> { let mut send = self.get_send_stream().await?; let cmd = Command::Dissociate { assoc_id }; cmd.write_to(&mut send).await?; @@ -161,23 +166,39 @@ impl Connection { Ok(()) } - async fn get_send_stream(&self) -> IoResult { - let send = self.conn.open_uni().await?; + async fn get_send_stream(&self) -> Result { + let send = self + .conn + .open_uni() + .await + .map_err(ConnectionError::from_quinn_connection_error)?; + Ok(SendStream::new(send, self.stream_reg.as_ref().clone())) } - async fn get_bi_stream(&self) -> IoResult { - let (send, recv) = self.conn.open_bi().await?; + async fn get_bi_stream(&self) -> Result { + let (send, recv) = self + .conn + .open_bi() + .await + .map_err(ConnectionError::from_quinn_connection_error)?; + let send = SendStream::new(send, self.stream_reg.as_ref().clone()); let recv = RecvStream::new(recv, self.stream_reg.as_ref().clone()); + Ok(Stream::new(send, recv)) } - fn send_packet_to_datagram(&self, assoc_id: u32, addr: Address, pkt: Bytes) -> IoResult<()> { + fn send_packet_to_datagram( + &self, + assoc_id: u32, + addr: Address, + pkt: Bytes, + ) -> Result<(), ConnectionError> { let max_datagram_size = if let Some(size) = self.conn.max_datagram_size() { size } else { - return Err(IoError::new(ErrorKind::Other, "datagram not supported")); + return Err(ConnectionError::DatagramDisabled); }; let pkt_id = self.next_pkt_id.fetch_add(1, Ordering::SeqCst); @@ -199,7 +220,9 @@ impl Connection { buf.extend_from_slice(&first_pkt); let buf = buf.freeze(); - self.conn.send_datagram(buf).unwrap(); // TODO: error handling + self.conn + .send_datagram(buf) + .map_err(ConnectionError::from_quinn_send_datagram_error)?; for (id, pkt) in pkts.enumerate() { let pkt_header = Command::Packet { @@ -216,7 +239,9 @@ impl Connection { buf.extend_from_slice(&pkt); let buf = buf.freeze(); - self.conn.send_datagram(buf).unwrap(); // TODO: error handling + self.conn + .send_datagram(buf) + .map_err(ConnectionError::from_quinn_send_datagram_error)?; } Ok(()) @@ -227,7 +252,7 @@ impl Connection { assoc_id: u32, addr: Address, pkt: Bytes, - ) -> IoResult<()> { + ) -> Result<(), ConnectionError> { let mut send = self.get_send_stream().await?; let cmd = Command::Packet { @@ -274,12 +299,13 @@ impl IncomingPackets { gc_interval: Duration, gc_timeout: Duration, ) -> Option> { + #[inline] async fn process_datagram( pkt_buf: &mut PacketBuffer, - dg: Result, + dg: Result, ) -> Result, IncomingPacketsError> { - let dg = dg.unwrap(); - let cmd = Command::read_from(&mut dg.as_ref()).await.unwrap(); + let dg = dg?; + let cmd = Command::read_from(&mut dg.as_ref()).await?; let cmd_len = cmd.serialized_len(); match cmd { @@ -321,6 +347,7 @@ impl IncomingPackets { tokio::select! { dg = self.datagrams.next() => { if let Some(dg) = dg { + let dg = dg.map_err(IncomingPacketsError::from_quinn_connection_error); match process_datagram(&mut self.pkt_buf, dg).await { Ok(Some(pkt)) => break Some(Ok(pkt)), Ok(None) => {} @@ -339,16 +366,11 @@ impl IncomingPackets { } async fn accept_from_uni_streams(&mut self) -> Option> { + #[inline] async fn process_uni_stream( - recv: Result, - stream_reg: StreamReg, + recv: Result, ) -> Result { - let recv = match recv { - Ok(recv) => recv, - Err(err) => return Err(IncomingPacketsError::from_quinn_connection_error(err)), - }; - - let mut recv = RecvStream::new(recv, stream_reg); + let mut recv = recv?; let cmd = Command::read_from(&mut recv).await?; match cmd { @@ -361,7 +383,9 @@ impl IncomingPackets { addr, } => { if frag_id != 0 || frag_total != 1 { - return Err(IncomingPacketsError::FragmentedPacketFromUniStream); + return Err(IncomingPacketsError::FragmentedPacketFromUniStream( + frag_id, frag_total, + )); } if addr.is_none() { @@ -381,13 +405,57 @@ impl IncomingPackets { } if let Some(recv) = self.uni_streams.next().await { - Some(process_uni_stream(recv, self.stream_reg.as_ref().clone()).await) + let recv = recv + .map(|recv| RecvStream::new(recv, self.stream_reg.as_ref().clone())) + .map_err(IncomingPacketsError::from_quinn_connection_error); + Some(process_uni_stream(recv).await) } else { None } } } +#[derive(Error, Debug)] +pub enum ConnectionError { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Tuic(TuicError), + #[error("datagrams not supported by peer")] + DatagramUnsupportedByPeer, + #[error("datagram support disabled")] + DatagramDisabled, + #[error("datagram too large")] + DatagramTooLarge, +} + +impl ConnectionError { + #[inline] + fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { + Self::Io(IoError::from(err)) + } + + #[inline] + fn from_quinn_send_datagram_error(err: QuinnSendDatagramError) -> Self { + match err { + QuinnSendDatagramError::UnsupportedByPeer => Self::DatagramUnsupportedByPeer, + QuinnSendDatagramError::Disabled => Self::DatagramDisabled, + QuinnSendDatagramError::TooLarge => Self::DatagramTooLarge, + QuinnSendDatagramError::ConnectionLost(err) => Self::Io(IoError::from(err)), + } + } +} + +impl From for ConnectionError { + #[inline] + fn from(err: TuicError) -> Self { + match err { + TuicError::Io(err) => Self::Io(err), + err => Self::Tuic(err), + } + } +} + #[derive(Error, Debug)] pub enum IncomingPacketsError { #[error(transparent)] @@ -396,8 +464,8 @@ pub enum IncomingPacketsError { Tuic(TuicError), #[error(transparent)] PacketBuffer(#[from] PacketBufferError), - #[error("received fragmented packet from uni stream")] - FragmentedPacketFromUniStream, + #[error("received fragmented packet from uni stream: {0} in {1} packets")] + FragmentedPacketFromUniStream(u8, u8), #[error("received packet without address from uni stream")] NoAddressPacketFromUniStream, } diff --git a/src/client/mod.rs b/src/client/mod.rs index c404928..4fbe134 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -101,10 +101,10 @@ impl Client { server_name: &str, token: [u8; 32], ) -> Result<(Connection, IncomingPackets), ConnectError> { - let conn = match self.endpoint.connect(addr, server_name) { - Ok(conn) => conn, - Err(err) => return Err(ConnectError::from_quinn_connect_error(err)), - }; + let conn = self + .endpoint + .connect(addr, server_name) + .map_err(ConnectError::from_quinn_connect_error)?; let QuinnNewConnection { connection, @@ -123,10 +123,8 @@ impl Client { } } } else { - match conn.await { - Ok(conn) => conn, - Err(err) => return Err(ConnectError::from_quinn_connection_error(err)), - } + conn.await + .map_err(ConnectError::from_quinn_connection_error)? }; Ok(Connection::new( diff --git a/src/common.rs b/src/common.rs index 9bdd084..afd15bd 100644 --- a/src/common.rs +++ b/src/common.rs @@ -64,6 +64,10 @@ impl PacketBuffer { Entry::Occupied(mut entry) => { let v = entry.get_mut(); + if frag_id >= frag_total { + return Err(PacketBufferError::FragIdExceed); + } + if v.buf[frag_id as usize].is_some() { entry.remove(); return Err(PacketBufferError::DuplicatedFragId); @@ -97,6 +101,10 @@ impl PacketBuffer { } } Entry::Vacant(entry) => { + if frag_id >= frag_total { + return Err(PacketBufferError::FragIdExceed); + } + if frag_total == 1 { return Ok(Some(Packet::new(assoc_id, pkt_id, addr.unwrap(), pkt))); } @@ -149,11 +157,13 @@ pub enum PacketBufferError { DuplicatedFragId, #[error("frag_total not match")] FragTotalNotMatch, + #[error("frag_id exceed")] + FragIdExceed, } #[inline] -pub(crate) fn split_packet(pkt: Bytes, addr: &Address, max_datagram_len: usize) -> SplitPacket { - SplitPacket::new(pkt, addr, max_datagram_len) +pub(crate) fn split_packet(pkt: Bytes, addr: &Address, max_datagram_size: usize) -> SplitPacket { + SplitPacket::new(pkt, addr, max_datagram_size) } #[derive(Debug)] From 48ee5ec42bbeb849c76715820321295a9410a654 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 3 Aug 2022 12:13:40 +0900 Subject: [PATCH 017/103] re-export `ConnectionError` --- src/client/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index 4fbe134..8d773f0 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -2,7 +2,7 @@ mod connection; mod stream; pub use self::{ - connection::{Connecting, Connection, IncomingPackets, IncomingPacketsError}, + connection::{Connecting, Connection, ConnectionError, IncomingPackets, IncomingPacketsError}, stream::Stream, }; From 4430d8663649856b15d36380f0c9135a8711d2e4 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 3 Aug 2022 12:34:35 +0900 Subject: [PATCH 018/103] simplify `IncomingPacketsError` --- src/client/connection.rs | 33 +++++++++++++++++++++------------ src/common.rs | 32 +++++++++++--------------------- src/lib.rs | 2 +- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/src/client/connection.rs b/src/client/connection.rs index 62259a8..69d174b 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -3,9 +3,9 @@ use super::{ ConnectError, Stream, }; use crate::{ - common::{self, PacketBuffer}, + common::{self, PacketBuffer, PacketBufferError}, protocol::{Address, Command, Error as TuicError}, - Packet, PacketBufferError, UdpRelayMode, + Packet, UdpRelayMode, }; use bytes::{Bytes, BytesMut}; use futures_util::StreamExt; @@ -383,13 +383,11 @@ impl IncomingPackets { addr, } => { if frag_id != 0 || frag_total != 1 { - return Err(IncomingPacketsError::FragmentedPacketFromUniStream( - frag_id, frag_total, - )); + return Err(IncomingPacketsError::BadFragment); } if addr.is_none() { - return Err(IncomingPacketsError::NoAddressPacketFromUniStream); + return Err(IncomingPacketsError::NoAddress); } let mut buf = vec![0; len as usize]; @@ -462,12 +460,12 @@ pub enum IncomingPacketsError { Io(#[from] IoError), #[error(transparent)] Tuic(TuicError), - #[error(transparent)] - PacketBuffer(#[from] PacketBufferError), - #[error("received fragmented packet from uni stream: {0} in {1} packets")] - FragmentedPacketFromUniStream(u8, u8), - #[error("received packet without address from uni stream")] - NoAddressPacketFromUniStream, + #[error("received bad-fragmented packet")] + BadFragment, + #[error("missing address in packet with frag_id 0")] + NoAddress, + #[error("unexpected address in packet")] + UnexpectedAddress, } impl IncomingPacketsError { @@ -477,6 +475,17 @@ impl IncomingPacketsError { } } +impl From for IncomingPacketsError { + #[inline] + fn from(err: PacketBufferError) -> Self { + match err { + PacketBufferError::NoAddress => Self::NoAddress, + PacketBufferError::UnexpectedAddress => Self::UnexpectedAddress, + PacketBufferError::BadFragment => Self::BadFragment, + } + } +} + impl From for IncomingPacketsError { #[inline] fn from(err: TuicError) -> Self { diff --git a/src/common.rs b/src/common.rs index afd15bd..3d31c03 100644 --- a/src/common.rs +++ b/src/common.rs @@ -64,18 +64,12 @@ impl PacketBuffer { Entry::Occupied(mut entry) => { let v = entry.get_mut(); - if frag_id >= frag_total { - return Err(PacketBufferError::FragIdExceed); - } - - if v.buf[frag_id as usize].is_some() { - entry.remove(); - return Err(PacketBufferError::DuplicatedFragId); - } - - if v.buf.len() != frag_total as usize { - entry.remove(); - return Err(PacketBufferError::FragTotalNotMatch); + if frag_total == 0 + || frag_id >= frag_total + || v.buf.len() != frag_total as usize + || v.buf[frag_id as usize].is_some() + { + return Err(PacketBufferError::BadFragment); } v.total_len += pkt.len(); @@ -101,8 +95,8 @@ impl PacketBuffer { } } Entry::Vacant(entry) => { - if frag_id >= frag_total { - return Err(PacketBufferError::FragIdExceed); + if frag_total == 0 || frag_id >= frag_total { + return Err(PacketBufferError::BadFragment); } if frag_total == 1 { @@ -148,17 +142,13 @@ struct PacketBufferValue { } #[derive(Error, Debug)] -pub enum PacketBufferError { +pub(crate) enum PacketBufferError { #[error("missing address in packet with frag_id 0")] NoAddress, #[error("unexpected address in packet")] UnexpectedAddress, - #[error("duplicated frag_id")] - DuplicatedFragId, - #[error("frag_total not match")] - FragTotalNotMatch, - #[error("frag_id exceed")] - FragIdExceed, + #[error("received bad-fragmented packet")] + BadFragment, } #[inline] diff --git a/src/lib.rs b/src/lib.rs index 0617542..4814ff6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ mod server; pub mod client; #[cfg(any(feature = "server", feature = "client"))] -pub use crate::common::{Packet, PacketBufferError, UdpRelayMode}; +pub use crate::common::{Packet, UdpRelayMode}; #[cfg(feature = "client")] pub use crate::client::{Client, ClientConfig}; From dc08a7cb770f3c394196d3328348333b78d6f185 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 3 Aug 2022 12:42:02 +0900 Subject: [PATCH 019/103] move `IncomingPackets` into a separate mod --- src/client/connection.rs | 211 ++---------------------------------- src/client/incoming.rs | 225 +++++++++++++++++++++++++++++++++++++++ src/client/mod.rs | 4 +- 3 files changed, 236 insertions(+), 204 deletions(-) create mode 100644 src/client/incoming.rs diff --git a/src/client/connection.rs b/src/client/connection.rs index 69d174b..1ff77aa 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -1,14 +1,13 @@ use super::{ stream::{RecvStream, SendStream, StreamReg}, - ConnectError, Stream, + ConnectError, IncomingPackets, Stream, }; use crate::{ - common::{self, PacketBuffer, PacketBufferError}, + common, protocol::{Address, Command, Error as TuicError}, - Packet, UdpRelayMode, + UdpRelayMode, }; use bytes::{Bytes, BytesMut}; -use futures_util::StreamExt; use quinn::{ Connecting as QuinnConnecting, Connection as QuinnConnection, ConnectionError as QuinnConnectionError, Datagrams, IncomingUniStreams, @@ -20,13 +19,9 @@ use std::{ atomic::{AtomicU16, Ordering}, Arc, }, - time::{Duration, Instant}, }; use thiserror::Error; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - time, -}; +use tokio::io::AsyncWriteExt; #[derive(Debug)] pub struct Connecting { @@ -96,14 +91,7 @@ impl Connection { next_pkt_id: Arc::new(AtomicU16::new(0)), }; - let incoming = IncomingPackets { - uni_streams, - datagrams, - udp_relay_mode, - stream_reg, - pkt_buf: PacketBuffer::new(), - last_gc_time: Instant::now(), - }; + let incoming = IncomingPackets::new(uni_streams, datagrams, udp_relay_mode, stream_reg); (conn, incoming) } @@ -272,147 +260,6 @@ impl Connection { } } -#[derive(Debug)] -pub struct IncomingPackets { - uni_streams: IncomingUniStreams, - datagrams: Datagrams, - udp_relay_mode: UdpRelayMode, - stream_reg: Arc, - pkt_buf: PacketBuffer, - last_gc_time: Instant, -} - -impl IncomingPackets { - pub async fn accept( - &mut self, - gc_interval: Duration, - gc_timeout: Duration, - ) -> Option> { - match self.udp_relay_mode { - UdpRelayMode::Native => self.accept_from_datagrams(gc_interval, gc_timeout).await, - UdpRelayMode::Quic => self.accept_from_uni_streams().await, - } - } - - async fn accept_from_datagrams( - &mut self, - gc_interval: Duration, - gc_timeout: Duration, - ) -> Option> { - #[inline] - async fn process_datagram( - pkt_buf: &mut PacketBuffer, - dg: Result, - ) -> Result, IncomingPacketsError> { - let dg = dg?; - let cmd = Command::read_from(&mut dg.as_ref()).await?; - let cmd_len = cmd.serialized_len(); - - match cmd { - Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - } => { - if let Some(pkt) = pkt_buf.insert( - assoc_id, - pkt_id, - frag_total, - frag_id, - addr, - dg.slice(cmd_len..cmd_len + len as usize), - )? { - Ok(Some(pkt)) - } else { - Ok(None) - } - } - cmd => Err(IncomingPacketsError::Tuic(TuicError::InvalidCommand( - cmd.as_type_code(), - ))), - } - } - - if self.last_gc_time.elapsed() > gc_interval { - self.pkt_buf.collect_garbage(gc_timeout); - self.last_gc_time = Instant::now(); - } - - let mut gc_interval = time::interval(gc_interval); - - loop { - tokio::select! { - dg = self.datagrams.next() => { - if let Some(dg) = dg { - let dg = dg.map_err(IncomingPacketsError::from_quinn_connection_error); - match process_datagram(&mut self.pkt_buf, dg).await { - Ok(Some(pkt)) => break Some(Ok(pkt)), - Ok(None) => {} - Err(err) => break Some(Err(err)), - } - } else { - break None; - } - } - _ = gc_interval.tick() => { - self.pkt_buf.collect_garbage(gc_timeout); - self.last_gc_time = Instant::now(); - } - } - } - } - - async fn accept_from_uni_streams(&mut self) -> Option> { - #[inline] - async fn process_uni_stream( - recv: Result, - ) -> Result { - let mut recv = recv?; - let cmd = Command::read_from(&mut recv).await?; - - match cmd { - Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - } => { - if frag_id != 0 || frag_total != 1 { - return Err(IncomingPacketsError::BadFragment); - } - - if addr.is_none() { - return Err(IncomingPacketsError::NoAddress); - } - - let mut buf = vec![0; len as usize]; - recv.read_exact(&mut buf).await?; - let pkt = Bytes::from(buf); - - Ok(Packet::new(assoc_id, pkt_id, addr.unwrap(), pkt)) - } - _ => Err(IncomingPacketsError::Tuic(TuicError::InvalidCommand( - cmd.as_type_code(), - ))), - } - } - - if let Some(recv) = self.uni_streams.next().await { - let recv = recv - .map(|recv| RecvStream::new(recv, self.stream_reg.as_ref().clone())) - .map_err(IncomingPacketsError::from_quinn_connection_error); - Some(process_uni_stream(recv).await) - } else { - None - } - } -} - #[derive(Error, Debug)] pub enum ConnectionError { #[error(transparent)] @@ -454,53 +301,11 @@ impl From for ConnectionError { } } -#[derive(Error, Debug)] -pub enum IncomingPacketsError { - #[error(transparent)] - Io(#[from] IoError), - #[error(transparent)] - Tuic(TuicError), - #[error("received bad-fragmented packet")] - BadFragment, - #[error("missing address in packet with frag_id 0")] - NoAddress, - #[error("unexpected address in packet")] - UnexpectedAddress, -} - -impl IncomingPacketsError { +impl From for IoError { #[inline] - fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { - Self::Io(IoError::from(err)) - } -} - -impl From for IncomingPacketsError { - #[inline] - fn from(err: PacketBufferError) -> Self { + fn from(err: ConnectionError) -> Self { match err { - PacketBufferError::NoAddress => Self::NoAddress, - PacketBufferError::UnexpectedAddress => Self::UnexpectedAddress, - PacketBufferError::BadFragment => Self::BadFragment, - } - } -} - -impl From for IncomingPacketsError { - #[inline] - fn from(err: TuicError) -> Self { - match err { - TuicError::Io(err) => Self::Io(err), - err => Self::Tuic(err), - } - } -} - -impl From for IoError { - #[inline] - fn from(err: IncomingPacketsError) -> Self { - match err { - IncomingPacketsError::Io(err) => Self::from(err), + ConnectionError::Io(err) => Self::from(err), err => Self::new(ErrorKind::Other, err), } } diff --git a/src/client/incoming.rs b/src/client/incoming.rs new file mode 100644 index 0000000..9dddc50 --- /dev/null +++ b/src/client/incoming.rs @@ -0,0 +1,225 @@ +use super::stream::{RecvStream, StreamReg}; +use crate::{ + common::{PacketBuffer, PacketBufferError}, + protocol::{Command, Error as TuicError}, + Packet, UdpRelayMode, +}; +use bytes::Bytes; +use futures_util::StreamExt; +use quinn::{ConnectionError as QuinnConnectionError, Datagrams, IncomingUniStreams}; +use std::{ + io::{Error as IoError, ErrorKind}, + sync::Arc, + time::{Duration, Instant}, +}; +use thiserror::Error; +use tokio::{io::AsyncReadExt, time}; + +#[derive(Debug)] +pub struct IncomingPackets { + uni_streams: IncomingUniStreams, + datagrams: Datagrams, + udp_relay_mode: UdpRelayMode, + stream_reg: Arc, + pkt_buf: PacketBuffer, + last_gc_time: Instant, +} + +impl IncomingPackets { + pub(super) fn new( + uni_streams: IncomingUniStreams, + datagrams: Datagrams, + udp_relay_mode: UdpRelayMode, + stream_reg: Arc, + ) -> Self { + Self { + uni_streams, + datagrams, + udp_relay_mode, + stream_reg, + pkt_buf: PacketBuffer::new(), + last_gc_time: Instant::now(), + } + } + + pub async fn accept( + &mut self, + gc_interval: Duration, + gc_timeout: Duration, + ) -> Option> { + match self.udp_relay_mode { + UdpRelayMode::Native => self.accept_from_datagrams(gc_interval, gc_timeout).await, + UdpRelayMode::Quic => self.accept_from_uni_streams().await, + } + } + + async fn accept_from_datagrams( + &mut self, + gc_interval: Duration, + gc_timeout: Duration, + ) -> Option> { + #[inline] + async fn process_datagram( + pkt_buf: &mut PacketBuffer, + dg: Result, + ) -> Result, IncomingPacketsError> { + let dg = dg?; + let cmd = Command::read_from(&mut dg.as_ref()).await?; + let cmd_len = cmd.serialized_len(); + + match cmd { + Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } => { + if let Some(pkt) = pkt_buf.insert( + assoc_id, + pkt_id, + frag_total, + frag_id, + addr, + dg.slice(cmd_len..cmd_len + len as usize), + )? { + Ok(Some(pkt)) + } else { + Ok(None) + } + } + cmd => Err(IncomingPacketsError::Tuic(TuicError::InvalidCommand( + cmd.as_type_code(), + ))), + } + } + + if self.last_gc_time.elapsed() > gc_interval { + self.pkt_buf.collect_garbage(gc_timeout); + self.last_gc_time = Instant::now(); + } + + let mut gc_interval = time::interval(gc_interval); + + loop { + tokio::select! { + dg = self.datagrams.next() => { + if let Some(dg) = dg { + let dg = dg.map_err(IncomingPacketsError::from_quinn_connection_error); + match process_datagram(&mut self.pkt_buf, dg).await { + Ok(Some(pkt)) => break Some(Ok(pkt)), + Ok(None) => {} + Err(err) => break Some(Err(err)), + } + } else { + break None; + } + } + _ = gc_interval.tick() => { + self.pkt_buf.collect_garbage(gc_timeout); + self.last_gc_time = Instant::now(); + } + } + } + } + + async fn accept_from_uni_streams(&mut self) -> Option> { + #[inline] + async fn process_uni_stream( + recv: Result, + ) -> Result { + let mut recv = recv?; + let cmd = Command::read_from(&mut recv).await?; + + match cmd { + Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } => { + if frag_id != 0 || frag_total != 1 { + return Err(IncomingPacketsError::BadFragment); + } + + if addr.is_none() { + return Err(IncomingPacketsError::NoAddress); + } + + let mut buf = vec![0; len as usize]; + recv.read_exact(&mut buf).await?; + let pkt = Bytes::from(buf); + + Ok(Packet::new(assoc_id, pkt_id, addr.unwrap(), pkt)) + } + _ => Err(IncomingPacketsError::Tuic(TuicError::InvalidCommand( + cmd.as_type_code(), + ))), + } + } + + if let Some(recv) = self.uni_streams.next().await { + let recv = recv + .map(|recv| RecvStream::new(recv, self.stream_reg.as_ref().clone())) + .map_err(IncomingPacketsError::from_quinn_connection_error); + Some(process_uni_stream(recv).await) + } else { + None + } + } +} + +#[derive(Error, Debug)] +pub enum IncomingPacketsError { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Tuic(TuicError), + #[error("received bad-fragmented packet")] + BadFragment, + #[error("missing address in packet with frag_id 0")] + NoAddress, + #[error("unexpected address in packet")] + UnexpectedAddress, +} + +impl IncomingPacketsError { + #[inline] + fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { + Self::Io(IoError::from(err)) + } +} + +impl From for IncomingPacketsError { + #[inline] + fn from(err: PacketBufferError) -> Self { + match err { + PacketBufferError::NoAddress => Self::NoAddress, + PacketBufferError::UnexpectedAddress => Self::UnexpectedAddress, + PacketBufferError::BadFragment => Self::BadFragment, + } + } +} + +impl From for IncomingPacketsError { + #[inline] + fn from(err: TuicError) -> Self { + match err { + TuicError::Io(err) => Self::Io(err), + err => Self::Tuic(err), + } + } +} + +impl From for IoError { + #[inline] + fn from(err: IncomingPacketsError) -> Self { + match err { + IncomingPacketsError::Io(err) => Self::from(err), + err => Self::new(ErrorKind::Other, err), + } + } +} diff --git a/src/client/mod.rs b/src/client/mod.rs index 8d773f0..b75718b 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,8 +1,10 @@ mod connection; +mod incoming; mod stream; pub use self::{ - connection::{Connecting, Connection, ConnectionError, IncomingPackets, IncomingPacketsError}, + connection::{Connecting, Connection, ConnectionError}, + incoming::{IncomingPackets, IncomingPacketsError}, stream::Stream, }; From ed7fe0bdc11f8f83c4403c1e5d304ae9e27099a9 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 4 Aug 2022 12:38:49 +0900 Subject: [PATCH 020/103] abstract out `CongestionController` --- src/client/mod.rs | 61 ++++++++++++++++++++++++++++++++--------------- src/common.rs | 9 ++++++- src/lib.rs | 2 +- 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/src/client/mod.rs b/src/client/mod.rs index b75718b..0c88e90 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -8,15 +8,17 @@ pub use self::{ stream::Stream, }; -use crate::UdpRelayMode; +use crate::{CongestionController, UdpRelayMode}; use quinn::{ - congestion::ControllerFactory, ApplicationClose, ClientConfig as QuinnClientConfig, - ConnectError as QuinnConnectError, ConnectionClose, ConnectionError as QuinnConnectionError, - Endpoint, EndpointConfig, NewConnection as QuinnNewConnection, + congestion::{BbrConfig, CubicConfig, NewRenoConfig}, + ApplicationClose, ClientConfig as QuinnClientConfig, ConnectError as QuinnConnectError, + ConnectionClose, ConnectionError as QuinnConnectionError, Endpoint, EndpointConfig, + NewConnection as QuinnNewConnection, }; use quinn_proto::TransportError; use rustls::{version, ClientConfig as RustlsClientConfig, RootCertStore}; use std::{ + fmt::Debug, io::Result as IoResult, net::{SocketAddr, ToSocketAddrs, UdpSocket}, sync::Arc, @@ -30,10 +32,7 @@ pub struct Client { } impl Client { - pub fn bind(cfg: ClientConfig, addr: impl ToSocketAddrs) -> IoResult - where - C: ControllerFactory + Send + Sync + 'static, - { + pub fn bind(cfg: ClientConfig, addr: impl ToSocketAddrs) -> IoResult { let socket = UdpSocket::bind(addr)?; let (mut ep, _) = Endpoint::new(EndpointConfig::default(), None, socket)?; @@ -52,9 +51,20 @@ impl Client { let mut quinn_config = QuinnClientConfig::new(Arc::new(crypto)); let transport = Arc::get_mut(&mut quinn_config.transport).unwrap(); - transport.congestion_controller_factory(cfg.congestion_controller); transport.max_idle_timeout(None); + match cfg.congestion_controller { + CongestionController::Cubic => { + transport.congestion_controller_factory(Arc::new(CubicConfig::default())); + } + CongestionController::NewReno => { + transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); + } + CongestionController::Bbr => { + transport.congestion_controller_factory(Arc::new(BbrConfig::default())); + } + } + ep.set_default_client_config(quinn_config); Ok(Self { @@ -64,10 +74,7 @@ impl Client { }) } - pub fn reconfigure(&mut self, cfg: ClientConfig) - where - C: ControllerFactory + Send + Sync + 'static, - { + pub fn reconfigure(&mut self, cfg: ClientConfig) { let mut crypto = RustlsClientConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() @@ -83,9 +90,20 @@ impl Client { let mut quinn_config = QuinnClientConfig::new(Arc::new(crypto)); let transport = Arc::get_mut(&mut quinn_config.transport).unwrap(); - transport.congestion_controller_factory(cfg.congestion_controller); transport.max_idle_timeout(None); + match cfg.congestion_controller { + CongestionController::Cubic => { + transport.congestion_controller_factory(Arc::new(CubicConfig::default())); + } + CongestionController::NewReno => { + transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); + } + CongestionController::Bbr => { + transport.congestion_controller_factory(Arc::new(BbrConfig::default())); + } + } + self.endpoint.set_default_client_config(quinn_config); self.udp_relay_mode = cfg.udp_relay_mode; @@ -139,17 +157,22 @@ impl Client { } } +impl Debug for Client { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Client") + .field("endpoint", &self.endpoint) + .finish() + } +} + #[derive(Clone, Debug)] -pub struct ClientConfig -where - C: ControllerFactory + Send + Sync + 'static, -{ +pub struct ClientConfig { pub certs: RootCertStore, pub alpn_protocols: Vec>, pub disable_sni: bool, pub enable_0rtt: bool, pub udp_relay_mode: UdpRelayMode, - pub congestion_controller: C, + pub congestion_controller: CongestionController, } #[derive(Error, Debug)] diff --git a/src/common.rs b/src/common.rs index 3d31c03..923eec4 100644 --- a/src/common.rs +++ b/src/common.rs @@ -6,13 +6,20 @@ use std::{ }; use thiserror::Error; +#[derive(Clone, Copy, Debug)] +pub enum CongestionController { + Cubic, + NewReno, + Bbr, +} + #[derive(Clone, Copy, Debug)] pub enum UdpRelayMode { Native, Quic, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Packet { pub id: u16, pub associate_id: u32, diff --git a/src/lib.rs b/src/lib.rs index 4814ff6..4d4248a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ mod server; pub mod client; #[cfg(any(feature = "server", feature = "client"))] -pub use crate::common::{Packet, UdpRelayMode}; +pub use crate::common::{CongestionController, Packet, UdpRelayMode}; #[cfg(feature = "client")] pub use crate::client::{Client, ClientConfig}; From d7002a5e3d3900e4fccfb5bf9bbb28c4d04ef5e2 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 4 Aug 2022 16:32:48 +0900 Subject: [PATCH 021/103] adding the server endpoint --- Cargo.toml | 5 +- src/client/connection.rs | 6 +- src/client/mod.rs | 108 +++++++++------------ src/common.rs | 2 +- src/lib.rs | 9 +- src/server/connection.rs | 1 + src/{server.rs => server/incoming.rs} | 0 src/server/mod.rs | 133 ++++++++++++++++++++++++++ 8 files changed, 191 insertions(+), 73 deletions(-) create mode 100644 src/server/connection.rs rename src/{server.rs => server/incoming.rs} (100%) create mode 100644 src/server/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 924f30e..d836065 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,6 @@ futures-util = { version = "0.3.*", default-features = false, optional = true } # once_cell = { version = "1.13.*", features = ["parking_lot"] } # parking_lot = "0.12.*" quinn = { version = "0.8.*", features = ["tls-rustls"], default-features = false, optional = true } -quinn-proto = { version = "0.8.*", default-features = false, optional = true } # rand = "0.8.*" rustls = { version = "0.20.*", default-features = false, optional = true } # rustls-native-certs = "0.6.*" @@ -51,8 +50,8 @@ all = ["protocol_marshaling", "server", "client"] protocol_marshaling = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] -server = ["protocol_marshaling", "quinn", "rustls"] -client = ["futures-util", "protocol_marshaling", "quinn", "quinn-proto", "rustls", "thiserror", "tokio/io-util", "tokio/macros", "tokio/time"] +server = ["protocol_marshaling", "quinn", "rustls", "thiserror"] +client = ["futures-util", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util", "tokio/macros", "tokio/time"] [dev-dependencies] tuic = { path = ".", features = ["all"] } diff --git a/src/client/connection.rs b/src/client/connection.rs index 1ff77aa..c4580d2 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -1,6 +1,6 @@ use super::{ stream::{RecvStream, SendStream, StreamReg}, - ConnectError, IncomingPackets, Stream, + IncomingPackets, Stream, }; use crate::{ common, @@ -43,7 +43,7 @@ impl Connecting { } } - pub async fn establish(self) -> Result<(Connection, IncomingPackets), ConnectError> { + pub async fn establish(self) -> Result<(Connection, IncomingPackets), ConnectionError> { let QuinnNewConnection { connection, datagrams, @@ -52,7 +52,7 @@ impl Connecting { } = self .conn .await - .map_err(ConnectError::from_quinn_connection_error)?; + .map_err(ConnectionError::from_quinn_connection_error)?; Ok(Connection::new( connection, diff --git a/src/client/mod.rs b/src/client/mod.rs index 0c88e90..a9304db 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -8,44 +8,41 @@ pub use self::{ stream::Stream, }; -use crate::{CongestionController, UdpRelayMode}; +use crate::{CongestionControl, UdpRelayMode}; use quinn::{ congestion::{BbrConfig, CubicConfig, NewRenoConfig}, - ApplicationClose, ClientConfig as QuinnClientConfig, ConnectError as QuinnConnectError, - ConnectionClose, ConnectionError as QuinnConnectionError, Endpoint, EndpointConfig, + ClientConfig as QuinnClientConfig, ConnectError as QuinnConnectError, + ConnectionError as QuinnConnectionError, Endpoint, EndpointConfig, NewConnection as QuinnNewConnection, }; -use quinn_proto::TransportError; use rustls::{version, ClientConfig as RustlsClientConfig, RootCertStore}; use std::{ + convert::Infallible, fmt::Debug, - io::Result as IoResult, - net::{SocketAddr, ToSocketAddrs, UdpSocket}, + io::Error as IoError, + net::{SocketAddr, UdpSocket}, sync::Arc, }; use thiserror::Error; pub struct Client { endpoint: Endpoint, - enable_0rtt: bool, + enable_quic_0rtt: bool, udp_relay_mode: UdpRelayMode, } impl Client { - pub fn bind(cfg: ClientConfig, addr: impl ToSocketAddrs) -> IoResult { - let socket = UdpSocket::bind(addr)?; - let (mut ep, _) = Endpoint::new(EndpointConfig::default(), None, socket)?; - + pub fn bind(cfg: ClientConfig, socket: UdpSocket) -> Result { let mut crypto = RustlsClientConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() .with_protocol_versions(&[&version::TLS13]) .unwrap() - .with_root_certificates(cfg.certs) + .with_root_certificates(cfg.certificates) .with_no_client_auth(); crypto.alpn_protocols = cfg.alpn_protocols; - crypto.enable_early_data = cfg.enable_0rtt; + crypto.enable_early_data = cfg.enable_quic_0rtt; crypto.enable_sni = !cfg.disable_sni; let mut quinn_config = QuinnClientConfig::new(Arc::new(crypto)); @@ -54,37 +51,38 @@ impl Client { transport.max_idle_timeout(None); match cfg.congestion_controller { - CongestionController::Cubic => { + CongestionControl::Cubic => { transport.congestion_controller_factory(Arc::new(CubicConfig::default())); } - CongestionController::NewReno => { + CongestionControl::NewReno => { transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); } - CongestionController::Bbr => { + CongestionControl::Bbr => { transport.congestion_controller_factory(Arc::new(BbrConfig::default())); } } + let (mut ep, _) = Endpoint::new(EndpointConfig::default(), None, socket)?; ep.set_default_client_config(quinn_config); Ok(Self { endpoint: ep, udp_relay_mode: cfg.udp_relay_mode, - enable_0rtt: cfg.enable_0rtt, + enable_quic_0rtt: cfg.enable_quic_0rtt, }) } - pub fn reconfigure(&mut self, cfg: ClientConfig) { + pub fn reconfigure(&mut self, cfg: ClientConfig) -> Result<(), Infallible> { let mut crypto = RustlsClientConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() .with_protocol_versions(&[&version::TLS13]) .unwrap() - .with_root_certificates(cfg.certs) + .with_root_certificates(cfg.certificates) .with_no_client_auth(); crypto.alpn_protocols = cfg.alpn_protocols; - crypto.enable_early_data = cfg.enable_0rtt; + crypto.enable_early_data = cfg.enable_quic_0rtt; crypto.enable_sni = !cfg.disable_sni; let mut quinn_config = QuinnClientConfig::new(Arc::new(crypto)); @@ -93,13 +91,13 @@ impl Client { transport.max_idle_timeout(None); match cfg.congestion_controller { - CongestionController::Cubic => { + CongestionControl::Cubic => { transport.congestion_controller_factory(Arc::new(CubicConfig::default())); } - CongestionController::NewReno => { + CongestionControl::NewReno => { transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); } - CongestionController::Bbr => { + CongestionControl::Bbr => { transport.congestion_controller_factory(Arc::new(BbrConfig::default())); } } @@ -107,12 +105,14 @@ impl Client { self.endpoint.set_default_client_config(quinn_config); self.udp_relay_mode = cfg.udp_relay_mode; - self.enable_0rtt = cfg.enable_0rtt; + self.enable_quic_0rtt = cfg.enable_quic_0rtt; + + Ok(()) } - pub fn rebind(&mut self, addr: impl ToSocketAddrs) -> IoResult<()> { - let socket = UdpSocket::bind(addr)?; - self.endpoint.rebind(socket) + pub fn rebind(&mut self, socket: UdpSocket) -> Result<(), ClientError> { + self.endpoint.rebind(socket)?; + Ok(()) } pub async fn connect( @@ -120,22 +120,22 @@ impl Client { addr: SocketAddr, server_name: &str, token: [u8; 32], - ) -> Result<(Connection, IncomingPackets), ConnectError> { + ) -> Result<(Connection, IncomingPackets), ClientError> { let conn = self .endpoint .connect(addr, server_name) - .map_err(ConnectError::from_quinn_connect_error)?; + .map_err(ClientError::from_quinn_connect_error)?; let QuinnNewConnection { connection, datagrams, uni_streams, .. - } = if self.enable_0rtt { + } = if self.enable_quic_0rtt { match conn.into_0rtt() { Ok((conn, _)) => conn, Err(conn) => { - return Err(ConnectError::Convert0Rtt(Connecting::new( + return Err(ClientError::Convert0Rtt(Connecting::new( conn, token, self.udp_relay_mode, @@ -144,7 +144,7 @@ impl Client { } } else { conn.await - .map_err(ConnectError::from_quinn_connection_error)? + .map_err(ClientError::from_quinn_connection_error)? }; Ok(Connection::new( @@ -167,50 +167,40 @@ impl Debug for Client { #[derive(Clone, Debug)] pub struct ClientConfig { - pub certs: RootCertStore, + pub certificates: RootCertStore, pub alpn_protocols: Vec>, pub disable_sni: bool, - pub enable_0rtt: bool, + pub enable_quic_0rtt: bool, pub udp_relay_mode: UdpRelayMode, - pub congestion_controller: CongestionController, + pub congestion_controller: CongestionControl, } #[derive(Error, Debug)] -pub enum ConnectError { +pub enum ClientError { #[error("failed to convert QUIC connection into 0-RTT")] Convert0Rtt(Connecting), - #[error("unsupported QUIC version")] - UnsupportedQUICVersion, + #[error(transparent)] + Io(#[from] IoError), #[error("endpoint stopping")] EndpointStopping, #[error("too many connections")] TooManyConnections, - #[error("invalid domain name: {0}")] - InvalidDomainName(String), + #[error("invalid DNS name: {0}")] + InvalidDnsName(String), #[error("invalid remote address: {0}")] InvalidRemoteAddress(SocketAddr), - #[error(transparent)] - TransportError(#[from] TransportError), - #[error("aborted by peer: {0}")] - ConnectionClosed(ConnectionClose), - #[error("closed by peer: {0}")] - ApplicationClosed(ApplicationClose), - #[error("reset by peer")] - Reset, - #[error("timed out")] - TimedOut, - #[error("closed")] - LocallyClosed, + #[error("unsupported QUIC version")] + UnsupportedQUICVersion, } -impl ConnectError { +impl ClientError { #[inline] fn from_quinn_connect_error(err: QuinnConnectError) -> Self { match err { QuinnConnectError::UnsupportedVersion => Self::UnsupportedQUICVersion, QuinnConnectError::EndpointStopping => Self::EndpointStopping, QuinnConnectError::TooManyConnections => Self::TooManyConnections, - QuinnConnectError::InvalidDnsName(err) => Self::InvalidDomainName(err), + QuinnConnectError::InvalidDnsName(err) => Self::InvalidDnsName(err), QuinnConnectError::InvalidRemoteAddress(err) => Self::InvalidRemoteAddress(err), QuinnConnectError::NoDefaultClientConfig => unreachable!(), } @@ -218,14 +208,6 @@ impl ConnectError { #[inline] fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { - match err { - QuinnConnectionError::VersionMismatch => Self::UnsupportedQUICVersion, - QuinnConnectionError::TransportError(err) => Self::TransportError(err), - QuinnConnectionError::ConnectionClosed(err) => Self::ConnectionClosed(err), - QuinnConnectionError::ApplicationClosed(err) => Self::ApplicationClosed(err), - QuinnConnectionError::Reset => Self::Reset, - QuinnConnectionError::TimedOut => Self::TimedOut, - QuinnConnectionError::LocallyClosed => Self::LocallyClosed, - } + Self::from(IoError::from(err)) } } diff --git a/src/common.rs b/src/common.rs index 923eec4..536592a 100644 --- a/src/common.rs +++ b/src/common.rs @@ -7,7 +7,7 @@ use std::{ use thiserror::Error; #[derive(Clone, Copy, Debug)] -pub enum CongestionController { +pub enum CongestionControl { Cubic, NewReno, Bbr, diff --git a/src/lib.rs b/src/lib.rs index 4d4248a..1370efc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,13 +4,16 @@ pub mod protocol; mod common; #[cfg(feature = "server")] -mod server; +pub mod server; #[cfg(feature = "client")] pub mod client; #[cfg(any(feature = "server", feature = "client"))] -pub use crate::common::{CongestionController, Packet, UdpRelayMode}; +pub use crate::common::{CongestionControl, Packet, UdpRelayMode}; #[cfg(feature = "client")] -pub use crate::client::{Client, ClientConfig}; +pub use crate::client::{Client, ClientConfig, ClientError}; + +#[cfg(feature = "server")] +pub use crate::server::{Server, ServerConfig, ServerError}; diff --git a/src/server/connection.rs b/src/server/connection.rs new file mode 100644 index 0000000..810e86f --- /dev/null +++ b/src/server/connection.rs @@ -0,0 +1 @@ +pub struct Connecting; diff --git a/src/server.rs b/src/server/incoming.rs similarity index 100% rename from src/server.rs rename to src/server/incoming.rs diff --git a/src/server/mod.rs b/src/server/mod.rs new file mode 100644 index 0000000..4eed924 --- /dev/null +++ b/src/server/mod.rs @@ -0,0 +1,133 @@ +mod connection; +mod incoming; + +pub use self::connection::Connecting; + +use crate::CongestionControl; +use quinn::{ + congestion::{BbrConfig, CubicConfig, NewRenoConfig}, + Endpoint, EndpointConfig, IdleTimeout, Incoming, ServerConfig as QuinnServerConfig, VarInt, +}; +use rustls::{ + version, Certificate, Error as RustlsError, PrivateKey, ServerConfig as RustlsServerConfig, +}; +use std::{io::Error as IoError, net::UdpSocket, sync::Arc, time::Duration}; +use thiserror::Error; + +#[derive(Debug)] +pub struct Server { + endpoint: Endpoint, + incoming: Incoming, +} + +impl Server { + pub fn bind(cfg: ServerConfig, socket: UdpSocket) -> Result { + let mut crypto = RustlsServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&version::TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(cfg.certificate_chain, cfg.private_key)?; + + if cfg.allow_quic_0rtt { + crypto.max_early_data_size = u32::MAX; + } + + crypto.alpn_protocols = cfg.alpn_protocols; + + let mut quinn_config = QuinnServerConfig::with_crypto(Arc::new(crypto)); + let transport = Arc::get_mut(&mut quinn_config.transport).unwrap(); + + let max_idle_timeout = cfg.max_idle_timeout.map(|timeout| { + IdleTimeout::try_from(timeout).unwrap_or_else(|_| IdleTimeout::from(VarInt::MAX)) + }); + + transport.max_idle_timeout(max_idle_timeout); + + match cfg.congestion_controller { + CongestionControl::Cubic => { + transport.congestion_controller_factory(Arc::new(CubicConfig::default())); + } + CongestionControl::NewReno => { + transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); + } + CongestionControl::Bbr => { + transport.congestion_controller_factory(Arc::new(BbrConfig::default())); + } + } + + let (endpoint, incoming) = + Endpoint::new(EndpointConfig::default(), Some(quinn_config), socket)?; + + Ok(Self { endpoint, incoming }) + } + + pub fn reconfigure(&mut self, cfg: ServerConfig) -> Result<(), ServerError> { + let mut crypto = RustlsServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&version::TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(cfg.certificate_chain, cfg.private_key)?; + + if cfg.allow_quic_0rtt { + crypto.max_early_data_size = u32::MAX; + } + + crypto.alpn_protocols = cfg.alpn_protocols; + + let mut quinn_config = QuinnServerConfig::with_crypto(Arc::new(crypto)); + let transport = Arc::get_mut(&mut quinn_config.transport).unwrap(); + + let max_idle_timeout = cfg.max_idle_timeout.map(|timeout| { + IdleTimeout::try_from(timeout).unwrap_or_else(|_| IdleTimeout::from(VarInt::MAX)) + }); + + transport.max_idle_timeout(max_idle_timeout); + + match cfg.congestion_controller { + CongestionControl::Cubic => { + transport.congestion_controller_factory(Arc::new(CubicConfig::default())); + } + CongestionControl::NewReno => { + transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); + } + CongestionControl::Bbr => { + transport.congestion_controller_factory(Arc::new(BbrConfig::default())); + } + } + + self.endpoint.set_server_config(Some(quinn_config)); + + Ok(()) + } + + pub fn rebind(&mut self, socket: UdpSocket) -> Result<(), ServerError> { + self.endpoint.rebind(socket)?; + Ok(()) + } + + pub async fn accept(&self) -> Connecting { + todo!() + } +} + +#[derive(Clone, Debug)] +pub struct ServerConfig { + pub certificate_chain: Vec, + pub private_key: PrivateKey, + pub alpn_protocols: Vec>, + pub allow_quic_0rtt: bool, + pub max_idle_timeout: Option, + pub congestion_controller: CongestionControl, +} + +#[derive(Error, Debug)] +pub enum ServerError { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Certificate(#[from] RustlsError), +} From 9bf7eaad9640c774b45b927d30018f0762091fb5 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 4 Aug 2022 17:48:16 +0900 Subject: [PATCH 022/103] gc packet frag only when receiving a new packet --- Cargo.toml | 4 ++-- src/client/incoming.rs | 38 ++++++++++++++------------------------ 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d836065..a87af5b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,8 +50,8 @@ all = ["protocol_marshaling", "server", "client"] protocol_marshaling = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] -server = ["protocol_marshaling", "quinn", "rustls", "thiserror"] -client = ["futures-util", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util", "tokio/macros", "tokio/time"] +server = ["protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util"] +client = ["futures-util", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util"] [dev-dependencies] tuic = { path = ".", features = ["all"] } diff --git a/src/client/incoming.rs b/src/client/incoming.rs index 9dddc50..eb2d03b 100644 --- a/src/client/incoming.rs +++ b/src/client/incoming.rs @@ -13,7 +13,7 @@ use std::{ time::{Duration, Instant}, }; use thiserror::Error; -use tokio::{io::AsyncReadExt, time}; +use tokio::io::AsyncReadExt; #[derive(Debug)] pub struct IncomingPackets { @@ -95,31 +95,21 @@ impl IncomingPackets { } } - if self.last_gc_time.elapsed() > gc_interval { - self.pkt_buf.collect_garbage(gc_timeout); - self.last_gc_time = Instant::now(); - } - - let mut gc_interval = time::interval(gc_interval); - loop { - tokio::select! { - dg = self.datagrams.next() => { - if let Some(dg) = dg { - let dg = dg.map_err(IncomingPacketsError::from_quinn_connection_error); - match process_datagram(&mut self.pkt_buf, dg).await { - Ok(Some(pkt)) => break Some(Ok(pkt)), - Ok(None) => {} - Err(err) => break Some(Err(err)), - } - } else { - break None; - } - } - _ = gc_interval.tick() => { - self.pkt_buf.collect_garbage(gc_timeout); - self.last_gc_time = Instant::now(); + if self.last_gc_time.elapsed() > gc_interval { + self.pkt_buf.collect_garbage(gc_timeout); + self.last_gc_time = Instant::now(); + } + + if let Some(dg) = self.datagrams.next().await { + let dg = dg.map_err(IncomingPacketsError::from_quinn_connection_error); + match process_datagram(&mut self.pkt_buf, dg).await { + Ok(Some(pkt)) => break Some(Ok(pkt)), + Ok(None) => {} + Err(err) => break Some(Err(err)), } + } else { + break None; } } } From ef4cc6f700b5cdd46e427b9cd5245a73dcee6d7d Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 5 Aug 2022 00:15:04 +0900 Subject: [PATCH 023/103] return `Connecting` when calling `Client::connect` --- src/client/connection.rs | 26 +++++++++++++++++---- src/client/mod.rs | 50 +++++----------------------------------- 2 files changed, 28 insertions(+), 48 deletions(-) diff --git a/src/client/connection.rs b/src/client/connection.rs index c4580d2..cb00c98 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -27,6 +27,7 @@ use tokio::io::AsyncWriteExt; pub struct Connecting { conn: QuinnConnecting, token: [u8; 32], + enable_quic_0rtt: bool, udp_relay_mode: UdpRelayMode, } @@ -34,11 +35,13 @@ impl Connecting { pub(super) fn new( conn: QuinnConnecting, token: [u8; 32], + enable_quic_0rtt: bool, udp_relay_mode: UdpRelayMode, ) -> Self { Self { conn, token, + enable_quic_0rtt, udp_relay_mode, } } @@ -49,10 +52,23 @@ impl Connecting { datagrams, uni_streams, .. - } = self - .conn - .await - .map_err(ConnectionError::from_quinn_connection_error)?; + } = if self.enable_quic_0rtt { + match self.conn.into_0rtt() { + Ok((conn, _)) => conn, + Err(conn) => { + return Err(ConnectionError::Convert0Rtt(Connecting::new( + conn, + self.token, + false, + self.udp_relay_mode, + ))) + } + } + } else { + self.conn + .await + .map_err(ConnectionError::from_quinn_connection_error)? + }; Ok(Connection::new( connection, @@ -262,6 +278,8 @@ impl Connection { #[derive(Error, Debug)] pub enum ConnectionError { + #[error("failed to convert QUIC connection into 0-RTT")] + Convert0Rtt(Connecting), #[error(transparent)] Io(#[from] IoError), #[error(transparent)] diff --git a/src/client/mod.rs b/src/client/mod.rs index a9304db..b91bff4 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -11,9 +11,7 @@ pub use self::{ use crate::{CongestionControl, UdpRelayMode}; use quinn::{ congestion::{BbrConfig, CubicConfig, NewRenoConfig}, - ClientConfig as QuinnClientConfig, ConnectError as QuinnConnectError, - ConnectionError as QuinnConnectionError, Endpoint, EndpointConfig, - NewConnection as QuinnNewConnection, + ClientConfig as QuinnClientConfig, ConnectError as QuinnConnectError, Endpoint, EndpointConfig, }; use rustls::{version, ClientConfig as RustlsClientConfig, RootCertStore}; use std::{ @@ -25,6 +23,7 @@ use std::{ }; use thiserror::Error; +#[derive(Debug)] pub struct Client { endpoint: Endpoint, enable_quic_0rtt: bool, @@ -120,51 +119,21 @@ impl Client { addr: SocketAddr, server_name: &str, token: [u8; 32], - ) -> Result<(Connection, IncomingPackets), ClientError> { + ) -> Result { let conn = self .endpoint .connect(addr, server_name) .map_err(ClientError::from_quinn_connect_error)?; - let QuinnNewConnection { - connection, - datagrams, - uni_streams, - .. - } = if self.enable_quic_0rtt { - match conn.into_0rtt() { - Ok((conn, _)) => conn, - Err(conn) => { - return Err(ClientError::Convert0Rtt(Connecting::new( - conn, - token, - self.udp_relay_mode, - ))) - } - } - } else { - conn.await - .map_err(ClientError::from_quinn_connection_error)? - }; - - Ok(Connection::new( - connection, - uni_streams, - datagrams, + Ok(Connecting::new( + conn, token, + self.enable_quic_0rtt, self.udp_relay_mode, )) } } -impl Debug for Client { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Client") - .field("endpoint", &self.endpoint) - .finish() - } -} - #[derive(Clone, Debug)] pub struct ClientConfig { pub certificates: RootCertStore, @@ -177,8 +146,6 @@ pub struct ClientConfig { #[derive(Error, Debug)] pub enum ClientError { - #[error("failed to convert QUIC connection into 0-RTT")] - Convert0Rtt(Connecting), #[error(transparent)] Io(#[from] IoError), #[error("endpoint stopping")] @@ -205,9 +172,4 @@ impl ClientError { QuinnConnectError::NoDefaultClientConfig => unreachable!(), } } - - #[inline] - fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { - Self::from(IoError::from(err)) - } } From 36d9a04e15ce43fe8d90319beab70267899b1410 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 9 Aug 2022 19:22:53 +0900 Subject: [PATCH 024/103] move streams and tasks to `common` --- Cargo.toml | 4 +- src/client/connection.rs | 62 +++++++++++------------------ src/client/incoming.rs | 53 +++++++++++++------------ src/client/mod.rs | 6 +-- src/common/mod.rs | 16 ++++++++ src/{client => common}/stream.rs | 12 +++--- src/common/task.rs | 21 ++++++++++ src/{common.rs => common/util.rs} | 37 ++--------------- src/lib.rs | 2 +- src/server/connection.rs | 66 ++++++++++++++++++++++++++++++- src/server/incoming.rs | 23 +++++++++++ src/server/mod.rs | 10 +++-- 12 files changed, 198 insertions(+), 114 deletions(-) create mode 100644 src/common/mod.rs rename src/{client => common}/stream.rs (89%) create mode 100644 src/common/task.rs rename src/{common.rs => common/util.rs} (89%) diff --git a/Cargo.toml b/Cargo.toml index a87af5b..fb33978 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ members = [ # blake3 = "1.3.*" byteorder = { version = "1.4.*", default-features = false, optional = true } bytes = { version = "1.2.*", default-features = false, optional = true } -# crossbeam-utils = { version = "0.8.*", default-features = false } +crossbeam-utils = { version = "0.8.*", default-features = false, optional = true } # env_logger = { version = "0.9.*", features = ["humantime"], default-features = false } futures-util = { version = "0.3.*", default-features = false, optional = true } # getopts = "0.2.*" @@ -50,7 +50,7 @@ all = ["protocol_marshaling", "server", "client"] protocol_marshaling = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] -server = ["protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util"] +server = ["crossbeam-utils", "futures-util", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util"] client = ["futures-util", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util"] [dev-dependencies] diff --git a/src/client/connection.rs b/src/client/connection.rs index cb00c98..0692ea6 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -1,17 +1,17 @@ -use super::{ - stream::{RecvStream, SendStream, StreamReg}, - IncomingPackets, Stream, -}; +use super::Incoming; use crate::{ - common, + common::{ + stream::{RecvStream, SendStream, StreamReg}, + util, + }, protocol::{Address, Command, Error as TuicError}, - UdpRelayMode, + Stream, UdpRelayMode, }; use bytes::{Bytes, BytesMut}; use quinn::{ Connecting as QuinnConnecting, Connection as QuinnConnection, - ConnectionError as QuinnConnectionError, Datagrams, IncomingUniStreams, - NewConnection as QuinnNewConnection, SendDatagramError as QuinnSendDatagramError, + ConnectionError as QuinnConnectionError, NewConnection as QuinnNewConnection, + SendDatagramError as QuinnSendDatagramError, }; use std::{ io::{Error as IoError, ErrorKind}, @@ -26,7 +26,6 @@ use tokio::io::AsyncWriteExt; #[derive(Debug)] pub struct Connecting { conn: QuinnConnecting, - token: [u8; 32], enable_quic_0rtt: bool, udp_relay_mode: UdpRelayMode, } @@ -34,19 +33,17 @@ pub struct Connecting { impl Connecting { pub(super) fn new( conn: QuinnConnecting, - token: [u8; 32], enable_quic_0rtt: bool, udp_relay_mode: UdpRelayMode, ) -> Self { Self { conn, - token, enable_quic_0rtt, udp_relay_mode, } } - pub async fn establish(self) -> Result<(Connection, IncomingPackets), ConnectionError> { + pub async fn establish(self) -> Result<(Connection, Incoming), ConnectionError> { let QuinnNewConnection { connection, datagrams, @@ -58,7 +55,6 @@ impl Connecting { Err(conn) => { return Err(ConnectionError::Convert0Rtt(Connecting::new( conn, - self.token, false, self.udp_relay_mode, ))) @@ -70,51 +66,41 @@ impl Connecting { .map_err(ConnectionError::from_quinn_connection_error)? }; - Ok(Connection::new( - connection, - uni_streams, - datagrams, - self.token, - self.udp_relay_mode, - )) + let stream_reg = Arc::new(Arc::new(())); + + let conn = Connection::new(connection, self.udp_relay_mode, stream_reg.clone()); + + let incoming = Incoming::new(uni_streams, datagrams, self.udp_relay_mode, stream_reg); + + Ok((conn, incoming)) } } #[derive(Debug)] pub struct Connection { conn: QuinnConnection, - token: [u8; 32], udp_relay_mode: UdpRelayMode, stream_reg: Arc, next_pkt_id: Arc, } impl Connection { - pub(super) fn new( + fn new( conn: QuinnConnection, - uni_streams: IncomingUniStreams, - datagrams: Datagrams, - token: [u8; 32], udp_relay_mode: UdpRelayMode, - ) -> (Self, IncomingPackets) { - let stream_reg = Arc::new(Arc::new(())); - - let conn = Self { + stream_reg: Arc, + ) -> Self { + Self { conn, - token, udp_relay_mode, stream_reg: stream_reg.clone(), next_pkt_id: Arc::new(AtomicU16::new(0)), - }; - - let incoming = IncomingPackets::new(uni_streams, datagrams, udp_relay_mode, stream_reg); - - (conn, incoming) + } } - pub async fn authenticate(&self) -> Result<(), ConnectionError> { + pub async fn authenticate(&self, token: [u8; 32]) -> Result<(), ConnectionError> { let mut send = self.get_send_stream().await?; - let cmd = Command::Authenticate(self.token); + let cmd = Command::Authenticate(token); cmd.write_to(&mut send).await?; send.finish().await?; Ok(()) @@ -206,7 +192,7 @@ impl Connection { }; let pkt_id = self.next_pkt_id.fetch_add(1, Ordering::SeqCst); - let mut pkts = common::split_packet(pkt, &addr, max_datagram_size); + let mut pkts = util::split_packet(pkt, &addr, max_datagram_size); let frag_total = pkts.len() as u8; let first_pkt = pkts.next().unwrap(); diff --git a/src/client/incoming.rs b/src/client/incoming.rs index eb2d03b..c05c00e 100644 --- a/src/client/incoming.rs +++ b/src/client/incoming.rs @@ -1,8 +1,11 @@ -use super::stream::{RecvStream, StreamReg}; use crate::{ - common::{PacketBuffer, PacketBufferError}, + common::{ + stream::{RecvStream, StreamReg}, + util::{PacketBuffer, PacketBufferError}, + }, protocol::{Command, Error as TuicError}, - Packet, UdpRelayMode, + task::Packet, + UdpRelayMode, }; use bytes::Bytes; use futures_util::StreamExt; @@ -16,7 +19,7 @@ use thiserror::Error; use tokio::io::AsyncReadExt; #[derive(Debug)] -pub struct IncomingPackets { +pub struct Incoming { uni_streams: IncomingUniStreams, datagrams: Datagrams, udp_relay_mode: UdpRelayMode, @@ -25,7 +28,7 @@ pub struct IncomingPackets { last_gc_time: Instant, } -impl IncomingPackets { +impl Incoming { pub(super) fn new( uni_streams: IncomingUniStreams, datagrams: Datagrams, @@ -46,7 +49,7 @@ impl IncomingPackets { &mut self, gc_interval: Duration, gc_timeout: Duration, - ) -> Option> { + ) -> Option> { match self.udp_relay_mode { UdpRelayMode::Native => self.accept_from_datagrams(gc_interval, gc_timeout).await, UdpRelayMode::Quic => self.accept_from_uni_streams().await, @@ -57,12 +60,12 @@ impl IncomingPackets { &mut self, gc_interval: Duration, gc_timeout: Duration, - ) -> Option> { + ) -> Option> { #[inline] async fn process_datagram( pkt_buf: &mut PacketBuffer, - dg: Result, - ) -> Result, IncomingPacketsError> { + dg: Result, + ) -> Result, IncomingError> { let dg = dg?; let cmd = Command::read_from(&mut dg.as_ref()).await?; let cmd_len = cmd.serialized_len(); @@ -89,7 +92,7 @@ impl IncomingPackets { Ok(None) } } - cmd => Err(IncomingPacketsError::Tuic(TuicError::InvalidCommand( + cmd => Err(IncomingError::Tuic(TuicError::InvalidCommand( cmd.as_type_code(), ))), } @@ -102,7 +105,7 @@ impl IncomingPackets { } if let Some(dg) = self.datagrams.next().await { - let dg = dg.map_err(IncomingPacketsError::from_quinn_connection_error); + let dg = dg.map_err(IncomingError::from_quinn_connection_error); match process_datagram(&mut self.pkt_buf, dg).await { Ok(Some(pkt)) => break Some(Ok(pkt)), Ok(None) => {} @@ -114,11 +117,11 @@ impl IncomingPackets { } } - async fn accept_from_uni_streams(&mut self) -> Option> { + async fn accept_from_uni_streams(&mut self) -> Option> { #[inline] async fn process_uni_stream( - recv: Result, - ) -> Result { + recv: Result, + ) -> Result { let mut recv = recv?; let cmd = Command::read_from(&mut recv).await?; @@ -132,11 +135,11 @@ impl IncomingPackets { addr, } => { if frag_id != 0 || frag_total != 1 { - return Err(IncomingPacketsError::BadFragment); + return Err(IncomingError::BadFragment); } if addr.is_none() { - return Err(IncomingPacketsError::NoAddress); + return Err(IncomingError::NoAddress); } let mut buf = vec![0; len as usize]; @@ -145,7 +148,7 @@ impl IncomingPackets { Ok(Packet::new(assoc_id, pkt_id, addr.unwrap(), pkt)) } - _ => Err(IncomingPacketsError::Tuic(TuicError::InvalidCommand( + _ => Err(IncomingError::Tuic(TuicError::InvalidCommand( cmd.as_type_code(), ))), } @@ -154,7 +157,7 @@ impl IncomingPackets { if let Some(recv) = self.uni_streams.next().await { let recv = recv .map(|recv| RecvStream::new(recv, self.stream_reg.as_ref().clone())) - .map_err(IncomingPacketsError::from_quinn_connection_error); + .map_err(IncomingError::from_quinn_connection_error); Some(process_uni_stream(recv).await) } else { None @@ -163,7 +166,7 @@ impl IncomingPackets { } #[derive(Error, Debug)] -pub enum IncomingPacketsError { +pub enum IncomingError { #[error(transparent)] Io(#[from] IoError), #[error(transparent)] @@ -176,14 +179,14 @@ pub enum IncomingPacketsError { UnexpectedAddress, } -impl IncomingPacketsError { +impl IncomingError { #[inline] fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { Self::Io(IoError::from(err)) } } -impl From for IncomingPacketsError { +impl From for IncomingError { #[inline] fn from(err: PacketBufferError) -> Self { match err { @@ -194,7 +197,7 @@ impl From for IncomingPacketsError { } } -impl From for IncomingPacketsError { +impl From for IncomingError { #[inline] fn from(err: TuicError) -> Self { match err { @@ -204,11 +207,11 @@ impl From for IncomingPacketsError { } } -impl From for IoError { +impl From for IoError { #[inline] - fn from(err: IncomingPacketsError) -> Self { + fn from(err: IncomingError) -> Self { match err { - IncomingPacketsError::Io(err) => Self::from(err), + IncomingError::Io(err) => Self::from(err), err => Self::new(ErrorKind::Other, err), } } diff --git a/src/client/mod.rs b/src/client/mod.rs index b91bff4..e85df97 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,11 +1,9 @@ mod connection; mod incoming; -mod stream; pub use self::{ connection::{Connecting, Connection, ConnectionError}, - incoming::{IncomingPackets, IncomingPacketsError}, - stream::Stream, + incoming::{Incoming, IncomingError}, }; use crate::{CongestionControl, UdpRelayMode}; @@ -118,7 +116,6 @@ impl Client { &self, addr: SocketAddr, server_name: &str, - token: [u8; 32], ) -> Result { let conn = self .endpoint @@ -127,7 +124,6 @@ impl Client { Ok(Connecting::new( conn, - token, self.enable_quic_0rtt, self.udp_relay_mode, )) diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..fb469ef --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1,16 @@ +pub(crate) mod stream; +pub mod task; +pub(crate) mod util; + +#[derive(Clone, Copy, Debug)] +pub enum CongestionControl { + Cubic, + NewReno, + Bbr, +} + +#[derive(Clone, Copy, Debug)] +pub enum UdpRelayMode { + Native, + Quic, +} diff --git a/src/client/stream.rs b/src/common/stream.rs similarity index 89% rename from src/client/stream.rs rename to src/common/stream.rs index 35fa9c6..2b68400 100644 --- a/src/client/stream.rs +++ b/src/common/stream.rs @@ -7,12 +7,12 @@ use std::{ }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -pub(super) type StreamReg = Arc<()>; +pub(crate) type StreamReg = Arc<()>; -pub(super) struct SendStream(QuinnSendStream, StreamReg); +pub(crate) struct SendStream(QuinnSendStream, StreamReg); impl SendStream { - pub(super) fn new(send: QuinnSendStream, reg: StreamReg) -> Self { + pub(crate) fn new(send: QuinnSendStream, reg: StreamReg) -> Self { Self(send, reg) } @@ -23,10 +23,10 @@ impl SendStream { } } -pub(super) struct RecvStream(QuinnRecvStream, StreamReg); +pub(crate) struct RecvStream(QuinnRecvStream, StreamReg); impl RecvStream { - pub(super) fn new(recv: QuinnRecvStream, reg: StreamReg) -> Self { + pub(crate) fn new(recv: QuinnRecvStream, reg: StreamReg) -> Self { Self(recv, reg) } } @@ -34,7 +34,7 @@ impl RecvStream { pub struct Stream(SendStream, RecvStream); impl Stream { - pub(super) fn new(send: SendStream, recv: RecvStream) -> Self { + pub(crate) fn new(send: SendStream, recv: RecvStream) -> Self { Self(send, recv) } diff --git a/src/common/task.rs b/src/common/task.rs new file mode 100644 index 0000000..b726da4 --- /dev/null +++ b/src/common/task.rs @@ -0,0 +1,21 @@ +use crate::protocol::Address; +use bytes::Bytes; + +#[derive(Clone, Debug)] +pub struct Packet { + pub id: u16, + pub associate_id: u32, + pub address: Address, + pub data: Bytes, +} + +impl Packet { + pub(crate) fn new(assoc_id: u32, pkt_id: u16, addr: Address, pkt: Bytes) -> Self { + Self { + id: pkt_id, + associate_id: assoc_id, + address: addr, + data: pkt, + } + } +} diff --git a/src/common.rs b/src/common/util.rs similarity index 89% rename from src/common.rs rename to src/common/util.rs index 536592a..b136294 100644 --- a/src/common.rs +++ b/src/common/util.rs @@ -1,4 +1,7 @@ -use crate::protocol::{Address, Command}; +use crate::{ + protocol::{Address, Command}, + task::Packet, +}; use bytes::{Bytes, BytesMut}; use std::{ collections::{hash_map::Entry, HashMap}, @@ -6,38 +9,6 @@ use std::{ }; use thiserror::Error; -#[derive(Clone, Copy, Debug)] -pub enum CongestionControl { - Cubic, - NewReno, - Bbr, -} - -#[derive(Clone, Copy, Debug)] -pub enum UdpRelayMode { - Native, - Quic, -} - -#[derive(Clone, Debug)] -pub struct Packet { - pub id: u16, - pub associate_id: u32, - pub address: Address, - pub data: Bytes, -} - -impl Packet { - pub(crate) fn new(assoc_id: u32, pkt_id: u16, addr: Address, pkt: Bytes) -> Self { - Self { - id: pkt_id, - associate_id: assoc_id, - address: addr, - data: pkt, - } - } -} - #[derive(Debug)] pub(crate) struct PacketBuffer(HashMap); diff --git a/src/lib.rs b/src/lib.rs index 1370efc..baf714b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ pub mod server; pub mod client; #[cfg(any(feature = "server", feature = "client"))] -pub use crate::common::{CongestionControl, Packet, UdpRelayMode}; +pub use crate::common::{stream::Stream, task, CongestionControl, UdpRelayMode}; #[cfg(feature = "client")] pub use crate::client::{Client, ClientConfig, ClientError}; diff --git a/src/server/connection.rs b/src/server/connection.rs index 810e86f..6a36996 100644 --- a/src/server/connection.rs +++ b/src/server/connection.rs @@ -1 +1,65 @@ -pub struct Connecting; +use super::IncomingTasks; +use crate::UdpRelayMode; +use crossbeam_utils::atomic::AtomicCell; +use quinn::{ + Connecting as QuinnConnecting, Connection as QuinnConnection, + ConnectionError as QuinnConnectionError, NewConnection as QuinnNewConnection, +}; +use std::{io::Error as IoError, sync::Arc}; +use thiserror::Error; + +pub struct Connecting { + conn: QuinnConnecting, +} + +impl Connecting { + pub(super) fn new(conn: QuinnConnecting) -> Self { + Self { conn } + } + + pub async fn establish(self) -> Result<(Connection, IncomingTasks), ConnectionError> { + let QuinnNewConnection { + connection, + datagrams, + uni_streams, + .. + } = self + .conn + .await + .map_err(ConnectionError::from_quinn_connection_error)?; + + let udp_relay_mode = Arc::new(AtomicCell::new(None)); + + let conn = Connection::new(connection, udp_relay_mode.clone()); + let incoming = IncomingTasks::new(uni_streams, datagrams, udp_relay_mode); + + Ok((conn, incoming)) + } +} + +pub struct Connection { + conn: QuinnConnection, + udp_relay_mode: Arc>>, +} + +impl Connection { + fn new(conn: QuinnConnection, udp_relay_mode: Arc>>) -> Self { + Self { + conn, + udp_relay_mode, + } + } +} + +#[derive(Error, Debug)] +pub enum ConnectionError { + #[error(transparent)] + Io(#[from] IoError), +} + +impl ConnectionError { + #[inline] + fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { + Self::Io(IoError::from(err)) + } +} diff --git a/src/server/incoming.rs b/src/server/incoming.rs index 8b13789..cd4aadf 100644 --- a/src/server/incoming.rs +++ b/src/server/incoming.rs @@ -1 +1,24 @@ +use crate::UdpRelayMode; +use crossbeam_utils::atomic::AtomicCell; +use quinn::{Datagrams, IncomingUniStreams}; +use std::sync::Arc; +pub struct IncomingTasks { + uni_streams: IncomingUniStreams, + datagrams: Datagrams, + udp_relay_mode: Arc>>, +} + +impl IncomingTasks { + pub(super) fn new( + uni_streams: IncomingUniStreams, + datagrams: Datagrams, + udp_relay_mode: Arc>>, + ) -> Self { + Self { + uni_streams, + datagrams, + udp_relay_mode, + } + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 4eed924..a044ef8 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,9 +1,13 @@ mod connection; mod incoming; -pub use self::connection::Connecting; +pub use self::{ + connection::{Connecting, Connection, ConnectionError}, + incoming::IncomingTasks, +}; use crate::CongestionControl; +use futures_util::StreamExt; use quinn::{ congestion::{BbrConfig, CubicConfig, NewRenoConfig}, Endpoint, EndpointConfig, IdleTimeout, Incoming, ServerConfig as QuinnServerConfig, VarInt, @@ -109,8 +113,8 @@ impl Server { Ok(()) } - pub async fn accept(&self) -> Connecting { - todo!() + pub async fn accept(&mut self) -> Option { + self.incoming.next().await.map(Connecting::new) } } From eeedd829dcc8ebdfab3d997d1493cec87bebcb07 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 14 Aug 2022 12:29:34 +0900 Subject: [PATCH 025/103] expose packet buffer gc handler --- Cargo.toml | 6 +- src/client/incoming.rs | 255 ++++++++++++++++++++++++----------------- src/common/mod.rs | 5 +- src/common/stream.rs | 3 + src/common/util.rs | 31 +++-- src/lib.rs | 4 +- 6 files changed, 188 insertions(+), 116 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fb33978..cc09329 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,7 @@ futures-util = { version = "0.3.*", default-features = false, optional = true } # getopts = "0.2.*" # log = { version = "0.4.*", features = ["serde", "std"] } # once_cell = { version = "1.13.*", features = ["parking_lot"] } -# parking_lot = "0.12.*" +parking_lot = { version = "0.12.*", default-features = false, optional = true } quinn = { version = "0.8.*", features = ["tls-rustls"], default-features = false, optional = true } # rand = "0.8.*" rustls = { version = "0.20.*", default-features = false, optional = true } @@ -50,8 +50,8 @@ all = ["protocol_marshaling", "server", "client"] protocol_marshaling = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] -server = ["crossbeam-utils", "futures-util", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util"] -client = ["futures-util", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util"] +server = ["crossbeam-utils", "futures-util", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util", "tokio/macros"] +client = ["futures-util", "parking_lot", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util", "tokio/macros"] [dev-dependencies] tuic = { path = ".", features = ["all"] } diff --git a/src/client/incoming.rs b/src/client/incoming.rs index c05c00e..e6907a0 100644 --- a/src/client/incoming.rs +++ b/src/client/incoming.rs @@ -5,15 +5,17 @@ use crate::{ }, protocol::{Command, Error as TuicError}, task::Packet, - UdpRelayMode, + PacketBufferGcHandle, UdpRelayMode, }; use bytes::Bytes; use futures_util::StreamExt; -use quinn::{ConnectionError as QuinnConnectionError, Datagrams, IncomingUniStreams}; +use quinn::{ + ConnectionError as QuinnConnectionError, Datagrams, IncomingUniStreams, + RecvStream as QuinnRecvStream, +}; use std::{ io::{Error as IoError, ErrorKind}, sync::Arc, - time::{Duration, Instant}, }; use thiserror::Error; use tokio::io::AsyncReadExt; @@ -25,7 +27,6 @@ pub struct Incoming { udp_relay_mode: UdpRelayMode, stream_reg: Arc, pkt_buf: PacketBuffer, - last_gc_time: Instant, } impl Incoming { @@ -41,136 +42,184 @@ impl Incoming { udp_relay_mode, stream_reg, pkt_buf: PacketBuffer::new(), - last_gc_time: Instant::now(), } } - pub async fn accept( - &mut self, - gc_interval: Duration, - gc_timeout: Duration, - ) -> Option> { - match self.udp_relay_mode { - UdpRelayMode::Native => self.accept_from_datagrams(gc_interval, gc_timeout).await, - UdpRelayMode::Quic => self.accept_from_uni_streams().await, + pub async fn accept(&mut self) -> Option> { + let handle_raw_datagram = |dg: Result| { + dg.map(|dg| { + PendingTask::new( + TaskSource::Datagram(dg), + self.udp_relay_mode, + self.pkt_buf.clone(), + ) + }) + .map_err(IncomingError::from_quinn_connection_error) + }; + + let handle_raw_uni_stream = |recv: Result| { + recv.map(|recv| { + PendingTask::new( + TaskSource::RecvStream(RecvStream::new(recv, self.stream_reg.as_ref().clone())), + self.udp_relay_mode, + self.pkt_buf.clone(), + ) + }) + .map_err(IncomingError::from_quinn_connection_error) + }; + + tokio::select! { + dg = self.datagrams.next() => if let Some(dg) = dg { + Some(handle_raw_datagram(dg)) + } else { + self.uni_streams.next().await.map(handle_raw_uni_stream) + }, + recv = self.uni_streams.next() => if let Some(recv) = recv { + Some(handle_raw_uni_stream(recv)) + } else { + self.datagrams.next().await.map(handle_raw_datagram) + }, } } - async fn accept_from_datagrams( - &mut self, - gc_interval: Duration, - gc_timeout: Duration, - ) -> Option> { - #[inline] - async fn process_datagram( - pkt_buf: &mut PacketBuffer, - dg: Result, - ) -> Result, IncomingError> { - let dg = dg?; - let cmd = Command::read_from(&mut dg.as_ref()).await?; - let cmd_len = cmd.serialized_len(); + pub fn get_packet_buffer_gc_handler(&self) -> PacketBufferGcHandle { + self.pkt_buf.get_gc_handler() + } +} - match cmd { - Command::Packet { +#[derive(Debug)] +pub struct PendingTask { + source: TaskSource, + udp_relay_mode: UdpRelayMode, + pkt_buf: PacketBuffer, +} + +impl PendingTask { + fn new(source: TaskSource, udp_relay_mode: UdpRelayMode, pkt_buf: PacketBuffer) -> Self { + Self { + source, + udp_relay_mode, + pkt_buf, + } + } + + pub async fn parse(self) -> Result { + match self.source { + TaskSource::Datagram(dg) => { + Self::parse_datagram(dg, self.udp_relay_mode, self.pkt_buf).await + } + TaskSource::RecvStream(recv) => { + Self::parse_recv_stream(recv, self.udp_relay_mode).await + } + } + } + + #[inline] + async fn parse_datagram( + dg: Bytes, + udp_relay_mode: UdpRelayMode, + mut pkt_buf: PacketBuffer, + ) -> Result { + let cmd = Command::read_from(&mut dg.as_ref()).await?; + let cmd_len = cmd.serialized_len(); + + match cmd { + Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } => { + if !matches!(udp_relay_mode, UdpRelayMode::Native) { + return Err(IncomingError::BadCommand(Command::TYPE_PACKET, "datagram")); + } + + if let Some(pkt) = pkt_buf.insert( assoc_id, pkt_id, frag_total, frag_id, - len, addr, - } => { - if let Some(pkt) = pkt_buf.insert( - assoc_id, - pkt_id, - frag_total, - frag_id, - addr, - dg.slice(cmd_len..cmd_len + len as usize), - )? { - Ok(Some(pkt)) - } else { - Ok(None) - } + dg.slice(cmd_len..cmd_len + len as usize), + )? { + Ok(Task::Packet(Some(pkt))) + } else { + Ok(Task::Packet(None)) } - cmd => Err(IncomingError::Tuic(TuicError::InvalidCommand( - cmd.as_type_code(), - ))), - } - } - - loop { - if self.last_gc_time.elapsed() > gc_interval { - self.pkt_buf.collect_garbage(gc_timeout); - self.last_gc_time = Instant::now(); - } - - if let Some(dg) = self.datagrams.next().await { - let dg = dg.map_err(IncomingError::from_quinn_connection_error); - match process_datagram(&mut self.pkt_buf, dg).await { - Ok(Some(pkt)) => break Some(Ok(pkt)), - Ok(None) => {} - Err(err) => break Some(Err(err)), - } - } else { - break None; } + cmd => Err(IncomingError::BadCommand(cmd.as_type_code(), "datagram")), } } - async fn accept_from_uni_streams(&mut self) -> Option> { - #[inline] - async fn process_uni_stream( - recv: Result, - ) -> Result { - let mut recv = recv?; - let cmd = Command::read_from(&mut recv).await?; + #[inline] + async fn parse_recv_stream( + mut recv: RecvStream, + udp_relay_mode: UdpRelayMode, + ) -> Result { + let cmd = Command::read_from(&mut recv).await?; - match cmd { - Command::Packet { + match cmd { + Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } => { + if !matches!(udp_relay_mode, UdpRelayMode::Quic) { + return Err(IncomingError::BadCommand( + Command::TYPE_PACKET, + "uni_stream", + )); + } + + if frag_id != 0 || frag_total != 1 { + return Err(IncomingError::BadFragment); + } + + if addr.is_none() { + return Err(IncomingError::NoAddress); + } + + let mut buf = vec![0; len as usize]; + recv.read_exact(&mut buf).await?; + let pkt = Bytes::from(buf); + + Ok(Task::Packet(Some(Packet::new( assoc_id, pkt_id, - frag_total, - frag_id, - len, - addr, - } => { - if frag_id != 0 || frag_total != 1 { - return Err(IncomingError::BadFragment); - } - - if addr.is_none() { - return Err(IncomingError::NoAddress); - } - - let mut buf = vec![0; len as usize]; - recv.read_exact(&mut buf).await?; - let pkt = Bytes::from(buf); - - Ok(Packet::new(assoc_id, pkt_id, addr.unwrap(), pkt)) - } - _ => Err(IncomingError::Tuic(TuicError::InvalidCommand( - cmd.as_type_code(), - ))), + addr.unwrap(), + pkt, + )))) } - } - - if let Some(recv) = self.uni_streams.next().await { - let recv = recv - .map(|recv| RecvStream::new(recv, self.stream_reg.as_ref().clone())) - .map_err(IncomingError::from_quinn_connection_error); - Some(process_uni_stream(recv).await) - } else { - None + _ => Err(IncomingError::BadCommand(cmd.as_type_code(), "uni_stream")), } } } +#[derive(Debug)] +#[non_exhaustive] +pub enum Task { + Packet(Option), +} + +#[derive(Debug)] +enum TaskSource { + Datagram(Bytes), + RecvStream(RecvStream), +} + #[derive(Error, Debug)] pub enum IncomingError { #[error(transparent)] Io(#[from] IoError), #[error(transparent)] Tuic(TuicError), + #[error("received bad command {0:#x} from {1}")] + BadCommand(u8, &'static str), #[error("received bad-fragmented packet")] BadFragment, #[error("missing address in packet with frag_id 0")] diff --git a/src/common/mod.rs b/src/common/mod.rs index fb469ef..178a421 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,7 +1,10 @@ pub(crate) mod stream; -pub mod task; pub(crate) mod util; +pub mod task; + +pub use self::util::PacketBufferGcHandle; + #[derive(Clone, Copy, Debug)] pub enum CongestionControl { Cubic, diff --git a/src/common/stream.rs b/src/common/stream.rs index 2b68400..9f103fa 100644 --- a/src/common/stream.rs +++ b/src/common/stream.rs @@ -9,6 +9,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pub(crate) type StreamReg = Arc<()>; +#[derive(Debug)] pub(crate) struct SendStream(QuinnSendStream, StreamReg); impl SendStream { @@ -23,6 +24,7 @@ impl SendStream { } } +#[derive(Debug)] pub(crate) struct RecvStream(QuinnRecvStream, StreamReg); impl RecvStream { @@ -31,6 +33,7 @@ impl RecvStream { } } +#[derive(Debug)] pub struct Stream(SendStream, RecvStream); impl Stream { diff --git a/src/common/util.rs b/src/common/util.rs index b136294..5046bec 100644 --- a/src/common/util.rs +++ b/src/common/util.rs @@ -3,18 +3,20 @@ use crate::{ task::Packet, }; use bytes::{Bytes, BytesMut}; +use parking_lot::Mutex; use std::{ collections::{hash_map::Entry, HashMap}, + sync::Arc, time::{Duration, Instant}, }; use thiserror::Error; -#[derive(Debug)] -pub(crate) struct PacketBuffer(HashMap); +#[derive(Clone, Debug)] +pub(crate) struct PacketBuffer(Arc>>); impl PacketBuffer { pub(crate) fn new() -> Self { - Self(HashMap::new()) + Self(Arc::new(Mutex::new(HashMap::new()))) } pub(crate) fn insert( @@ -26,19 +28,20 @@ impl PacketBuffer { addr: Option
, pkt: Bytes, ) -> Result, PacketBufferError> { + let mut pkt_buf = self.0.lock(); let key = PacketBufferKey { assoc_id, pkt_id }; if frag_id == 0 && addr.is_none() { - self.0.remove(&key); + pkt_buf.remove(&key); return Err(PacketBufferError::NoAddress); } if frag_id != 0 && addr.is_some() { - self.0.remove(&key); + pkt_buf.remove(&key); return Err(PacketBufferError::UnexpectedAddress); } - match self.0.entry(key) { + match pkt_buf.entry(key) { Entry::Occupied(mut entry) => { let v = entry.get_mut(); @@ -99,8 +102,20 @@ impl PacketBuffer { } } - pub(crate) fn collect_garbage(&mut self, timeout: Duration) { - self.0.retain(|_, v| v.c_time.elapsed() < timeout); + pub(crate) fn get_gc_handler(&self) -> PacketBufferGcHandle { + PacketBufferGcHandle(self.clone()) + } + + fn collect_garbage(&self, timeout: Duration) { + self.0.lock().retain(|_, v| v.c_time.elapsed() < timeout); + } +} + +pub struct PacketBufferGcHandle(PacketBuffer); + +impl PacketBufferGcHandle { + pub fn collect_garbage(&self, timeout: Duration) { + self.0.collect_garbage(timeout) } } diff --git a/src/lib.rs b/src/lib.rs index baf714b..724215c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,9 @@ pub mod server; pub mod client; #[cfg(any(feature = "server", feature = "client"))] -pub use crate::common::{stream::Stream, task, CongestionControl, UdpRelayMode}; +pub use crate::common::{ + stream::Stream, task, CongestionControl, PacketBufferGcHandle, UdpRelayMode, +}; #[cfg(feature = "client")] pub use crate::client::{Client, ClientConfig, ClientError}; From e53235dd73f051f3511f617e73c2730d7d4d3dea Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 15 Aug 2022 12:48:43 +0900 Subject: [PATCH 026/103] import `futures` instead of `futures_util` --- Cargo.toml | 6 +++--- src/client/incoming.rs | 2 +- src/server/mod.rs | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index cc09329..67a58d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ byteorder = { version = "1.4.*", default-features = false, optional = true } bytes = { version = "1.2.*", default-features = false, optional = true } crossbeam-utils = { version = "0.8.*", default-features = false, optional = true } # env_logger = { version = "0.9.*", features = ["humantime"], default-features = false } -futures-util = { version = "0.3.*", default-features = false, optional = true } +futures = { version = "0.3.*", default-features = false, optional = true } # getopts = "0.2.*" # log = { version = "0.4.*", features = ["serde", "std"] } # once_cell = { version = "1.13.*", features = ["parking_lot"] } @@ -50,8 +50,8 @@ all = ["protocol_marshaling", "server", "client"] protocol_marshaling = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] -server = ["crossbeam-utils", "futures-util", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util", "tokio/macros"] -client = ["futures-util", "parking_lot", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util", "tokio/macros"] +server = ["crossbeam-utils", "futures", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util", "tokio/macros"] +client = ["futures", "parking_lot", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util", "tokio/macros"] [dev-dependencies] tuic = { path = ".", features = ["all"] } diff --git a/src/client/incoming.rs b/src/client/incoming.rs index e6907a0..0d2eb02 100644 --- a/src/client/incoming.rs +++ b/src/client/incoming.rs @@ -8,7 +8,7 @@ use crate::{ PacketBufferGcHandle, UdpRelayMode, }; use bytes::Bytes; -use futures_util::StreamExt; +use futures::StreamExt; use quinn::{ ConnectionError as QuinnConnectionError, Datagrams, IncomingUniStreams, RecvStream as QuinnRecvStream, diff --git a/src/server/mod.rs b/src/server/mod.rs index a044ef8..f96f4f5 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -7,7 +7,7 @@ pub use self::{ }; use crate::CongestionControl; -use futures_util::StreamExt; +use futures::StreamExt; use quinn::{ congestion::{BbrConfig, CubicConfig, NewRenoConfig}, Endpoint, EndpointConfig, IdleTimeout, Incoming, ServerConfig as QuinnServerConfig, VarInt, From f0a0e887f533da641f78a5998d680eeab474bc7f Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 15 Aug 2022 13:45:05 +0900 Subject: [PATCH 027/103] generalizing task receiving --- src/common/incoming.rs | 96 ++++++++++++++++++++++++++++++++++++++++++ src/common/mod.rs | 1 + 2 files changed, 97 insertions(+) create mode 100644 src/common/incoming.rs diff --git a/src/common/incoming.rs b/src/common/incoming.rs new file mode 100644 index 0000000..c55d281 --- /dev/null +++ b/src/common/incoming.rs @@ -0,0 +1,96 @@ +use super::stream::{RecvStream, SendStream, Stream as BiStream, StreamReg}; +use bytes::Bytes; +use futures::{stream::SelectAll, Stream}; +use quinn::{ + ConnectionError as QuinnConnectionError, Datagrams, IncomingBiStreams, IncomingUniStreams, + RecvStream as QuinnRecvStream, SendStream as QuinnSendStream, +}; +use std::{ + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +pub(crate) struct IncomingTasks { + incoming: SelectAll, + stream_reg: Arc, +} + +impl IncomingTasks { + pub(crate) fn new( + bi_streams: IncomingBiStreams, + uni_streams: IncomingUniStreams, + datagrams: Datagrams, + stream_reg: Arc, + ) -> Self { + let mut incoming = SelectAll::new(); + + incoming.push(IncomingSource::BiStreams(bi_streams)); + incoming.push(IncomingSource::UniStreams(uni_streams)); + incoming.push(IncomingSource::Datagrams(datagrams)); + + Self { + incoming, + stream_reg, + } + } +} + +impl Stream for IncomingTasks { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.incoming) + .poll_next(cx) + .map_ok(|src| match src { + IncomingItem::BiStream((send, recv)) => PendingTask::BiStream(BiStream::new( + SendStream::new(send, self.stream_reg.as_ref().clone()), + RecvStream::new(recv, self.stream_reg.as_ref().clone()), + )), + IncomingItem::UniStream(recv) => { + PendingTask::UniStream(RecvStream::new(recv, self.stream_reg.as_ref().clone())) + } + IncomingItem::Datagram(datagram) => PendingTask::Datagram(datagram), + }) + } +} + +enum IncomingSource { + BiStreams(IncomingBiStreams), + UniStreams(IncomingUniStreams), + Datagrams(Datagrams), +} + +impl Stream for IncomingSource { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + IncomingSource::BiStreams(bi_streams) => Pin::new(bi_streams) + .poll_next(cx) + .map_ok(IncomingItem::BiStream), + IncomingSource::UniStreams(uni_streams) => Pin::new(uni_streams) + .poll_next(cx) + .map_ok(IncomingItem::UniStream), + IncomingSource::Datagrams(datagrams) => Pin::new(datagrams) + .poll_next(cx) + .map_ok(IncomingItem::Datagram), + } + } + + fn size_hint(&self) -> (usize, Option) { + (0, None) + } +} + +enum IncomingItem { + BiStream((QuinnSendStream, QuinnRecvStream)), + UniStream(QuinnRecvStream), + Datagram(Bytes), +} + +pub(crate) enum PendingTask { + BiStream(BiStream), + UniStream(RecvStream), + Datagram(Bytes), +} diff --git a/src/common/mod.rs b/src/common/mod.rs index 178a421..a8c44c8 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod incoming; pub(crate) mod stream; pub(crate) mod util; From 3d1a0fb4008b9e4c131071aa02c2d9b0586c9f4e Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 15 Aug 2022 23:42:44 +0900 Subject: [PATCH 028/103] move task parsing to mod `common` --- src/common/incoming.rs | 57 ++++++++++++++++++++++++++++++++++-------- src/common/task.rs | 20 ++++++++++++++- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/src/common/incoming.rs b/src/common/incoming.rs index c55d281..7a94554 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -1,4 +1,8 @@ -use super::stream::{RecvStream, SendStream, Stream as BiStream, StreamReg}; +use super::{ + stream::{RecvStream, SendStream, Stream as BiStream, StreamReg}, + task::{RawTask, RawTaskPayload}, +}; +use crate::protocol::{Command, Error as ProtocalError}; use bytes::Bytes; use futures::{stream::SelectAll, Stream}; use quinn::{ @@ -6,17 +10,19 @@ use quinn::{ RecvStream as QuinnRecvStream, SendStream as QuinnSendStream, }; use std::{ + io::Error as IoError, pin::Pin, sync::Arc, task::{Context, Poll}, }; +use thiserror::Error; -pub(crate) struct IncomingTasks { +pub(crate) struct RawIncomingTasks { incoming: SelectAll, stream_reg: Arc, } -impl IncomingTasks { +impl RawIncomingTasks { pub(crate) fn new( bi_streams: IncomingBiStreams, uni_streams: IncomingUniStreams, @@ -36,21 +42,22 @@ impl IncomingTasks { } } -impl Stream for IncomingTasks { - type Item = Result; +impl Stream for RawIncomingTasks { + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.incoming) .poll_next(cx) .map_ok(|src| match src { - IncomingItem::BiStream((send, recv)) => PendingTask::BiStream(BiStream::new( + IncomingItem::BiStream((send, recv)) => RawPendingTask::BiStream(BiStream::new( SendStream::new(send, self.stream_reg.as_ref().clone()), RecvStream::new(recv, self.stream_reg.as_ref().clone()), )), - IncomingItem::UniStream(recv) => { - PendingTask::UniStream(RecvStream::new(recv, self.stream_reg.as_ref().clone())) - } - IncomingItem::Datagram(datagram) => PendingTask::Datagram(datagram), + IncomingItem::UniStream(recv) => RawPendingTask::UniStream(RecvStream::new( + recv, + self.stream_reg.as_ref().clone(), + )), + IncomingItem::Datagram(datagram) => RawPendingTask::Datagram(datagram), }) } } @@ -89,8 +96,36 @@ enum IncomingItem { Datagram(Bytes), } -pub(crate) enum PendingTask { +pub(crate) enum RawPendingTask { BiStream(BiStream), UniStream(RecvStream), Datagram(Bytes), } + +impl RawPendingTask { + pub(crate) async fn accept(self) -> RawTask { + match self { + RawPendingTask::BiStream(mut bi_stream) => RawTask::new( + Command::read_from(&mut bi_stream).await.unwrap(), + RawTaskPayload::BiStream(bi_stream), + ), + RawPendingTask::UniStream(mut uni_stream) => RawTask::new( + Command::read_from(&mut uni_stream).await.unwrap(), + RawTaskPayload::UniStream(uni_stream), + ), + RawPendingTask::Datagram(datagram) => { + let cmd = Command::read_from(&mut datagram.as_ref()).await.unwrap(); + let payload = datagram.slice(cmd.serialized_len()..); + RawTask::new(cmd, RawTaskPayload::Datagram(payload)) + } + } + } +} + +#[derive(Error, Debug)] +pub enum IncomingError { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Protocol(ProtocalError), +} diff --git a/src/common/task.rs b/src/common/task.rs index b726da4..ec91b59 100644 --- a/src/common/task.rs +++ b/src/common/task.rs @@ -1,4 +1,5 @@ -use crate::protocol::Address; +use super::stream::{RecvStream, Stream}; +use crate::protocol::{Address, Command}; use bytes::Bytes; #[derive(Clone, Debug)] @@ -19,3 +20,20 @@ impl Packet { } } } + +pub(crate) struct RawTask { + header: Command, + payload: RawTaskPayload, +} + +impl RawTask { + pub(crate) fn new(header: Command, payload: RawTaskPayload) -> Self { + Self { header, payload } + } +} + +pub(crate) enum RawTaskPayload { + BiStream(Stream), + UniStream(RecvStream), + Datagram(Bytes), +} From f4003deba01043621f5c989353f28f54cbff15f0 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 16 Aug 2022 00:05:19 +0900 Subject: [PATCH 029/103] spliting `protocol::Error` --- src/common/incoming.rs | 47 +++++++++++++++++--------------------- src/protocol/marshaling.rs | 47 ++++++++++++++++---------------------- src/protocol/mod.rs | 15 +++++++++++- 3 files changed, 55 insertions(+), 54 deletions(-) diff --git a/src/common/incoming.rs b/src/common/incoming.rs index 7a94554..73d753b 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -2,12 +2,12 @@ use super::{ stream::{RecvStream, SendStream, Stream as BiStream, StreamReg}, task::{RawTask, RawTaskPayload}, }; -use crate::protocol::{Command, Error as ProtocalError}; +use crate::protocol::{Command, MarshalingError}; use bytes::Bytes; use futures::{stream::SelectAll, Stream}; use quinn::{ - ConnectionError as QuinnConnectionError, Datagrams, IncomingBiStreams, IncomingUniStreams, - RecvStream as QuinnRecvStream, SendStream as QuinnSendStream, + Datagrams, IncomingBiStreams, IncomingUniStreams, RecvStream as QuinnRecvStream, + SendStream as QuinnSendStream, }; use std::{ io::Error as IoError, @@ -15,7 +15,6 @@ use std::{ sync::Arc, task::{Context, Poll}, }; -use thiserror::Error; pub(crate) struct RawIncomingTasks { incoming: SelectAll, @@ -43,7 +42,7 @@ impl RawIncomingTasks { } impl Stream for RawIncomingTasks { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.incoming) @@ -59,6 +58,7 @@ impl Stream for RawIncomingTasks { )), IncomingItem::Datagram(datagram) => RawPendingTask::Datagram(datagram), }) + .map_err(IoError::from) } } @@ -69,19 +69,22 @@ enum IncomingSource { } impl Stream for IncomingSource { - type Item = Result; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.get_mut() { IncomingSource::BiStreams(bi_streams) => Pin::new(bi_streams) .poll_next(cx) - .map_ok(IncomingItem::BiStream), + .map_ok(IncomingItem::BiStream) + .map_err(IoError::from), IncomingSource::UniStreams(uni_streams) => Pin::new(uni_streams) .poll_next(cx) - .map_ok(IncomingItem::UniStream), + .map_ok(IncomingItem::UniStream) + .map_err(IoError::from), IncomingSource::Datagrams(datagrams) => Pin::new(datagrams) .poll_next(cx) - .map_ok(IncomingItem::Datagram), + .map_ok(IncomingItem::Datagram) + .map_err(IoError::from), } } @@ -103,29 +106,21 @@ pub(crate) enum RawPendingTask { } impl RawPendingTask { - pub(crate) async fn accept(self) -> RawTask { + pub(crate) async fn accept(self) -> Result { match self { - RawPendingTask::BiStream(mut bi_stream) => RawTask::new( - Command::read_from(&mut bi_stream).await.unwrap(), + RawPendingTask::BiStream(mut bi_stream) => Ok(RawTask::new( + Command::read_from(&mut bi_stream).await?, RawTaskPayload::BiStream(bi_stream), - ), - RawPendingTask::UniStream(mut uni_stream) => RawTask::new( - Command::read_from(&mut uni_stream).await.unwrap(), + )), + RawPendingTask::UniStream(mut uni_stream) => Ok(RawTask::new( + Command::read_from(&mut uni_stream).await?, RawTaskPayload::UniStream(uni_stream), - ), + )), RawPendingTask::Datagram(datagram) => { - let cmd = Command::read_from(&mut datagram.as_ref()).await.unwrap(); + let cmd = Command::read_from(&mut datagram.as_ref()).await?; let payload = datagram.slice(cmd.serialized_len()..); - RawTask::new(cmd, RawTaskPayload::Datagram(payload)) + Ok(RawTask::new(cmd, RawTaskPayload::Datagram(payload))) } } } } - -#[derive(Error, Debug)] -pub enum IncomingError { - #[error(transparent)] - Io(#[from] IoError), - #[error(transparent)] - Protocol(ProtocalError), -} diff --git a/src/protocol/marshaling.rs b/src/protocol/marshaling.rs index c86b12e..8151a01 100644 --- a/src/protocol/marshaling.rs +++ b/src/protocol/marshaling.rs @@ -1,4 +1,4 @@ -use super::{Address, Command, TUIC_PROTOCOL_VERSION}; +use super::{Address, Command, ProtocolError, TUIC_PROTOCOL_VERSION}; use byteorder::{BigEndian, ReadBytesExt}; use bytes::BufMut; use std::{ @@ -10,14 +10,16 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; impl Command { - pub async fn read_from(r: &mut R) -> Result + pub async fn read_from(r: &mut R) -> Result where R: AsyncRead + Unpin, { let ver = r.read_u8().await?; if ver != TUIC_PROTOCOL_VERSION { - return Err(Error::UnsupportedVersion(ver)); + return Err(MarshalingError::from(ProtocolError::UnsupportedVersion( + ver, + ))); } let cmd = r.read_u8().await?; @@ -28,7 +30,7 @@ impl Command { match resp { Self::RESPONSE_SUCCEEDED => Ok(Self::Response(true)), Self::RESPONSE_FAILED => Ok(Self::Response(false)), - _ => Err(Error::InvalidResponse(resp)), + _ => Err(MarshalingError::from(ProtocolError::InvalidResponse(resp))), } } Self::TYPE_AUTHENTICATE => { @@ -71,11 +73,11 @@ impl Command { Ok(Self::Dissociate { assoc_id }) } Self::TYPE_HEARTBEAT => Ok(Self::Heartbeat), - _ => Err(Error::InvalidCommand(cmd)), + _ => Err(MarshalingError::from(ProtocolError::InvalidCommand(cmd))), } } - pub async fn write_to(&self, w: &mut W) -> Result<(), Error> + pub async fn write_to(&self, w: &mut W) -> Result<(), MarshalingError> where W: AsyncWrite + Unpin, { @@ -136,7 +138,7 @@ impl Command { } impl Address { - pub async fn read_from(stream: &mut R) -> Result + pub async fn read_from(stream: &mut R) -> Result where R: AsyncRead + Unpin, { @@ -192,11 +194,13 @@ impl Address { Ok(Self::SocketAddress(SocketAddr::from((addr, port)))) } - _ => Err(Error::InvalidAddressType(addr_type)), + _ => Err(MarshalingError::from(ProtocolError::InvalidAddressType( + addr_type, + ))), } } - pub async fn write_to(&self, writer: &mut W) -> Result<(), Error> + pub async fn write_to(&self, writer: &mut W) -> Result<(), MarshalingError> where W: AsyncWrite + Unpin, { @@ -233,28 +237,17 @@ impl Address { } #[derive(Error, Debug)] -pub enum Error { +pub enum MarshalingError { #[error(transparent)] Io(#[from] IoError), - #[error("unsupported TUIC version: {0:#x}")] - UnsupportedVersion(u8), - #[error("invalid response: {0:#x}")] - InvalidResponse(u8), - #[error("invalid command: {0:#x}")] - InvalidCommand(u8), - #[error("invalid address type: {0:#x}")] - InvalidAddressType(u8), + #[error(transparent)] + Protocol(#[from] ProtocolError), #[error("invalid address encoding: {0}")] - InvalidAddressEncoding(#[from] FromUtf8Error), + InvalidEncoding(#[from] FromUtf8Error), } -impl From for IoError { - fn from(err: Error) -> Self { - let kind = match err { - Error::Io(err) => return err, - _ => IoErrorKind::Other, - }; - - Self::new(kind, err) +impl From for IoError { + fn from(err: MarshalingError) -> Self { + Self::new(IoErrorKind::Other, err) } } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index df1323f..edb3299 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -7,11 +7,12 @@ use std::{ fmt::{Display, Formatter, Result as FmtResult}, net::SocketAddr, }; +use thiserror::Error; pub const TUIC_PROTOCOL_VERSION: u8 = 0x05; #[cfg(feature = "protocol_marshaling")] -pub use self::marshaling::Error; +pub use self::marshaling::MarshalingError; /// Command /// @@ -157,3 +158,15 @@ impl Display for Address { } } } + +#[derive(Error, Debug)] +pub enum ProtocolError { + #[error("unsupported TUIC version: {0:#x}")] + UnsupportedVersion(u8), + #[error("invalid command: {0:#x}")] + InvalidCommand(u8), + #[error("invalid response: {0:#x}")] + InvalidResponse(u8), + #[error("invalid address type: {0:#x}")] + InvalidAddressType(u8), +} From ab1d0ea5993a44acb4057602f33cd3bf09285918 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 16 Aug 2022 01:50:18 +0900 Subject: [PATCH 030/103] refactoring client `IncomingTasks` --- src/client/connection.rs | 105 +++++++-------- src/client/incoming.rs | 266 ++++++------------------------------- src/client/mod.rs | 2 +- src/common/incoming.rs | 46 +------ src/common/mod.rs | 5 +- src/common/stream.rs | 4 +- src/common/task.rs | 28 +++- src/common/util.rs | 2 +- src/lib.rs | 5 +- src/protocol/marshaling.rs | 22 ++- src/protocol/mod.rs | 17 +-- 11 files changed, 141 insertions(+), 361 deletions(-) diff --git a/src/client/connection.rs b/src/client/connection.rs index 0692ea6..56c1aec 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -1,20 +1,20 @@ -use super::Incoming; +use super::IncomingTasks; use crate::{ common::{ stream::{RecvStream, SendStream, StreamReg}, util, }, - protocol::{Address, Command, Error as TuicError}, + protocol::{Address, Command, MarshalingError, ProtocolError}, Stream, UdpRelayMode, }; use bytes::{Bytes, BytesMut}; use quinn::{ Connecting as QuinnConnecting, Connection as QuinnConnection, - ConnectionError as QuinnConnectionError, NewConnection as QuinnNewConnection, - SendDatagramError as QuinnSendDatagramError, + NewConnection as QuinnNewConnection, SendDatagramError as QuinnSendDatagramError, }; use std::{ - io::{Error as IoError, ErrorKind}, + io::Error as IoError, + string::FromUtf8Error, sync::{ atomic::{AtomicU16, Ordering}, Arc, @@ -23,7 +23,6 @@ use std::{ use thiserror::Error; use tokio::io::AsyncWriteExt; -#[derive(Debug)] pub struct Connecting { conn: QuinnConnecting, enable_quic_0rtt: bool, @@ -43,40 +42,47 @@ impl Connecting { } } - pub async fn establish(self) -> Result<(Connection, Incoming), ConnectionError> { + pub async fn establish( + self, + ) -> Result, Self> { let QuinnNewConnection { connection, - datagrams, + bi_streams, uni_streams, + datagrams, .. } = if self.enable_quic_0rtt { match self.conn.into_0rtt() { Ok((conn, _)) => conn, Err(conn) => { - return Err(ConnectionError::Convert0Rtt(Connecting::new( + return Err(Self { conn, - false, - self.udp_relay_mode, - ))) + enable_quic_0rtt: false, + udp_relay_mode: self.udp_relay_mode, + }); } } } else { - self.conn - .await - .map_err(ConnectionError::from_quinn_connection_error)? + match self.conn.await { + Ok(conn) => conn, + Err(err) => return Ok(Err(ConnectionError::from(IoError::from(err)))), + } }; let stream_reg = Arc::new(Arc::new(())); - let conn = Connection::new(connection, self.udp_relay_mode, stream_reg.clone()); + let incoming = IncomingTasks::new( + bi_streams, + uni_streams, + datagrams, + self.udp_relay_mode, + stream_reg, + ); - let incoming = Incoming::new(uni_streams, datagrams, self.udp_relay_mode, stream_reg); - - Ok((conn, incoming)) + Ok(Ok((conn, incoming))) } } -#[derive(Debug)] pub struct Connection { conn: QuinnConnection, udp_relay_mode: UdpRelayMode, @@ -121,15 +127,15 @@ impl Connection { cmd.write_to(&mut stream).await?; let resp = match Command::read_from(&mut stream).await { - Ok(Command::Response(resp)) => Ok(resp), - Ok(cmd) => Err(TuicError::InvalidCommand(cmd.as_type_code())), - Err(err) => Err(err), + Ok(Command::Respond(resp)) => Ok(resp), + Ok(cmd) => Err(ConnectionError::ShouldBeRespond(cmd)), + Err(err) => Err(ConnectionError::from(err)), }; let res = match resp { Ok(true) => return Ok(Some(stream)), Ok(false) => Ok(None), - Err(err) => Err(ConnectionError::from(err)), + Err(err) => Err(err), }; stream.finish().await?; @@ -157,21 +163,13 @@ impl Connection { } async fn get_send_stream(&self) -> Result { - let send = self - .conn - .open_uni() - .await - .map_err(ConnectionError::from_quinn_connection_error)?; + let send = self.conn.open_uni().await.map_err(IoError::from)?; Ok(SendStream::new(send, self.stream_reg.as_ref().clone())) } async fn get_bi_stream(&self) -> Result { - let (send, recv) = self - .conn - .open_bi() - .await - .map_err(ConnectionError::from_quinn_connection_error)?; + let (send, recv) = self.conn.open_bi().await.map_err(IoError::from)?; let send = SendStream::new(send, self.stream_reg.as_ref().clone()); let recv = RecvStream::new(recv, self.stream_reg.as_ref().clone()); @@ -264,12 +262,20 @@ impl Connection { #[derive(Error, Debug)] pub enum ConnectionError { - #[error("failed to convert QUIC connection into 0-RTT")] - Convert0Rtt(Connecting), #[error(transparent)] Io(#[from] IoError), #[error(transparent)] - Tuic(TuicError), + Protocol(#[from] ProtocolError), + #[error("invalid address encoding: {0}")] + InvalidEncoding(#[from] FromUtf8Error), + #[error("expecting a `Respond`, got a command")] + ShouldBeRespond(Command), + #[error("unexpected incoming bi_stream")] + UnexpectedIncomingBiStream(Stream), + #[error("unexpected incoming uni_stream")] + UnexpectedIncomingUniStream(RecvStream), + #[error("unexpected incoming datagram")] + UnexpectedIncomingDatagram(Bytes), #[error("datagrams not supported by peer")] DatagramUnsupportedByPeer, #[error("datagram support disabled")] @@ -279,11 +285,6 @@ pub enum ConnectionError { } impl ConnectionError { - #[inline] - fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { - Self::Io(IoError::from(err)) - } - #[inline] fn from_quinn_send_datagram_error(err: QuinnSendDatagramError) -> Self { match err { @@ -295,22 +296,12 @@ impl ConnectionError { } } -impl From for ConnectionError { - #[inline] - fn from(err: TuicError) -> Self { +impl From for ConnectionError { + fn from(err: MarshalingError) -> Self { match err { - TuicError::Io(err) => Self::Io(err), - err => Self::Tuic(err), - } - } -} - -impl From for IoError { - #[inline] - fn from(err: ConnectionError) -> Self { - match err { - ConnectionError::Io(err) => Self::from(err), - err => Self::new(ErrorKind::Other, err), + MarshalingError::Io(err) => Self::Io(err), + MarshalingError::Protocol(err) => Self::Protocol(err), + MarshalingError::InvalidEncoding(err) => Self::InvalidEncoding(err), } } } diff --git a/src/client/incoming.rs b/src/client/incoming.rs index 0d2eb02..6d6a0e2 100644 --- a/src/client/incoming.rs +++ b/src/client/incoming.rs @@ -1,202 +1,74 @@ +use super::ConnectionError; use crate::{ - common::{ - stream::{RecvStream, StreamReg}, - util::{PacketBuffer, PacketBufferError}, - }, - protocol::{Command, Error as TuicError}, - task::Packet, - PacketBufferGcHandle, UdpRelayMode, -}; -use bytes::Bytes; -use futures::StreamExt; -use quinn::{ - ConnectionError as QuinnConnectionError, Datagrams, IncomingUniStreams, - RecvStream as QuinnRecvStream, + common::{incoming::RawIncomingTasks, stream::StreamReg, task::TaskSource, util::PacketBuffer}, + Packet, PacketBufferGcHandle, UdpRelayMode, }; +use futures::Stream; +use quinn::{Datagrams, IncomingBiStreams, IncomingUniStreams}; use std::{ - io::{Error as IoError, ErrorKind}, + pin::Pin, sync::Arc, + task::{Context, Poll}, }; -use thiserror::Error; -use tokio::io::AsyncReadExt; -#[derive(Debug)] -pub struct Incoming { - uni_streams: IncomingUniStreams, - datagrams: Datagrams, +pub struct IncomingTasks { + inner: RawIncomingTasks, udp_relay_mode: UdpRelayMode, - stream_reg: Arc, pkt_buf: PacketBuffer, } -impl Incoming { +impl IncomingTasks { pub(super) fn new( + bi_streams: IncomingBiStreams, uni_streams: IncomingUniStreams, datagrams: Datagrams, udp_relay_mode: UdpRelayMode, stream_reg: Arc, ) -> Self { Self { - uni_streams, - datagrams, + inner: RawIncomingTasks::new(bi_streams, uni_streams, datagrams, stream_reg), udp_relay_mode, - stream_reg, pkt_buf: PacketBuffer::new(), } } - pub async fn accept(&mut self) -> Option> { - let handle_raw_datagram = |dg: Result| { - dg.map(|dg| { - PendingTask::new( - TaskSource::Datagram(dg), - self.udp_relay_mode, - self.pkt_buf.clone(), - ) - }) - .map_err(IncomingError::from_quinn_connection_error) - }; - - let handle_raw_uni_stream = |recv: Result| { - recv.map(|recv| { - PendingTask::new( - TaskSource::RecvStream(RecvStream::new(recv, self.stream_reg.as_ref().clone())), - self.udp_relay_mode, - self.pkt_buf.clone(), - ) - }) - .map_err(IncomingError::from_quinn_connection_error) - }; - - tokio::select! { - dg = self.datagrams.next() => if let Some(dg) = dg { - Some(handle_raw_datagram(dg)) - } else { - self.uni_streams.next().await.map(handle_raw_uni_stream) - }, - recv = self.uni_streams.next() => if let Some(recv) = recv { - Some(handle_raw_uni_stream(recv)) - } else { - self.datagrams.next().await.map(handle_raw_datagram) - }, - } - } - pub fn get_packet_buffer_gc_handler(&self) -> PacketBufferGcHandle { self.pkt_buf.get_gc_handler() } } -#[derive(Debug)] +impl Stream for IncomingTasks { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx).map(|poll| { + poll.map(|res| match res { + Ok(source) => match (source, self.udp_relay_mode) { + (TaskSource::BiStream(stream), _) => { + Err(ConnectionError::UnexpectedIncomingBiStream(stream)) + } + (TaskSource::UniStream(stream), UdpRelayMode::Native) => { + Err(ConnectionError::UnexpectedIncomingUniStream(stream)) + } + (TaskSource::Datagram(datagram), UdpRelayMode::Quic) => { + Err(ConnectionError::UnexpectedIncomingDatagram(datagram)) + } + (source, _) => Ok(PendingTask::new(source, self.pkt_buf.clone())), + }, + Err(err) => Err(ConnectionError::from(err)), + }) + }) + } +} + pub struct PendingTask { - source: TaskSource, - udp_relay_mode: UdpRelayMode, + inner: TaskSource, pkt_buf: PacketBuffer, } impl PendingTask { - fn new(source: TaskSource, udp_relay_mode: UdpRelayMode, pkt_buf: PacketBuffer) -> Self { - Self { - source, - udp_relay_mode, - pkt_buf, - } - } - - pub async fn parse(self) -> Result { - match self.source { - TaskSource::Datagram(dg) => { - Self::parse_datagram(dg, self.udp_relay_mode, self.pkt_buf).await - } - TaskSource::RecvStream(recv) => { - Self::parse_recv_stream(recv, self.udp_relay_mode).await - } - } - } - - #[inline] - async fn parse_datagram( - dg: Bytes, - udp_relay_mode: UdpRelayMode, - mut pkt_buf: PacketBuffer, - ) -> Result { - let cmd = Command::read_from(&mut dg.as_ref()).await?; - let cmd_len = cmd.serialized_len(); - - match cmd { - Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - } => { - if !matches!(udp_relay_mode, UdpRelayMode::Native) { - return Err(IncomingError::BadCommand(Command::TYPE_PACKET, "datagram")); - } - - if let Some(pkt) = pkt_buf.insert( - assoc_id, - pkt_id, - frag_total, - frag_id, - addr, - dg.slice(cmd_len..cmd_len + len as usize), - )? { - Ok(Task::Packet(Some(pkt))) - } else { - Ok(Task::Packet(None)) - } - } - cmd => Err(IncomingError::BadCommand(cmd.as_type_code(), "datagram")), - } - } - - #[inline] - async fn parse_recv_stream( - mut recv: RecvStream, - udp_relay_mode: UdpRelayMode, - ) -> Result { - let cmd = Command::read_from(&mut recv).await?; - - match cmd { - Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - } => { - if !matches!(udp_relay_mode, UdpRelayMode::Quic) { - return Err(IncomingError::BadCommand( - Command::TYPE_PACKET, - "uni_stream", - )); - } - - if frag_id != 0 || frag_total != 1 { - return Err(IncomingError::BadFragment); - } - - if addr.is_none() { - return Err(IncomingError::NoAddress); - } - - let mut buf = vec![0; len as usize]; - recv.read_exact(&mut buf).await?; - let pkt = Bytes::from(buf); - - Ok(Task::Packet(Some(Packet::new( - assoc_id, - pkt_id, - addr.unwrap(), - pkt, - )))) - } - _ => Err(IncomingError::BadCommand(cmd.as_type_code(), "uni_stream")), - } + fn new(inner: TaskSource, pkt_buf: PacketBuffer) -> Self { + Self { inner, pkt_buf } } } @@ -205,63 +77,3 @@ impl PendingTask { pub enum Task { Packet(Option), } - -#[derive(Debug)] -enum TaskSource { - Datagram(Bytes), - RecvStream(RecvStream), -} - -#[derive(Error, Debug)] -pub enum IncomingError { - #[error(transparent)] - Io(#[from] IoError), - #[error(transparent)] - Tuic(TuicError), - #[error("received bad command {0:#x} from {1}")] - BadCommand(u8, &'static str), - #[error("received bad-fragmented packet")] - BadFragment, - #[error("missing address in packet with frag_id 0")] - NoAddress, - #[error("unexpected address in packet")] - UnexpectedAddress, -} - -impl IncomingError { - #[inline] - fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { - Self::Io(IoError::from(err)) - } -} - -impl From for IncomingError { - #[inline] - fn from(err: PacketBufferError) -> Self { - match err { - PacketBufferError::NoAddress => Self::NoAddress, - PacketBufferError::UnexpectedAddress => Self::UnexpectedAddress, - PacketBufferError::BadFragment => Self::BadFragment, - } - } -} - -impl From for IncomingError { - #[inline] - fn from(err: TuicError) -> Self { - match err { - TuicError::Io(err) => Self::Io(err), - err => Self::Tuic(err), - } - } -} - -impl From for IoError { - #[inline] - fn from(err: IncomingError) -> Self { - match err { - IncomingError::Io(err) => Self::from(err), - err => Self::new(ErrorKind::Other, err), - } - } -} diff --git a/src/client/mod.rs b/src/client/mod.rs index e85df97..a6acbdd 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -3,7 +3,7 @@ mod incoming; pub use self::{ connection::{Connecting, Connection, ConnectionError}, - incoming::{Incoming, IncomingError}, + incoming::IncomingTasks, }; use crate::{CongestionControl, UdpRelayMode}; diff --git a/src/common/incoming.rs b/src/common/incoming.rs index 73d753b..ac7c5d3 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -1,8 +1,7 @@ use super::{ stream::{RecvStream, SendStream, Stream as BiStream, StreamReg}, - task::{RawTask, RawTaskPayload}, + task::TaskSource, }; -use crate::protocol::{Command, MarshalingError}; use bytes::Bytes; use futures::{stream::SelectAll, Stream}; use quinn::{ @@ -42,21 +41,20 @@ impl RawIncomingTasks { } impl Stream for RawIncomingTasks { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.incoming) .poll_next(cx) .map_ok(|src| match src { - IncomingItem::BiStream((send, recv)) => RawPendingTask::BiStream(BiStream::new( + IncomingItem::BiStream((send, recv)) => TaskSource::BiStream(BiStream::new( SendStream::new(send, self.stream_reg.as_ref().clone()), RecvStream::new(recv, self.stream_reg.as_ref().clone()), )), - IncomingItem::UniStream(recv) => RawPendingTask::UniStream(RecvStream::new( - recv, - self.stream_reg.as_ref().clone(), - )), - IncomingItem::Datagram(datagram) => RawPendingTask::Datagram(datagram), + IncomingItem::UniStream(recv) => { + TaskSource::UniStream(RecvStream::new(recv, self.stream_reg.as_ref().clone())) + } + IncomingItem::Datagram(datagram) => TaskSource::Datagram(datagram), }) .map_err(IoError::from) } @@ -87,10 +85,6 @@ impl Stream for IncomingSource { .map_err(IoError::from), } } - - fn size_hint(&self) -> (usize, Option) { - (0, None) - } } enum IncomingItem { @@ -98,29 +92,3 @@ enum IncomingItem { UniStream(QuinnRecvStream), Datagram(Bytes), } - -pub(crate) enum RawPendingTask { - BiStream(BiStream), - UniStream(RecvStream), - Datagram(Bytes), -} - -impl RawPendingTask { - pub(crate) async fn accept(self) -> Result { - match self { - RawPendingTask::BiStream(mut bi_stream) => Ok(RawTask::new( - Command::read_from(&mut bi_stream).await?, - RawTaskPayload::BiStream(bi_stream), - )), - RawPendingTask::UniStream(mut uni_stream) => Ok(RawTask::new( - Command::read_from(&mut uni_stream).await?, - RawTaskPayload::UniStream(uni_stream), - )), - RawPendingTask::Datagram(datagram) => { - let cmd = Command::read_from(&mut datagram.as_ref()).await?; - let payload = datagram.slice(cmd.serialized_len()..); - Ok(RawTask::new(cmd, RawTaskPayload::Datagram(payload))) - } - } - } -} diff --git a/src/common/mod.rs b/src/common/mod.rs index a8c44c8..9d2dcd3 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,11 +1,8 @@ pub(crate) mod incoming; pub(crate) mod stream; +pub(crate) mod task; pub(crate) mod util; -pub mod task; - -pub use self::util::PacketBufferGcHandle; - #[derive(Clone, Copy, Debug)] pub enum CongestionControl { Cubic, diff --git a/src/common/stream.rs b/src/common/stream.rs index 9f103fa..8298cda 100644 --- a/src/common/stream.rs +++ b/src/common/stream.rs @@ -10,7 +10,7 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; pub(crate) type StreamReg = Arc<()>; #[derive(Debug)] -pub(crate) struct SendStream(QuinnSendStream, StreamReg); +pub struct SendStream(QuinnSendStream, StreamReg); impl SendStream { pub(crate) fn new(send: QuinnSendStream, reg: StreamReg) -> Self { @@ -25,7 +25,7 @@ impl SendStream { } #[derive(Debug)] -pub(crate) struct RecvStream(QuinnRecvStream, StreamReg); +pub struct RecvStream(QuinnRecvStream, StreamReg); impl RecvStream { pub(crate) fn new(recv: QuinnRecvStream, reg: StreamReg) -> Self { diff --git a/src/common/task.rs b/src/common/task.rs index ec91b59..ba96855 100644 --- a/src/common/task.rs +++ b/src/common/task.rs @@ -1,5 +1,5 @@ use super::stream::{RecvStream, Stream}; -use crate::protocol::{Address, Command}; +use crate::protocol::{Address, Command, MarshalingError}; use bytes::Bytes; #[derive(Clone, Debug)] @@ -21,6 +21,32 @@ impl Packet { } } +pub(crate) enum TaskSource { + BiStream(Stream), + UniStream(RecvStream), + Datagram(Bytes), +} + +impl TaskSource { + pub(crate) async fn accept(self) -> Result { + match self { + TaskSource::BiStream(mut bi_stream) => Ok(RawTask::new( + Command::read_from(&mut bi_stream).await?, + RawTaskPayload::BiStream(bi_stream), + )), + TaskSource::UniStream(mut uni_stream) => Ok(RawTask::new( + Command::read_from(&mut uni_stream).await?, + RawTaskPayload::UniStream(uni_stream), + )), + TaskSource::Datagram(datagram) => { + let cmd = Command::read_from(&mut datagram.as_ref()).await?; + let payload = datagram.slice(cmd.serialized_len()..); + Ok(RawTask::new(cmd, RawTaskPayload::Datagram(payload))) + } + } + } +} + pub(crate) struct RawTask { header: Command, payload: RawTaskPayload, diff --git a/src/common/util.rs b/src/common/util.rs index 5046bec..f67a00b 100644 --- a/src/common/util.rs +++ b/src/common/util.rs @@ -1,6 +1,6 @@ use crate::{ protocol::{Address, Command}, - task::Packet, + Packet, }; use bytes::{Bytes, BytesMut}; use parking_lot::Mutex; diff --git a/src/lib.rs b/src/lib.rs index 724215c..9bb9898 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,10 @@ pub mod client; #[cfg(any(feature = "server", feature = "client"))] pub use crate::common::{ - stream::Stream, task, CongestionControl, PacketBufferGcHandle, UdpRelayMode, + stream::{RecvStream, SendStream, Stream}, + task::Packet, + util::PacketBufferGcHandle, + CongestionControl, UdpRelayMode, }; #[cfg(feature = "client")] diff --git a/src/protocol/marshaling.rs b/src/protocol/marshaling.rs index 8151a01..bd39c28 100644 --- a/src/protocol/marshaling.rs +++ b/src/protocol/marshaling.rs @@ -2,7 +2,7 @@ use super::{Address, Command, ProtocolError, TUIC_PROTOCOL_VERSION}; use byteorder::{BigEndian, ReadBytesExt}; use bytes::BufMut; use std::{ - io::{Cursor, Error as IoError, ErrorKind as IoErrorKind}, + io::{Cursor, Error as IoError}, net::{Ipv4Addr, Ipv6Addr, SocketAddr}, string::FromUtf8Error, }; @@ -25,11 +25,11 @@ impl Command { let cmd = r.read_u8().await?; match cmd { - Self::TYPE_RESPONSE => { + Self::TYPE_RESPOND => { let resp = r.read_u8().await?; match resp { - Self::RESPONSE_SUCCEEDED => Ok(Self::Response(true)), - Self::RESPONSE_FAILED => Ok(Self::Response(false)), + Self::RESPONSE_SUCCEEDED => Ok(Self::Respond(true)), + Self::RESPONSE_FAILED => Ok(Self::Respond(false)), _ => Err(MarshalingError::from(ProtocolError::InvalidResponse(resp))), } } @@ -77,7 +77,7 @@ impl Command { } } - pub async fn write_to(&self, w: &mut W) -> Result<(), MarshalingError> + pub async fn write_to(&self, w: &mut W) -> Result<(), IoError> where W: AsyncWrite + Unpin, { @@ -91,8 +91,8 @@ impl Command { buf.put_u8(TUIC_PROTOCOL_VERSION); match self { - Self::Response(is_succeeded) => { - buf.put_u8(Self::TYPE_RESPONSE); + Self::Respond(is_succeeded) => { + buf.put_u8(Self::TYPE_RESPOND); if *is_succeeded { buf.put_u8(Self::RESPONSE_SUCCEEDED); } else { @@ -200,7 +200,7 @@ impl Address { } } - pub async fn write_to(&self, writer: &mut W) -> Result<(), MarshalingError> + pub async fn write_to(&self, writer: &mut W) -> Result<(), IoError> where W: AsyncWrite + Unpin, { @@ -245,9 +245,3 @@ pub enum MarshalingError { #[error("invalid address encoding: {0}")] InvalidEncoding(#[from] FromUtf8Error), } - -impl From for IoError { - fn from(err: MarshalingError) -> Self { - Self::new(IoErrorKind::Other, err) - } -} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index edb3299..e179bf1 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -31,7 +31,7 @@ pub enum Command { // +-----+ // | 1 | // +-----+ - Response(bool), + Respond(bool), // +-----+ // | TKN | @@ -81,7 +81,7 @@ pub enum Command { } impl Command { - pub const TYPE_RESPONSE: u8 = 0xff; + pub const TYPE_RESPOND: u8 = 0xff; pub const TYPE_AUTHENTICATE: u8 = 0x00; pub const TYPE_CONNECT: u8 = 0x01; pub const TYPE_PACKET: u8 = 0x02; @@ -91,20 +91,9 @@ impl Command { pub const RESPONSE_SUCCEEDED: u8 = 0x00; pub const RESPONSE_FAILED: u8 = 0xff; - pub const fn as_type_code(&self) -> u8 { - match self { - Command::Response(_) => Self::TYPE_RESPONSE, - Command::Authenticate(_) => Self::TYPE_AUTHENTICATE, - Command::Connect { .. } => Self::TYPE_CONNECT, - Command::Packet { .. } => Self::TYPE_PACKET, - Command::Dissociate { .. } => Self::TYPE_DISSOCIATE, - Command::Heartbeat => Self::TYPE_HEARTBEAT, - } - } - pub fn serialized_len(&self) -> usize { 2 + match self { - Self::Response(_) => 1, + Self::Respond(_) => 1, Self::Authenticate { .. } => 32, Self::Connect { addr } => addr.serialized_len(), Self::Packet { addr, .. } => 10 + addr.as_ref().map_or(0, |addr| addr.serialized_len()), From e3b6a318d69248f78da800ed6fbe8baf05e42aa8 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 28 Aug 2022 13:11:01 +0900 Subject: [PATCH 031/103] moving `IncomingError` to mod `common` --- src/client/connection.rs | 14 +--- src/client/incoming.rs | 41 +++++----- src/client/mod.rs | 6 +- src/common/incoming.rs | 173 ++++++++++++++++++++++++++++++++++++--- src/common/mod.rs | 1 - src/common/stream.rs | 53 ++++++------ src/common/task.rs | 65 --------------- src/common/util.rs | 16 +--- src/lib.rs | 3 +- 9 files changed, 220 insertions(+), 152 deletions(-) delete mode 100644 src/common/task.rs diff --git a/src/client/connection.rs b/src/client/connection.rs index 56c1aec..14413e2 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -5,7 +5,7 @@ use crate::{ util, }, protocol::{Address, Command, MarshalingError, ProtocolError}, - Stream, UdpRelayMode, + BiStream, UdpRelayMode, }; use bytes::{Bytes, BytesMut}; use quinn::{ @@ -120,7 +120,7 @@ impl Connection { Ok(()) } - pub async fn connect(&self, addr: Address) -> Result, ConnectionError> { + pub async fn connect(&self, addr: Address) -> Result, ConnectionError> { let mut stream = self.get_bi_stream().await?; let cmd = Command::Connect { addr }; @@ -168,13 +168,13 @@ impl Connection { Ok(SendStream::new(send, self.stream_reg.as_ref().clone())) } - async fn get_bi_stream(&self) -> Result { + async fn get_bi_stream(&self) -> Result { let (send, recv) = self.conn.open_bi().await.map_err(IoError::from)?; let send = SendStream::new(send, self.stream_reg.as_ref().clone()); let recv = RecvStream::new(recv, self.stream_reg.as_ref().clone()); - Ok(Stream::new(send, recv)) + Ok(BiStream::new(send, recv)) } fn send_packet_to_datagram( @@ -270,12 +270,6 @@ pub enum ConnectionError { InvalidEncoding(#[from] FromUtf8Error), #[error("expecting a `Respond`, got a command")] ShouldBeRespond(Command), - #[error("unexpected incoming bi_stream")] - UnexpectedIncomingBiStream(Stream), - #[error("unexpected incoming uni_stream")] - UnexpectedIncomingUniStream(RecvStream), - #[error("unexpected incoming datagram")] - UnexpectedIncomingDatagram(Bytes), #[error("datagrams not supported by peer")] DatagramUnsupportedByPeer, #[error("datagram support disabled")] diff --git a/src/client/incoming.rs b/src/client/incoming.rs index 6d6a0e2..499b4d5 100644 --- a/src/client/incoming.rs +++ b/src/client/incoming.rs @@ -1,7 +1,10 @@ -use super::ConnectionError; use crate::{ - common::{incoming::RawIncomingTasks, stream::StreamReg, task::TaskSource, util::PacketBuffer}, - Packet, PacketBufferGcHandle, UdpRelayMode, + common::{ + incoming::{IncomingError, RawIncomingTasks, RawPendingIncomingTask}, + stream::StreamReg, + util::PacketBuffer, + }, + PacketBufferGcHandle, UdpRelayMode, }; use futures::Stream; use quinn::{Datagrams, IncomingBiStreams, IncomingUniStreams}; @@ -38,42 +41,36 @@ impl IncomingTasks { } impl Stream for IncomingTasks { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_next(cx).map(|poll| { poll.map(|res| match res { Ok(source) => match (source, self.udp_relay_mode) { - (TaskSource::BiStream(stream), _) => { - Err(ConnectionError::UnexpectedIncomingBiStream(stream)) + (RawPendingIncomingTask::BiStream(stream), _) => { + Err(IncomingError::UnexpectedIncomingBiStream(stream)) } - (TaskSource::UniStream(stream), UdpRelayMode::Native) => { - Err(ConnectionError::UnexpectedIncomingUniStream(stream)) + (RawPendingIncomingTask::UniStream(stream), UdpRelayMode::Native) => { + Err(IncomingError::UnexpectedIncomingUniStream(stream)) } - (TaskSource::Datagram(datagram), UdpRelayMode::Quic) => { - Err(ConnectionError::UnexpectedIncomingDatagram(datagram)) + (RawPendingIncomingTask::Datagram(datagram), UdpRelayMode::Quic) => { + Err(IncomingError::UnexpectedIncomingDatagram(datagram)) } - (source, _) => Ok(PendingTask::new(source, self.pkt_buf.clone())), + (source, _) => Ok(PendingIncomingTask::new(source, self.pkt_buf.clone())), }, - Err(err) => Err(ConnectionError::from(err)), + Err(err) => Err(IncomingError::from(err)), }) }) } } -pub struct PendingTask { - inner: TaskSource, +pub struct PendingIncomingTask { + inner: RawPendingIncomingTask, pkt_buf: PacketBuffer, } -impl PendingTask { - fn new(inner: TaskSource, pkt_buf: PacketBuffer) -> Self { +impl PendingIncomingTask { + fn new(inner: RawPendingIncomingTask, pkt_buf: PacketBuffer) -> Self { Self { inner, pkt_buf } } } - -#[derive(Debug)] -#[non_exhaustive] -pub enum Task { - Packet(Option), -} diff --git a/src/client/mod.rs b/src/client/mod.rs index a6acbdd..cdcb18b 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -3,7 +3,7 @@ mod incoming; pub use self::{ connection::{Connecting, Connection, ConnectionError}, - incoming::IncomingTasks, + incoming::{IncomingTasks, PendingIncomingTask}, }; use crate::{CongestionControl, UdpRelayMode}; @@ -142,8 +142,8 @@ pub struct ClientConfig { #[derive(Error, Debug)] pub enum ClientError { - #[error(transparent)] - Io(#[from] IoError), + #[error("socket binding error: {0}")] + Socket(#[from] IoError), #[error("endpoint stopping")] EndpointStopping, #[error("too many connections")] diff --git a/src/common/incoming.rs b/src/common/incoming.rs index ac7c5d3..a97dff6 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -1,7 +1,5 @@ -use super::{ - stream::{RecvStream, SendStream, Stream as BiStream, StreamReg}, - task::TaskSource, -}; +use super::stream::{BiStream, RecvStream, SendStream, StreamReg}; +use crate::protocol::{Address, Command, MarshalingError, ProtocolError}; use bytes::Bytes; use futures::{stream::SelectAll, Stream}; use quinn::{ @@ -11,9 +9,11 @@ use quinn::{ use std::{ io::Error as IoError, pin::Pin, + string::FromUtf8Error, sync::Arc, task::{Context, Poll}, }; +use thiserror::Error; pub(crate) struct RawIncomingTasks { incoming: SelectAll, @@ -41,20 +41,22 @@ impl RawIncomingTasks { } impl Stream for RawIncomingTasks { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.incoming) .poll_next(cx) .map_ok(|src| match src { - IncomingItem::BiStream((send, recv)) => TaskSource::BiStream(BiStream::new( - SendStream::new(send, self.stream_reg.as_ref().clone()), - RecvStream::new(recv, self.stream_reg.as_ref().clone()), - )), - IncomingItem::UniStream(recv) => { - TaskSource::UniStream(RecvStream::new(recv, self.stream_reg.as_ref().clone())) + IncomingItem::BiStream((send, recv)) => { + RawPendingIncomingTask::BiStream(BiStream::new( + SendStream::new(send, self.stream_reg.as_ref().clone()), + RecvStream::new(recv, self.stream_reg.as_ref().clone()), + )) } - IncomingItem::Datagram(datagram) => TaskSource::Datagram(datagram), + IncomingItem::UniStream(recv) => RawPendingIncomingTask::UniStream( + RecvStream::new(recv, self.stream_reg.as_ref().clone()), + ), + IncomingItem::Datagram(datagram) => RawPendingIncomingTask::Datagram(datagram), }) .map_err(IoError::from) } @@ -92,3 +94,150 @@ enum IncomingItem { UniStream(QuinnRecvStream), Datagram(Bytes), } + +pub(crate) enum RawPendingIncomingTask { + BiStream(BiStream), + UniStream(RecvStream), + Datagram(Bytes), +} + +impl RawPendingIncomingTask { + pub(crate) async fn accept(self) -> Result { + match self { + Self::BiStream(stream) => Self::accept_from_bi_stream(stream).await, + Self::UniStream(stream) => Self::accept_from_uni_stream(stream).await, + Self::Datagram(datagram) => Self::accept_from_datagram(datagram).await, + } + } + + async fn accept_from_bi_stream(mut stream: BiStream) -> Result { + let cmd = Command::read_from(&mut stream) + .await + .map_err(IncomingError::from_marshaling_error)?; + + match cmd { + Command::Connect { addr } => Ok(RawIncomingTask::Connect { addr, stream }), + cmd => Err(IncomingError::UnexpectedCommand("bi_stream", cmd)), + } + } + + async fn accept_from_uni_stream( + mut stream: RecvStream, + ) -> Result { + let cmd = Command::read_from(&mut stream) + .await + .map_err(IncomingError::from_marshaling_error)?; + + match cmd { + Command::Authenticate(token) => Ok(RawIncomingTask::Authenticate { token }), + Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } => Ok(RawIncomingTask::PacketFromUniStream { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + payload: stream, + }), + Command::Dissociate { assoc_id } => Ok(RawIncomingTask::Dissociate { assoc_id }), + Command::Heartbeat => Ok(RawIncomingTask::Heartbeat), + cmd => Err(IncomingError::UnexpectedCommand("uni_stream", cmd)), + } + } + + async fn accept_from_datagram(datagram: Bytes) -> Result { + let cmd = Command::read_from(&mut datagram.as_ref()) + .await + .map_err(IncomingError::from_marshaling_error)?; + let payload = datagram.slice(cmd.serialized_len()..); + + match cmd { + Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } => Ok(RawIncomingTask::PacketFromDatagram { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + payload: datagram, + }), + cmd => Err(IncomingError::UnexpectedCommand("datagram", cmd)), + } + } +} + +#[non_exhaustive] +pub(crate) enum RawIncomingTask { + Authenticate { + token: [u8; 32], + }, + Connect { + addr: Address, + stream: BiStream, + }, + PacketFromDatagram { + assoc_id: u32, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + len: u16, + addr: Option
, + payload: Bytes, + }, + PacketFromUniStream { + assoc_id: u32, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + len: u16, + addr: Option
, + payload: RecvStream, + }, + Dissociate { + assoc_id: u32, + }, + Heartbeat, +} + +#[derive(Error, Debug)] +pub enum IncomingError { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Protocol(#[from] ProtocolError), + #[error("invalid address encoding: {0}")] + InvalidEncoding(#[from] FromUtf8Error), + #[error("unexpected incoming bi_stream")] + UnexpectedIncomingBiStream(BiStream), + #[error("unexpected incoming uni_stream")] + UnexpectedIncomingUniStream(RecvStream), + #[error("unexpected incoming datagram")] + UnexpectedIncomingDatagram(Bytes), + #[error("unexpected command from {0}: {1:?}")] + UnexpectedCommand(&'static str, Command), +} + +impl IncomingError { + #[inline] + pub(super) fn from_marshaling_error(err: MarshalingError) -> Self { + match err { + MarshalingError::Io(err) => Self::Io(err), + MarshalingError::Protocol(err) => Self::Protocol(err), + MarshalingError::InvalidEncoding(err) => Self::InvalidEncoding(err), + } + } +} diff --git a/src/common/mod.rs b/src/common/mod.rs index 9d2dcd3..6a3bb90 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,6 +1,5 @@ pub(crate) mod incoming; pub(crate) mod stream; -pub(crate) mod task; pub(crate) mod util; #[derive(Clone, Copy, Debug)] diff --git a/src/common/stream.rs b/src/common/stream.rs index 8298cda..49d413b 100644 --- a/src/common/stream.rs +++ b/src/common/stream.rs @@ -13,6 +13,7 @@ pub(crate) type StreamReg = Arc<()>; pub struct SendStream(QuinnSendStream, StreamReg); impl SendStream { + #[inline] pub(crate) fn new(send: QuinnSendStream, reg: StreamReg) -> Self { Self(send, reg) } @@ -24,29 +25,6 @@ impl SendStream { } } -#[derive(Debug)] -pub struct RecvStream(QuinnRecvStream, StreamReg); - -impl RecvStream { - pub(crate) fn new(recv: QuinnRecvStream, reg: StreamReg) -> Self { - Self(recv, reg) - } -} - -#[derive(Debug)] -pub struct Stream(SendStream, RecvStream); - -impl Stream { - pub(crate) fn new(send: SendStream, recv: RecvStream) -> Self { - Self(send, recv) - } - - #[inline] - pub async fn finish(&mut self) -> Result<()> { - self.0.finish().await - } -} - impl AsyncWrite for SendStream { #[inline] fn poll_write( @@ -82,6 +60,16 @@ impl AsyncWrite for SendStream { } } +#[derive(Debug)] +pub struct RecvStream(QuinnRecvStream, StreamReg); + +impl RecvStream { + #[inline] + pub(crate) fn new(recv: QuinnRecvStream, reg: StreamReg) -> Self { + Self(recv, reg) + } +} + impl AsyncRead for RecvStream { #[inline] fn poll_read( @@ -93,7 +81,22 @@ impl AsyncRead for RecvStream { } } -impl AsyncWrite for Stream { +#[derive(Debug)] +pub struct BiStream(SendStream, RecvStream); + +impl BiStream { + #[inline] + pub(crate) fn new(send: SendStream, recv: RecvStream) -> Self { + Self(send, recv) + } + + #[inline] + pub async fn finish(&mut self) -> Result<()> { + self.0.finish().await + } +} + +impl AsyncWrite for BiStream { #[inline] fn poll_write( mut self: Pin<&mut Self>, @@ -128,7 +131,7 @@ impl AsyncWrite for Stream { } } -impl AsyncRead for Stream { +impl AsyncRead for BiStream { #[inline] fn poll_read( mut self: Pin<&mut Self>, diff --git a/src/common/task.rs b/src/common/task.rs deleted file mode 100644 index ba96855..0000000 --- a/src/common/task.rs +++ /dev/null @@ -1,65 +0,0 @@ -use super::stream::{RecvStream, Stream}; -use crate::protocol::{Address, Command, MarshalingError}; -use bytes::Bytes; - -#[derive(Clone, Debug)] -pub struct Packet { - pub id: u16, - pub associate_id: u32, - pub address: Address, - pub data: Bytes, -} - -impl Packet { - pub(crate) fn new(assoc_id: u32, pkt_id: u16, addr: Address, pkt: Bytes) -> Self { - Self { - id: pkt_id, - associate_id: assoc_id, - address: addr, - data: pkt, - } - } -} - -pub(crate) enum TaskSource { - BiStream(Stream), - UniStream(RecvStream), - Datagram(Bytes), -} - -impl TaskSource { - pub(crate) async fn accept(self) -> Result { - match self { - TaskSource::BiStream(mut bi_stream) => Ok(RawTask::new( - Command::read_from(&mut bi_stream).await?, - RawTaskPayload::BiStream(bi_stream), - )), - TaskSource::UniStream(mut uni_stream) => Ok(RawTask::new( - Command::read_from(&mut uni_stream).await?, - RawTaskPayload::UniStream(uni_stream), - )), - TaskSource::Datagram(datagram) => { - let cmd = Command::read_from(&mut datagram.as_ref()).await?; - let payload = datagram.slice(cmd.serialized_len()..); - Ok(RawTask::new(cmd, RawTaskPayload::Datagram(payload))) - } - } - } -} - -pub(crate) struct RawTask { - header: Command, - payload: RawTaskPayload, -} - -impl RawTask { - pub(crate) fn new(header: Command, payload: RawTaskPayload) -> Self { - Self { header, payload } - } -} - -pub(crate) enum RawTaskPayload { - BiStream(Stream), - UniStream(RecvStream), - Datagram(Bytes), -} diff --git a/src/common/util.rs b/src/common/util.rs index f67a00b..e20adf0 100644 --- a/src/common/util.rs +++ b/src/common/util.rs @@ -1,7 +1,4 @@ -use crate::{ - protocol::{Address, Command}, - Packet, -}; +use crate::protocol::{Address, Command}; use bytes::{Bytes, BytesMut}; use parking_lot::Mutex; use std::{ @@ -27,7 +24,7 @@ impl PacketBuffer { frag_id: u8, addr: Option
, pkt: Bytes, - ) -> Result, PacketBufferError> { + ) -> Result, PacketBufferError> { let mut pkt_buf = self.0.lock(); let key = PacketBufferKey { assoc_id, pkt_id }; @@ -65,12 +62,7 @@ impl PacketBuffer { res.extend_from_slice(&pkt.unwrap()); } - Ok(Some(Packet::new( - assoc_id, - pkt_id, - v.addr.unwrap(), - res.freeze(), - ))) + Ok(Some((assoc_id, pkt_id, v.addr.unwrap(), res.freeze()))) } else { Ok(None) } @@ -81,7 +73,7 @@ impl PacketBuffer { } if frag_total == 1 { - return Ok(Some(Packet::new(assoc_id, pkt_id, addr.unwrap(), pkt))); + return Ok(Some((assoc_id, pkt_id, addr.unwrap(), pkt))); } let mut v = PacketBufferValue { diff --git a/src/lib.rs b/src/lib.rs index 9bb9898..555cd69 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,8 +11,7 @@ pub mod client; #[cfg(any(feature = "server", feature = "client"))] pub use crate::common::{ - stream::{RecvStream, SendStream, Stream}, - task::Packet, + stream::{BiStream, RecvStream, SendStream}, util::PacketBufferGcHandle, CongestionControl, UdpRelayMode, }; From 8f492cc4a12ca7b5114dc44c8ea1ade1c63180a3 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 18 Sep 2022 16:00:22 +0900 Subject: [PATCH 032/103] various refactors --- src/client/connection.rs | 4 +- src/client/incoming.rs | 33 ++++++- src/common/incoming.rs | 21 +++-- src/common/mod.rs | 2 +- src/common/{util.rs => packet.rs} | 139 ++++++++++++++++++++++++------ src/lib.rs | 2 +- 6 files changed, 162 insertions(+), 39 deletions(-) rename src/common/{util.rs => packet.rs} (59%) diff --git a/src/client/connection.rs b/src/client/connection.rs index 14413e2..b0c0d33 100644 --- a/src/client/connection.rs +++ b/src/client/connection.rs @@ -1,8 +1,8 @@ use super::IncomingTasks; use crate::{ common::{ + packet, stream::{RecvStream, SendStream, StreamReg}, - util, }, protocol::{Address, Command, MarshalingError, ProtocolError}, BiStream, UdpRelayMode, @@ -190,7 +190,7 @@ impl Connection { }; let pkt_id = self.next_pkt_id.fetch_add(1, Ordering::SeqCst); - let mut pkts = util::split_packet(pkt, &addr, max_datagram_size); + let mut pkts = packet::split_packet(pkt, &addr, max_datagram_size); let frag_total = pkts.len() as u8; let first_pkt = pkts.next().unwrap(); diff --git a/src/client/incoming.rs b/src/client/incoming.rs index 499b4d5..0b86754 100644 --- a/src/client/incoming.rs +++ b/src/client/incoming.rs @@ -1,8 +1,8 @@ use crate::{ common::{ - incoming::{IncomingError, RawIncomingTasks, RawPendingIncomingTask}, + incoming::{IncomingError, RawIncomingTask, RawIncomingTasks, RawPendingIncomingTask}, + packet::PacketBuffer, stream::StreamReg, - util::PacketBuffer, }, PacketBufferGcHandle, UdpRelayMode, }; @@ -73,4 +73,33 @@ impl PendingIncomingTask { fn new(inner: RawPendingIncomingTask, pkt_buf: PacketBuffer) -> Self { Self { inner, pkt_buf } } + + pub async fn accept(self) -> Result { + match self.inner.accept().await? { + RawIncomingTask::Authenticate { token } => todo!(), + RawIncomingTask::Connect { addr, payload } => todo!(), + RawIncomingTask::PacketFromDatagram { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + payload, + } => todo!(), + RawIncomingTask::PacketFromUniStream { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + payload, + } => todo!(), + RawIncomingTask::Dissociate { assoc_id } => todo!(), + RawIncomingTask::Heartbeat => todo!(), + } + } } + +pub enum IncomingTask {} diff --git a/src/common/incoming.rs b/src/common/incoming.rs index a97dff6..6586b5c 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -116,8 +116,11 @@ impl RawPendingIncomingTask { .map_err(IncomingError::from_marshaling_error)?; match cmd { - Command::Connect { addr } => Ok(RawIncomingTask::Connect { addr, stream }), - cmd => Err(IncomingError::UnexpectedCommand("bi_stream", cmd)), + Command::Connect { addr } => Ok(RawIncomingTask::Connect { + addr, + payload: stream, + }), + cmd => Err(IncomingError::UnexpectedCommandFromBiStream(stream, cmd)), } } @@ -148,7 +151,7 @@ impl RawPendingIncomingTask { }), Command::Dissociate { assoc_id } => Ok(RawIncomingTask::Dissociate { assoc_id }), Command::Heartbeat => Ok(RawIncomingTask::Heartbeat), - cmd => Err(IncomingError::UnexpectedCommand("uni_stream", cmd)), + cmd => Err(IncomingError::UnexpectedCommandFromUniStream(stream, cmd)), } } @@ -175,7 +178,7 @@ impl RawPendingIncomingTask { addr, payload: datagram, }), - cmd => Err(IncomingError::UnexpectedCommand("datagram", cmd)), + cmd => Err(IncomingError::UnexpectedCommandFromDatagram(payload, cmd)), } } } @@ -187,7 +190,7 @@ pub(crate) enum RawIncomingTask { }, Connect { addr: Address, - stream: BiStream, + payload: BiStream, }, PacketFromDatagram { assoc_id: u32, @@ -227,8 +230,12 @@ pub enum IncomingError { UnexpectedIncomingUniStream(RecvStream), #[error("unexpected incoming datagram")] UnexpectedIncomingDatagram(Bytes), - #[error("unexpected command from {0}: {1:?}")] - UnexpectedCommand(&'static str, Command), + #[error("unexpected command from bi_stream: {1:?}")] + UnexpectedCommandFromBiStream(BiStream, Command), + #[error("unexpected command from uni_stream: {1:?}")] + UnexpectedCommandFromUniStream(RecvStream, Command), + #[error("unexpected command from datagram: {1:?}")] + UnexpectedCommandFromDatagram(Bytes, Command), } impl IncomingError { diff --git a/src/common/mod.rs b/src/common/mod.rs index 6a3bb90..e074f73 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -1,6 +1,6 @@ pub(crate) mod incoming; +pub(crate) mod packet; pub(crate) mod stream; -pub(crate) mod util; #[derive(Clone, Copy, Debug)] pub enum CongestionControl { diff --git a/src/common/util.rs b/src/common/packet.rs similarity index 59% rename from src/common/util.rs rename to src/common/packet.rs index e20adf0..2403514 100644 --- a/src/common/util.rs +++ b/src/common/packet.rs @@ -1,4 +1,7 @@ -use crate::protocol::{Address, Command}; +use crate::{ + protocol::{Address, Command}, + RecvStream, +}; use bytes::{Bytes, BytesMut}; use parking_lot::Mutex; use std::{ @@ -8,6 +11,80 @@ use std::{ }; use thiserror::Error; +pub struct NeedAccept; +pub struct NeedAssembly; +pub struct Ready; + +pub struct Packet { + assoc_id: u32, + pkt_id: u16, + frag_id: Option, + frag_total: u8, + addr: Option
, + src: Option, + inner: Option, + _state: S, +} + +impl Packet { + fn new( + assoc_id: u32, + pkt_id: u16, + frag_id: u8, + frag_total: u8, + addr: Address, + stream: RecvStream, + ) -> Self { + Self { + assoc_id, + pkt_id, + frag_id: Some(frag_id), + frag_total, + addr: Some(addr), + inner: None, + src: Some(stream), + _state: NeedAccept, + } + } +} + +impl Packet { + fn new( + assoc_id: u32, + pkt_id: u16, + frag_id: u8, + frag_total: u8, + addr: Option
, + pkt: Bytes, + ) -> Self { + Self { + assoc_id, + pkt_id, + frag_id: Some(frag_id), + frag_total, + addr, + inner: Some(pkt), + src: None, + _state: NeedAssembly, + } + } +} + +impl Packet { + fn new(assoc_id: u32, pkt_id: u16, frag_total: u8, addr: Address, pkt: Bytes) -> Self { + Self { + assoc_id, + pkt_id, + frag_id: None, + frag_total, + addr: Some(addr), + inner: Some(pkt), + src: None, + _state: Ready, + } + } +} + #[derive(Clone, Debug)] pub(crate) struct PacketBuffer(Arc>>); @@ -18,22 +95,20 @@ impl PacketBuffer { pub(crate) fn insert( &mut self, - assoc_id: u32, - pkt_id: u16, - frag_total: u8, - frag_id: u8, - addr: Option
, - pkt: Bytes, - ) -> Result, PacketBufferError> { + pkt: Packet, + ) -> Result>, PacketBufferError> { let mut pkt_buf = self.0.lock(); - let key = PacketBufferKey { assoc_id, pkt_id }; + let key = PacketBufferKey { + assoc_id: pkt.assoc_id, + pkt_id: pkt.pkt_id, + }; - if frag_id == 0 && addr.is_none() { + if pkt.frag_id.unwrap() == 0 && pkt.addr.is_none() { pkt_buf.remove(&key); return Err(PacketBufferError::NoAddress); } - if frag_id != 0 && addr.is_some() { + if pkt.frag_id.unwrap() != 0 && pkt.addr.is_some() { pkt_buf.remove(&key); return Err(PacketBufferError::UnexpectedAddress); } @@ -42,19 +117,19 @@ impl PacketBuffer { Entry::Occupied(mut entry) => { let v = entry.get_mut(); - if frag_total == 0 - || frag_id >= frag_total - || v.buf.len() != frag_total as usize - || v.buf[frag_id as usize].is_some() + if pkt.frag_total == 0 + || pkt.frag_id.unwrap() >= pkt.frag_total + || v.buf.len() != pkt.frag_total as usize + || v.buf[pkt.frag_id.unwrap() as usize].is_some() { return Err(PacketBufferError::BadFragment); } - v.total_len += pkt.len(); - v.buf[frag_id as usize] = Some(pkt); + v.total_len += pkt.inner.as_ref().unwrap().len(); + v.buf[pkt.frag_id.unwrap() as usize] = Some(pkt.inner.unwrap()); v.recv_count += 1; - if v.recv_count == frag_total as usize { + if v.recv_count == pkt.frag_total as usize { let v = entry.remove(); let mut res = BytesMut::with_capacity(v.total_len); @@ -62,30 +137,42 @@ impl PacketBuffer { res.extend_from_slice(&pkt.unwrap()); } - Ok(Some((assoc_id, pkt_id, v.addr.unwrap(), res.freeze()))) + Ok(Some(Packet::::new( + pkt.assoc_id, + pkt.pkt_id, + pkt.frag_total, + v.addr.unwrap(), + res.freeze(), + ))) } else { Ok(None) } } Entry::Vacant(entry) => { - if frag_total == 0 || frag_id >= frag_total { + if pkt.frag_total == 0 || pkt.frag_id.unwrap() >= pkt.frag_total { return Err(PacketBufferError::BadFragment); } - if frag_total == 1 { - return Ok(Some((assoc_id, pkt_id, addr.unwrap(), pkt))); + if pkt.frag_total == 1 { + return Ok(Some(Packet::::new( + pkt.assoc_id, + pkt.pkt_id, + pkt.frag_total, + pkt.addr.unwrap(), + pkt.inner.unwrap(), + ))); } let mut v = PacketBufferValue { - buf: vec![None; frag_total as usize], - addr, + buf: vec![None; pkt.frag_total as usize], + addr: pkt.addr, recv_count: 0, total_len: 0, c_time: Instant::now(), }; - v.total_len += pkt.len(); - v.buf[frag_id as usize] = Some(pkt); + v.total_len += pkt.inner.as_ref().unwrap().len(); + v.buf[pkt.frag_id.unwrap() as usize] = Some(pkt.inner.unwrap()); v.recv_count += 1; entry.insert(v); diff --git a/src/lib.rs b/src/lib.rs index 555cd69..be7a429 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,8 +11,8 @@ pub mod client; #[cfg(any(feature = "server", feature = "client"))] pub use crate::common::{ + packet::PacketBufferGcHandle, stream::{BiStream, RecvStream, SendStream}, - util::PacketBufferGcHandle, CongestionControl, UdpRelayMode, }; From 68c5de875734312851901da195683d2b5b7b7c50 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 18 Sep 2022 16:31:05 +0900 Subject: [PATCH 033/103] adding `PacketBuffer` in `RawIncomingTasks` --- src/client/incoming.rs | 44 ++------------------ src/common/incoming.rs | 94 ++++++++++++++++-------------------------- src/common/packet.rs | 26 +++++++----- 3 files changed, 55 insertions(+), 109 deletions(-) diff --git a/src/client/incoming.rs b/src/client/incoming.rs index 0b86754..b07638d 100644 --- a/src/client/incoming.rs +++ b/src/client/incoming.rs @@ -17,7 +17,6 @@ use std::{ pub struct IncomingTasks { inner: RawIncomingTasks, udp_relay_mode: UdpRelayMode, - pkt_buf: PacketBuffer, } impl IncomingTasks { @@ -31,13 +30,8 @@ impl IncomingTasks { Self { inner: RawIncomingTasks::new(bi_streams, uni_streams, datagrams, stream_reg), udp_relay_mode, - pkt_buf: PacketBuffer::new(), } } - - pub fn get_packet_buffer_gc_handler(&self) -> PacketBufferGcHandle { - self.pkt_buf.get_gc_handler() - } } impl Stream for IncomingTasks { @@ -53,10 +47,10 @@ impl Stream for IncomingTasks { (RawPendingIncomingTask::UniStream(stream), UdpRelayMode::Native) => { Err(IncomingError::UnexpectedIncomingUniStream(stream)) } - (RawPendingIncomingTask::Datagram(datagram), UdpRelayMode::Quic) => { + (RawPendingIncomingTask::Datagram(datagram, ..), UdpRelayMode::Quic) => { Err(IncomingError::UnexpectedIncomingDatagram(datagram)) } - (source, _) => Ok(PendingIncomingTask::new(source, self.pkt_buf.clone())), + (source, _) => Ok(PendingIncomingTask(source)), }, Err(err) => Err(IncomingError::from(err)), }) @@ -64,41 +58,11 @@ impl Stream for IncomingTasks { } } -pub struct PendingIncomingTask { - inner: RawPendingIncomingTask, - pkt_buf: PacketBuffer, -} +pub struct PendingIncomingTask(RawPendingIncomingTask); impl PendingIncomingTask { - fn new(inner: RawPendingIncomingTask, pkt_buf: PacketBuffer) -> Self { - Self { inner, pkt_buf } - } - pub async fn accept(self) -> Result { - match self.inner.accept().await? { - RawIncomingTask::Authenticate { token } => todo!(), - RawIncomingTask::Connect { addr, payload } => todo!(), - RawIncomingTask::PacketFromDatagram { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - payload, - } => todo!(), - RawIncomingTask::PacketFromUniStream { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - payload, - } => todo!(), - RawIncomingTask::Dissociate { assoc_id } => todo!(), - RawIncomingTask::Heartbeat => todo!(), - } + todo!() } } diff --git a/src/common/incoming.rs b/src/common/incoming.rs index 6586b5c..cfa6672 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -1,4 +1,7 @@ -use super::stream::{BiStream, RecvStream, SendStream, StreamReg}; +use super::{ + packet::{NeedAccept, NeedAssembly, Packet, PacketBuffer}, + stream::{BiStream, RecvStream, SendStream, StreamReg}, +}; use crate::protocol::{Address, Command, MarshalingError, ProtocolError}; use bytes::Bytes; use futures::{stream::SelectAll, Stream}; @@ -18,6 +21,7 @@ use thiserror::Error; pub(crate) struct RawIncomingTasks { incoming: SelectAll, stream_reg: Arc, + pkt_buf: PacketBuffer, } impl RawIncomingTasks { @@ -36,6 +40,7 @@ impl RawIncomingTasks { Self { incoming, stream_reg, + pkt_buf: PacketBuffer::new(), } } } @@ -56,7 +61,9 @@ impl Stream for RawIncomingTasks { IncomingItem::UniStream(recv) => RawPendingIncomingTask::UniStream( RecvStream::new(recv, self.stream_reg.as_ref().clone()), ), - IncomingItem::Datagram(datagram) => RawPendingIncomingTask::Datagram(datagram), + IncomingItem::Datagram(datagram) => { + RawPendingIncomingTask::Datagram(datagram, self.pkt_buf.clone()) + } }) .map_err(IoError::from) } @@ -98,7 +105,7 @@ enum IncomingItem { pub(crate) enum RawPendingIncomingTask { BiStream(BiStream), UniStream(RecvStream), - Datagram(Bytes), + Datagram(Bytes, PacketBuffer), } impl RawPendingIncomingTask { @@ -106,7 +113,9 @@ impl RawPendingIncomingTask { match self { Self::BiStream(stream) => Self::accept_from_bi_stream(stream).await, Self::UniStream(stream) => Self::accept_from_uni_stream(stream).await, - Self::Datagram(datagram) => Self::accept_from_datagram(datagram).await, + Self::Datagram(datagram, pkt_buf) => { + Self::accept_from_datagram(datagram, pkt_buf).await + } } } @@ -116,10 +125,7 @@ impl RawPendingIncomingTask { .map_err(IncomingError::from_marshaling_error)?; match cmd { - Command::Connect { addr } => Ok(RawIncomingTask::Connect { - addr, - payload: stream, - }), + Command::Connect { addr } => Ok(RawIncomingTask::Connect(addr, stream)), cmd => Err(IncomingError::UnexpectedCommandFromBiStream(stream, cmd)), } } @@ -132,7 +138,7 @@ impl RawPendingIncomingTask { .map_err(IncomingError::from_marshaling_error)?; match cmd { - Command::Authenticate(token) => Ok(RawIncomingTask::Authenticate { token }), + Command::Authenticate(token) => Ok(RawIncomingTask::Authenticate(token)), Command::Packet { assoc_id, pkt_id, @@ -140,26 +146,23 @@ impl RawPendingIncomingTask { frag_id, len, addr, - } => Ok(RawIncomingTask::PacketFromUniStream { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - payload: stream, - }), - Command::Dissociate { assoc_id } => Ok(RawIncomingTask::Dissociate { assoc_id }), + } => Ok(RawIncomingTask::PacketFromUniStream( + Packet::::new(assoc_id, pkt_id, frag_total, frag_id, addr, stream, len), + )), + Command::Dissociate { assoc_id } => Ok(RawIncomingTask::Dissociate(assoc_id)), Command::Heartbeat => Ok(RawIncomingTask::Heartbeat), cmd => Err(IncomingError::UnexpectedCommandFromUniStream(stream, cmd)), } } - async fn accept_from_datagram(datagram: Bytes) -> Result { + async fn accept_from_datagram( + datagram: Bytes, + pkt_buf: PacketBuffer, + ) -> Result { let cmd = Command::read_from(&mut datagram.as_ref()) .await .map_err(IncomingError::from_marshaling_error)?; - let payload = datagram.slice(cmd.serialized_len()..); + let pkt = datagram.slice(cmd.serialized_len()..); match cmd { Command::Packet { @@ -169,50 +172,23 @@ impl RawPendingIncomingTask { frag_id, len, addr, - } => Ok(RawIncomingTask::PacketFromDatagram { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - payload: datagram, - }), - cmd => Err(IncomingError::UnexpectedCommandFromDatagram(payload, cmd)), + } => Ok(RawIncomingTask::PacketFromDatagram( + Packet::::new( + assoc_id, pkt_id, frag_total, frag_id, addr, pkt_buf, pkt, + ), + )), + cmd => Err(IncomingError::UnexpectedCommandFromDatagram(datagram, cmd)), } } } #[non_exhaustive] pub(crate) enum RawIncomingTask { - Authenticate { - token: [u8; 32], - }, - Connect { - addr: Address, - payload: BiStream, - }, - PacketFromDatagram { - assoc_id: u32, - pkt_id: u16, - frag_total: u8, - frag_id: u8, - len: u16, - addr: Option
, - payload: Bytes, - }, - PacketFromUniStream { - assoc_id: u32, - pkt_id: u16, - frag_total: u8, - frag_id: u8, - len: u16, - addr: Option
, - payload: RecvStream, - }, - Dissociate { - assoc_id: u32, - }, + Authenticate([u8; 32]), + Connect(Address, BiStream), + PacketFromDatagram(Packet), + PacketFromUniStream(Packet), + Dissociate(u32), Heartbeat, } diff --git a/src/common/packet.rs b/src/common/packet.rs index 2403514..1e3bdcb 100644 --- a/src/common/packet.rs +++ b/src/common/packet.rs @@ -21,40 +21,44 @@ pub struct Packet { frag_id: Option, frag_total: u8, addr: Option
, - src: Option, + src: Option<(RecvStream, u16)>, + pkt_buf: Option, inner: Option, _state: S, } impl Packet { - fn new( + pub(super) fn new( assoc_id: u32, pkt_id: u16, - frag_id: u8, frag_total: u8, - addr: Address, + frag_id: u8, + addr: Option
, stream: RecvStream, + len: u16, ) -> Self { Self { assoc_id, pkt_id, frag_id: Some(frag_id), frag_total, - addr: Some(addr), + addr: addr, + src: Some((stream, len)), + pkt_buf: None, inner: None, - src: Some(stream), _state: NeedAccept, } } } impl Packet { - fn new( + pub(super) fn new( assoc_id: u32, pkt_id: u16, - frag_id: u8, frag_total: u8, + frag_id: u8, addr: Option
, + pkt_buf: PacketBuffer, pkt: Bytes, ) -> Self { Self { @@ -63,8 +67,9 @@ impl Packet { frag_id: Some(frag_id), frag_total, addr, - inner: Some(pkt), src: None, + pkt_buf: Some(pkt_buf), + inner: Some(pkt), _state: NeedAssembly, } } @@ -78,8 +83,9 @@ impl Packet { frag_id: None, frag_total, addr: Some(addr), - inner: Some(pkt), src: None, + pkt_buf: None, + inner: Some(pkt), _state: Ready, } } From 21bc7280ead45de024da24313b8a917cbc0fe2f5 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 18 Sep 2022 16:41:34 +0900 Subject: [PATCH 034/103] remove the internal use of `Arc` in `PacketBuffer` --- src/client/incoming.rs | 2 +- src/common/incoming.rs | 8 ++++---- src/common/packet.rs | 24 +++++++++++++----------- src/lib.rs | 2 +- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/client/incoming.rs b/src/client/incoming.rs index b07638d..728f9ea 100644 --- a/src/client/incoming.rs +++ b/src/client/incoming.rs @@ -4,7 +4,7 @@ use crate::{ packet::PacketBuffer, stream::StreamReg, }, - PacketBufferGcHandle, UdpRelayMode, + PacketBufferHandle, UdpRelayMode, }; use futures::Stream; use quinn::{Datagrams, IncomingBiStreams, IncomingUniStreams}; diff --git a/src/common/incoming.rs b/src/common/incoming.rs index cfa6672..3897ed4 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -21,7 +21,7 @@ use thiserror::Error; pub(crate) struct RawIncomingTasks { incoming: SelectAll, stream_reg: Arc, - pkt_buf: PacketBuffer, + pkt_buf: Arc, } impl RawIncomingTasks { @@ -40,7 +40,7 @@ impl RawIncomingTasks { Self { incoming, stream_reg, - pkt_buf: PacketBuffer::new(), + pkt_buf: Arc::new(PacketBuffer::new()), } } } @@ -105,7 +105,7 @@ enum IncomingItem { pub(crate) enum RawPendingIncomingTask { BiStream(BiStream), UniStream(RecvStream), - Datagram(Bytes, PacketBuffer), + Datagram(Bytes, Arc), } impl RawPendingIncomingTask { @@ -157,7 +157,7 @@ impl RawPendingIncomingTask { async fn accept_from_datagram( datagram: Bytes, - pkt_buf: PacketBuffer, + pkt_buf: Arc, ) -> Result { let cmd = Command::read_from(&mut datagram.as_ref()) .await diff --git a/src/common/packet.rs b/src/common/packet.rs index 1e3bdcb..42c249d 100644 --- a/src/common/packet.rs +++ b/src/common/packet.rs @@ -22,7 +22,7 @@ pub struct Packet { frag_total: u8, addr: Option
, src: Option<(RecvStream, u16)>, - pkt_buf: Option, + pkt_buf: Option>, inner: Option, _state: S, } @@ -58,7 +58,7 @@ impl Packet { frag_total: u8, frag_id: u8, addr: Option
, - pkt_buf: PacketBuffer, + pkt_buf: Arc, pkt: Bytes, ) -> Self { Self { @@ -91,12 +91,11 @@ impl Packet { } } -#[derive(Clone, Debug)] -pub(crate) struct PacketBuffer(Arc>>); +pub(crate) struct PacketBuffer(Mutex>); impl PacketBuffer { pub(crate) fn new() -> Self { - Self(Arc::new(Mutex::new(HashMap::new()))) + Self(Mutex::new(HashMap::new())) } pub(crate) fn insert( @@ -187,18 +186,21 @@ impl PacketBuffer { } } - pub(crate) fn get_gc_handler(&self) -> PacketBufferGcHandle { - PacketBufferGcHandle(self.clone()) - } - fn collect_garbage(&self, timeout: Duration) { self.0.lock().retain(|_, v| v.c_time.elapsed() < timeout); } + + pub(crate) fn get_handler(self: Arc) -> PacketBufferHandle { + PacketBufferHandle(self.clone()) + } } -pub struct PacketBufferGcHandle(PacketBuffer); +pub struct PacketBufferHandle(Arc); -impl PacketBufferGcHandle { +impl PacketBufferHandle { + fn new(pkt_buf: Arc) -> Self { + Self(pkt_buf) + } pub fn collect_garbage(&self, timeout: Duration) { self.0.collect_garbage(timeout) } diff --git a/src/lib.rs b/src/lib.rs index be7a429..9b44173 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ pub mod client; #[cfg(any(feature = "server", feature = "client"))] pub use crate::common::{ - packet::PacketBufferGcHandle, + packet::PacketBufferHandle, stream::{BiStream, RecvStream, SendStream}, CongestionControl, UdpRelayMode, }; From b800b3f904c4c227ab6fb671f12bc112a967d34d Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 18 Sep 2022 17:03:05 +0900 Subject: [PATCH 035/103] add `len` to all `Packet` for validating --- src/common/incoming.rs | 4 ++-- src/common/packet.rs | 48 ++++++++++++++++++++++++++++++------------ src/lib.rs | 2 +- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/common/incoming.rs b/src/common/incoming.rs index 3897ed4..0bdb3c5 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -147,7 +147,7 @@ impl RawPendingIncomingTask { len, addr, } => Ok(RawIncomingTask::PacketFromUniStream( - Packet::::new(assoc_id, pkt_id, frag_total, frag_id, addr, stream, len), + Packet::::new(assoc_id, pkt_id, frag_total, frag_id, len, addr, stream), )), Command::Dissociate { assoc_id } => Ok(RawIncomingTask::Dissociate(assoc_id)), Command::Heartbeat => Ok(RawIncomingTask::Heartbeat), @@ -174,7 +174,7 @@ impl RawPendingIncomingTask { addr, } => Ok(RawIncomingTask::PacketFromDatagram( Packet::::new( - assoc_id, pkt_id, frag_total, frag_id, addr, pkt_buf, pkt, + assoc_id, pkt_id, frag_total, frag_id, len, addr, pkt_buf, pkt, ), )), cmd => Err(IncomingError::UnexpectedCommandFromDatagram(datagram, cmd)), diff --git a/src/common/packet.rs b/src/common/packet.rs index 42c249d..a9da20c 100644 --- a/src/common/packet.rs +++ b/src/common/packet.rs @@ -20,8 +20,9 @@ pub struct Packet { pkt_id: u16, frag_id: Option, frag_total: u8, + len: u16, addr: Option
, - src: Option<(RecvStream, u16)>, + src: Option, pkt_buf: Option>, inner: Option, _state: S, @@ -33,22 +34,27 @@ impl Packet { pkt_id: u16, frag_total: u8, frag_id: u8, + len: u16, addr: Option
, stream: RecvStream, - len: u16, ) -> Self { Self { assoc_id, pkt_id, frag_id: Some(frag_id), frag_total, - addr: addr, - src: Some((stream, len)), + len, + addr, + src: Some(stream), pkt_buf: None, inner: None, _state: NeedAccept, } } + + pub async fn accept(self) -> Result, PacketError> { + todo!() + } } impl Packet { @@ -57,6 +63,7 @@ impl Packet { pkt_id: u16, frag_total: u8, frag_id: u8, + len: u16, addr: Option
, pkt_buf: Arc, pkt: Bytes, @@ -66,6 +73,7 @@ impl Packet { pkt_id, frag_id: Some(frag_id), frag_total, + len, addr, src: None, pkt_buf: Some(pkt_buf), @@ -73,15 +81,27 @@ impl Packet { _state: NeedAssembly, } } + + pub fn assemble(self) -> Result, PacketError> { + todo!() + } } impl Packet { - fn new(assoc_id: u32, pkt_id: u16, frag_total: u8, addr: Address, pkt: Bytes) -> Self { + fn new( + assoc_id: u32, + pkt_id: u16, + frag_total: u8, + len: u16, + addr: Address, + pkt: Bytes, + ) -> Self { Self { assoc_id, pkt_id, frag_id: None, frag_total, + len, addr: Some(addr), src: None, pkt_buf: None, @@ -101,7 +121,7 @@ impl PacketBuffer { pub(crate) fn insert( &mut self, pkt: Packet, - ) -> Result>, PacketBufferError> { + ) -> Result>, PacketError> { let mut pkt_buf = self.0.lock(); let key = PacketBufferKey { assoc_id: pkt.assoc_id, @@ -110,12 +130,12 @@ impl PacketBuffer { if pkt.frag_id.unwrap() == 0 && pkt.addr.is_none() { pkt_buf.remove(&key); - return Err(PacketBufferError::NoAddress); + return Err(PacketError::NoAddress); } if pkt.frag_id.unwrap() != 0 && pkt.addr.is_some() { pkt_buf.remove(&key); - return Err(PacketBufferError::UnexpectedAddress); + return Err(PacketError::UnexpectedAddress); } match pkt_buf.entry(key) { @@ -127,10 +147,10 @@ impl PacketBuffer { || v.buf.len() != pkt.frag_total as usize || v.buf[pkt.frag_id.unwrap() as usize].is_some() { - return Err(PacketBufferError::BadFragment); + return Err(PacketError::BadFragment); } - v.total_len += pkt.inner.as_ref().unwrap().len(); + v.total_len += pkt.len as usize; v.buf[pkt.frag_id.unwrap() as usize] = Some(pkt.inner.unwrap()); v.recv_count += 1; @@ -146,6 +166,7 @@ impl PacketBuffer { pkt.assoc_id, pkt.pkt_id, pkt.frag_total, + pkt.len, v.addr.unwrap(), res.freeze(), ))) @@ -155,7 +176,7 @@ impl PacketBuffer { } Entry::Vacant(entry) => { if pkt.frag_total == 0 || pkt.frag_id.unwrap() >= pkt.frag_total { - return Err(PacketBufferError::BadFragment); + return Err(PacketError::BadFragment); } if pkt.frag_total == 1 { @@ -163,6 +184,7 @@ impl PacketBuffer { pkt.assoc_id, pkt.pkt_id, pkt.frag_total, + pkt.len, pkt.addr.unwrap(), pkt.inner.unwrap(), ))); @@ -176,7 +198,7 @@ impl PacketBuffer { c_time: Instant::now(), }; - v.total_len += pkt.inner.as_ref().unwrap().len(); + v.total_len += pkt.len as usize; v.buf[pkt.frag_id.unwrap() as usize] = Some(pkt.inner.unwrap()); v.recv_count += 1; entry.insert(v); @@ -222,7 +244,7 @@ struct PacketBufferValue { } #[derive(Error, Debug)] -pub(crate) enum PacketBufferError { +pub enum PacketError { #[error("missing address in packet with frag_id 0")] NoAddress, #[error("unexpected address in packet")] diff --git a/src/lib.rs b/src/lib.rs index 9b44173..b6f585a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ pub mod client; #[cfg(any(feature = "server", feature = "client"))] pub use crate::common::{ - packet::PacketBufferHandle, + packet::{Packet, PacketBufferHandle}, stream::{BiStream, RecvStream, SendStream}, CongestionControl, UdpRelayMode, }; From 42f4c959770a42ec0463ef130945c84c71d05732 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 18 Sep 2022 17:24:03 +0900 Subject: [PATCH 036/103] simplifying `Packet` --- src/common/incoming.rs | 63 +++++++++-------- src/common/packet.rs | 156 +++++++++++++++++++++-------------------- src/lib.rs | 2 +- 3 files changed, 112 insertions(+), 109 deletions(-) diff --git a/src/common/incoming.rs b/src/common/incoming.rs index 0bdb3c5..cd57515 100644 --- a/src/common/incoming.rs +++ b/src/common/incoming.rs @@ -1,5 +1,5 @@ use super::{ - packet::{NeedAccept, NeedAssembly, Packet, PacketBuffer}, + packet::{state::NeedAccept, Packet, PacketBuffer}, stream::{BiStream, RecvStream, SendStream, StreamReg}, }; use crate::protocol::{Address, Command, MarshalingError, ProtocolError}; @@ -130,6 +130,32 @@ impl RawPendingIncomingTask { } } + async fn accept_from_datagram( + datagram: Bytes, + pkt_buf: Arc, + ) -> Result { + let cmd = Command::read_from(&mut datagram.as_ref()) + .await + .map_err(IncomingError::from_marshaling_error)?; + let payload = datagram.slice(cmd.serialized_len()..); + + match cmd { + Command::Packet { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } => Ok(RawIncomingTask::Packet( + Packet::::new_from_datagram( + assoc_id, pkt_id, frag_total, frag_id, len, addr, pkt_buf, payload, + ), + )), + cmd => Err(IncomingError::UnexpectedCommandFromDatagram(datagram, cmd)), + } + } + async fn accept_from_uni_stream( mut stream: RecvStream, ) -> Result { @@ -146,48 +172,23 @@ impl RawPendingIncomingTask { frag_id, len, addr, - } => Ok(RawIncomingTask::PacketFromUniStream( - Packet::::new(assoc_id, pkt_id, frag_total, frag_id, len, addr, stream), + } => Ok(RawIncomingTask::Packet( + Packet::::new_from_uni_stream( + assoc_id, pkt_id, frag_total, frag_id, len, addr, stream, + ), )), Command::Dissociate { assoc_id } => Ok(RawIncomingTask::Dissociate(assoc_id)), Command::Heartbeat => Ok(RawIncomingTask::Heartbeat), cmd => Err(IncomingError::UnexpectedCommandFromUniStream(stream, cmd)), } } - - async fn accept_from_datagram( - datagram: Bytes, - pkt_buf: Arc, - ) -> Result { - let cmd = Command::read_from(&mut datagram.as_ref()) - .await - .map_err(IncomingError::from_marshaling_error)?; - let pkt = datagram.slice(cmd.serialized_len()..); - - match cmd { - Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - } => Ok(RawIncomingTask::PacketFromDatagram( - Packet::::new( - assoc_id, pkt_id, frag_total, frag_id, len, addr, pkt_buf, pkt, - ), - )), - cmd => Err(IncomingError::UnexpectedCommandFromDatagram(datagram, cmd)), - } - } } #[non_exhaustive] pub(crate) enum RawIncomingTask { Authenticate([u8; 32]), Connect(Address, BiStream), - PacketFromDatagram(Packet), - PacketFromUniStream(Packet), + Packet(Packet), Dissociate(u32), Heartbeat, } diff --git a/src/common/packet.rs b/src/common/packet.rs index a9da20c..0f99527 100644 --- a/src/common/packet.rs +++ b/src/common/packet.rs @@ -1,3 +1,4 @@ +use self::state::{NeedAccept, Ready, StateInner}; use crate::{ protocol::{Address, Command}, RecvStream, @@ -11,25 +12,57 @@ use std::{ }; use thiserror::Error; -pub struct NeedAccept; -pub struct NeedAssembly; -pub struct Ready; +pub mod state { + use super::PacketBuffer; + use crate::RecvStream; + use bytes::Bytes; + use std::sync::Arc; + + pub struct NeedAccept; + pub struct Ready; + + pub(super) enum StateInner { + FromDatagram(Bytes, Arc), + FromUniStream(RecvStream), + Ready(Bytes), + } +} pub struct Packet { assoc_id: u32, pkt_id: u16, - frag_id: Option, + frag_id: u8, frag_total: u8, len: u16, addr: Option
, - src: Option, - pkt_buf: Option>, - inner: Option, + inner: StateInner, _state: S, } impl Packet { - pub(super) fn new( + pub(super) fn new_from_datagram( + assoc_id: u32, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + len: u16, + addr: Option
, + pkt_buf: Arc, + payload: Bytes, + ) -> Self { + Self { + assoc_id, + pkt_id, + frag_id, + frag_total, + len, + addr, + inner: StateInner::FromDatagram(payload, pkt_buf), + _state: NeedAccept, + } + } + + pub(super) fn new_from_uni_stream( assoc_id: u32, pkt_id: u16, frag_total: u8, @@ -41,48 +74,16 @@ impl Packet { Self { assoc_id, pkt_id, - frag_id: Some(frag_id), + frag_id, frag_total, len, addr, - src: Some(stream), - pkt_buf: None, - inner: None, + inner: StateInner::FromUniStream(stream), _state: NeedAccept, } } - pub async fn accept(self) -> Result, PacketError> { - todo!() - } -} - -impl Packet { - pub(super) fn new( - assoc_id: u32, - pkt_id: u16, - frag_total: u8, - frag_id: u8, - len: u16, - addr: Option
, - pkt_buf: Arc, - pkt: Bytes, - ) -> Self { - Self { - assoc_id, - pkt_id, - frag_id: Some(frag_id), - frag_total, - len, - addr, - src: None, - pkt_buf: Some(pkt_buf), - inner: Some(pkt), - _state: NeedAssembly, - } - } - - pub fn assemble(self) -> Result, PacketError> { + pub async fn accept(self) -> Result>, PacketError> { todo!() } } @@ -99,13 +100,11 @@ impl Packet { Self { assoc_id, pkt_id, - frag_id: None, + frag_id: 0, frag_total, len, addr: Some(addr), - src: None, - pkt_buf: None, - inner: Some(pkt), + inner: StateInner::Ready(pkt), _state: Ready, } } @@ -120,20 +119,23 @@ impl PacketBuffer { pub(crate) fn insert( &mut self, - pkt: Packet, + assoc_id: u32, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + len: u16, + addr: Option
, + pkt: Bytes, ) -> Result>, PacketError> { let mut pkt_buf = self.0.lock(); - let key = PacketBufferKey { - assoc_id: pkt.assoc_id, - pkt_id: pkt.pkt_id, - }; + let key = PacketBufferKey { assoc_id, pkt_id }; - if pkt.frag_id.unwrap() == 0 && pkt.addr.is_none() { + if frag_id == 0 && addr.is_none() { pkt_buf.remove(&key); return Err(PacketError::NoAddress); } - if pkt.frag_id.unwrap() != 0 && pkt.addr.is_some() { + if frag_id != 0 && addr.is_some() { pkt_buf.remove(&key); return Err(PacketError::UnexpectedAddress); } @@ -142,19 +144,19 @@ impl PacketBuffer { Entry::Occupied(mut entry) => { let v = entry.get_mut(); - if pkt.frag_total == 0 - || pkt.frag_id.unwrap() >= pkt.frag_total - || v.buf.len() != pkt.frag_total as usize - || v.buf[pkt.frag_id.unwrap() as usize].is_some() + if frag_total == 0 + || frag_id >= frag_total + || v.buf.len() != frag_total as usize + || v.buf[frag_id as usize].is_some() { return Err(PacketError::BadFragment); } - v.total_len += pkt.len as usize; - v.buf[pkt.frag_id.unwrap() as usize] = Some(pkt.inner.unwrap()); + v.total_len += len as usize; + v.buf[frag_id as usize] = Some(pkt); v.recv_count += 1; - if v.recv_count == pkt.frag_total as usize { + if v.recv_count == frag_total as usize { let v = entry.remove(); let mut res = BytesMut::with_capacity(v.total_len); @@ -163,10 +165,10 @@ impl PacketBuffer { } Ok(Some(Packet::::new( - pkt.assoc_id, - pkt.pkt_id, - pkt.frag_total, - pkt.len, + assoc_id, + pkt_id, + frag_total, + len, v.addr.unwrap(), res.freeze(), ))) @@ -175,31 +177,31 @@ impl PacketBuffer { } } Entry::Vacant(entry) => { - if pkt.frag_total == 0 || pkt.frag_id.unwrap() >= pkt.frag_total { + if frag_total == 0 || frag_id >= frag_total { return Err(PacketError::BadFragment); } - if pkt.frag_total == 1 { + if frag_total == 1 { return Ok(Some(Packet::::new( - pkt.assoc_id, - pkt.pkt_id, - pkt.frag_total, - pkt.len, - pkt.addr.unwrap(), - pkt.inner.unwrap(), + assoc_id, + pkt_id, + frag_total, + len, + addr.unwrap(), + pkt, ))); } let mut v = PacketBufferValue { - buf: vec![None; pkt.frag_total as usize], - addr: pkt.addr, + buf: vec![None; frag_total as usize], + addr, recv_count: 0, total_len: 0, c_time: Instant::now(), }; - v.total_len += pkt.len as usize; - v.buf[pkt.frag_id.unwrap() as usize] = Some(pkt.inner.unwrap()); + v.total_len += len as usize; + v.buf[frag_id as usize] = Some(pkt); v.recv_count += 1; entry.insert(v); diff --git a/src/lib.rs b/src/lib.rs index b6f585a..4b174ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ pub mod client; #[cfg(any(feature = "server", feature = "client"))] pub use crate::common::{ - packet::{Packet, PacketBufferHandle}, + packet::{state as packet_state, Packet, PacketBufferHandle}, stream::{BiStream, RecvStream, SendStream}, CongestionControl, UdpRelayMode, }; From 1e51a221a5a504b44483e9096ef507de03edecc2 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 23 Jan 2023 16:18:33 +0900 Subject: [PATCH 037/103] new project structure --- .github/workflows/build-and-release.yml | 175 ------------- .gitignore | 5 +- Cargo.toml | 67 +---- README.md | 275 -------------------- src/client/connection.rs | 301 ---------------------- src/client/incoming.rs | 69 ----- src/client/mod.rs | 171 ------------- src/common/incoming.rs | 227 ---------------- src/common/mod.rs | 16 -- src/common/packet.rs | 327 ------------------------ src/common/stream.rs | 143 ----------- src/lib.rs | 23 -- src/protocol/marshaling.rs | 247 ------------------ src/protocol/mod.rs | 161 ------------ src/server/connection.rs | 65 ----- src/server/incoming.rs | 24 -- src/server/mod.rs | 137 ---------- {client => tuic-client}/Cargo.toml | 0 {client => tuic-client}/src/main.rs | 0 tuic-quinn/Cargo.toml | 8 + tuic-quinn/src/lib.rs | 14 + {server => tuic-server}/Cargo.toml | 0 {server => tuic-server}/src/main.rs | 0 tuic/Cargo.toml | 8 + tuic/src/lib.rs | 14 + 25 files changed, 48 insertions(+), 2429 deletions(-) delete mode 100644 .github/workflows/build-and-release.yml delete mode 100644 README.md delete mode 100644 src/client/connection.rs delete mode 100644 src/client/incoming.rs delete mode 100644 src/client/mod.rs delete mode 100644 src/common/incoming.rs delete mode 100644 src/common/mod.rs delete mode 100644 src/common/packet.rs delete mode 100644 src/common/stream.rs delete mode 100644 src/lib.rs delete mode 100644 src/protocol/marshaling.rs delete mode 100644 src/protocol/mod.rs delete mode 100644 src/server/connection.rs delete mode 100644 src/server/incoming.rs delete mode 100644 src/server/mod.rs rename {client => tuic-client}/Cargo.toml (100%) rename {client => tuic-client}/src/main.rs (100%) create mode 100644 tuic-quinn/Cargo.toml create mode 100644 tuic-quinn/src/lib.rs rename {server => tuic-server}/Cargo.toml (100%) rename {server => tuic-server}/src/main.rs (100%) create mode 100644 tuic/Cargo.toml create mode 100644 tuic/src/lib.rs diff --git a/.github/workflows/build-and-release.yml b/.github/workflows/build-and-release.yml deleted file mode 100644 index 68234c9..0000000 --- a/.github/workflows/build-and-release.yml +++ /dev/null @@ -1,175 +0,0 @@ -on: - workflow_dispatch: - release: - types: [published] - -name: build-and-release - -jobs: - build-and-release: - strategy: - fail-fast: false - matrix: - include: - # x86_64-linux-gnu - - arch-name: x86_64-linux-gnu - os: ubuntu-latest - target: x86_64-unknown-linux-gnu - cross: false - file-ext: - # x86_64-linux-musl - - arch-name: x86_64-linux-musl - os: ubuntu-latest - target: x86_64-unknown-linux-musl - cross: true - file-ext: - # x86_64-windows-msvc - - arch-name: x86_64-windows-msvc - os: windows-latest - target: x86_64-pc-windows-msvc - cross: false - file-ext: .exe - # x86_64-windows-gnu - - arch-name: x86_64-windows-gnu - os: ubuntu-latest - target: x86_64-pc-windows-gnu - cross: true - file-ext: .exe - # x86_64-macos - - arch-name: x86_64-macos - os: macos-latest - target: x86_64-apple-darwin - cross: false - file-ext: - # x86_64-android - - arch-name: x86_64-android - os: ubuntu-latest - target: x86_64-linux-android - cross: true - file-ext: - # aarch64-linux-gnu - - arch-name: aarch64-linux-gnu - os: ubuntu-latest - target: aarch64-unknown-linux-gnu - cross: true - file-ext: - # aarch64-linux-musl - - arch-name: aarch64-linux-musl - os: ubuntu-latest - target: aarch64-unknown-linux-musl - cross: true - file-ext: - # aarch64-macos - - arch-name: aarch64-macos - os: macos-latest - target: aarch64-apple-darwin - cross: true - file-ext: - # aarch64-android - - arch-name: aarch64-android - os: ubuntu-latest - target: aarch64-linux-android - cross: true - file-ext: - # aarch64-ios - - arch-name: aarch64-ios - os: macos-latest - target: aarch64-apple-ios - cross: true - file-ext: - # i686-linux-gnu - - arch-name: i686-linux-gnu - os: ubuntu-latest - target: i686-unknown-linux-gnu - cross: true - file-ext: - # i686-linux-musl - - arch-name: i686-linux-musl - os: ubuntu-latest - target: i686-unknown-linux-musl - cross: true - file-ext: - # i686-windows-msvc - - arch-name: i686-windows-msvc - os: windows-latest - target: i686-pc-windows-msvc - cross: true - file-ext: .exe - # i686-android - - arch-name: i686-android - os: ubuntu-latest - target: i686-linux-android - cross: true - file-ext: - # arm-linux-gnueabihf - - arch-name: arm-linux-gnueabihf - os: ubuntu-latest - target: arm-unknown-linux-gnueabihf - cross: true - file-ext: - # armv7-linux-musleabihf - - arch-name: armv7-linux-musleabihf - os: ubuntu-latest - target: armv7-unknown-linux-musleabihf - cross: true - file-ext: - # armv7-android - - arch-name: armv7-android - os: ubuntu-latest - target: armv7-linux-androideabi - cross: true - file-ext: - - runs-on: ${{ matrix.os }} - - steps: - - name: Checkout repository - uses: actions/checkout@v2 - with: - fetch-depth: 0 - - - name: Get the latest tag - id: tag - uses: "WyriHaximus/github-action-get-previous-tag@v1" - - - name: Install toolchain - uses: actions-rs/toolchain@v1 - with: - toolchain: stable - target: ${{ matrix.target }} - override: true - - - name: Build server - uses: actions-rs/cargo@v1 - with: - use-cross: ${{ matrix.cross }} - command: build - args: --release -p tuic-server --target ${{ matrix.target }} - - - name: Build client - uses: actions-rs/cargo@v1 - with: - use-cross: ${{ matrix.cross }} - command: build - args: --release -p tuic-client --target ${{ matrix.target }} - - - name: Move binaries - run: | - mkdir artifacts/ - mv target/${{ matrix.target }}/release/tuic-server${{ matrix.file-ext }} artifacts/tuic-server-${{ steps.tag.outputs.tag }}-${{ matrix.arch-name }}${{ matrix.file-ext }} - mv target/${{ matrix.target }}/release/tuic-client${{ matrix.file-ext }} artifacts/tuic-client-${{ steps.tag.outputs.tag }}-${{ matrix.arch-name }}${{ matrix.file-ext }} - - - name: Calculate SHA256 - run: | - cd artifacts/ - openssl dgst -sha256 -r tuic-server-${{ steps.tag.outputs.tag }}-${{ matrix.arch-name }}${{ matrix.file-ext }} > tuic-server-${{ steps.tag.outputs.tag }}-${{ matrix.arch-name }}${{ matrix.file-ext }}.sha256sum - openssl dgst -sha256 -r tuic-client-${{ steps.tag.outputs.tag }}-${{ matrix.arch-name }}${{ matrix.file-ext }} > tuic-client-${{ steps.tag.outputs.tag }}-${{ matrix.arch-name }}${{ matrix.file-ext }}.sha256sum - - - name: Release binaries - uses: ncipollo/release-action@v1 - with: - artifacts: "artifacts/*" - tag: ${{ steps.tag.outputs.tag }} - name: ${{ steps.tag.outputs.tag }} - allowUpdates: true - token: ${{ secrets.PERSONAL_ACCESS_TOKEN }} diff --git a/.gitignore b/.gitignore index 16d5636..9029b80 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ -/target -Cargo.lock +.vscode/ +target/ .DS_Store +Cargo.lock \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 67a58d8..05881e9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,67 +1,2 @@ -[package] -name = "tuic" -version = "5.0.0" -authors = ["EAimTY "] -description = "Delicately-TUICed high-performance proxy built on top of the QUIC protocol" -categories = ["network-programming", "command-line-utilities"] -keywords = ["tuic", "proxy", "quic"] -edition = "2021" -rust-version = "1.59" -readme = "README.md" -license = "GPL-3.0-or-later" -repository = "https://github.com/EAimTY/tuic" - [workspace] -members = [ - "client", - "server", -] - -[dependencies] -# blake3 = "1.3.*" -byteorder = { version = "1.4.*", default-features = false, optional = true } -bytes = { version = "1.2.*", default-features = false, optional = true } -crossbeam-utils = { version = "0.8.*", default-features = false, optional = true } -# env_logger = { version = "0.9.*", features = ["humantime"], default-features = false } -futures = { version = "0.3.*", default-features = false, optional = true } -# getopts = "0.2.*" -# log = { version = "0.4.*", features = ["serde", "std"] } -# once_cell = { version = "1.13.*", features = ["parking_lot"] } -parking_lot = { version = "0.12.*", default-features = false, optional = true } -quinn = { version = "0.8.*", features = ["tls-rustls"], default-features = false, optional = true } -# rand = "0.8.*" -rustls = { version = "0.20.*", default-features = false, optional = true } -# rustls-native-certs = "0.6.*" -# rustls-pemfile = "1.0.*" -# serde = { version = "1.0.*", features = ["derive", "std"], default-features = false } -# serde_json = { version = "1.0.*", features = ["std"], default-features = false } -# socket2 = "0.4.*" -# socks5-proto = "0.3.*" -# socks5-server = "0.8.*" -thiserror = { version = "1.0.*", default-features = false, optional = true } -tokio = { version = "1.20.*", default-features = false, optional = true } -# tokio = { version = "1.20.*", features = ["io-util", "macros", "net", "parking_lot", "rt-multi-thread", "sync", "time"] } -# webpki = { version = "0.22.*", default-features = false } - -[features] -default = [] - -all = ["protocol_marshaling", "server", "client"] - -protocol_marshaling = ["byteorder/std", "bytes", "thiserror", "tokio/io-util"] - -server = ["crossbeam-utils", "futures", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util", "tokio/macros"] -client = ["futures", "parking_lot", "protocol_marshaling", "quinn", "rustls", "thiserror", "tokio/io-util", "tokio/macros"] - -[dev-dependencies] -tuic = { path = ".", features = ["all"] } - -[profile.release] -lto = true -strip = true -incremental = false -codegen-units = 1 -panic = "abort" - -[package.metadata.docs.rs] -all-features = true \ No newline at end of file +members = ["tuic", "tuic-quinn", "tuic-server", "tuic-client"] \ No newline at end of file diff --git a/README.md b/README.md deleted file mode 100644 index 3b7dc15..0000000 --- a/README.md +++ /dev/null @@ -1,275 +0,0 @@ -# TUIC - -Delicately-TUICed high-performance proxy built on top of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol - -**TUIC's goal is to minimize the handshake latency as much as possible** - -## Features - -- 1-RTT TCP relaying -- 0-RTT UDP relaying with [Full Cone NAT](https://www.rfc-editor.org/rfc/rfc3489#section-5) -- Two UDP relay modes: `native` (native UDP mechanisms) and `quic` (100% delivery rate) -- Bidirectional user-space congestion control (BBR, New Reno and CUBIC) -- Multiplexing all tasks into a single QUIC connection (tasks are separately flow controlled) -- Smooth session transfer on network switching -- Paralleled 0-RTT authentication -- Optional QUIC 0-RTT handshake - -## Design - -TUIC was designed on the basis of the QUIC protocol from the very beginning. It can make full use of the advantages brought by QUIC. You can find more information about the TUIC protocol [here](https://github.com/EAimTY/tuic/tree/master/protocol). - -### Multiplexing - -TUIC multiplexes all tasks into a single QUIC connection using QUIC's multi-streams mechanism. This means that unless the QUIC connection is forcibly interrupted or no task within the maximum idle time, negotiating new relay task does not need to go through the process of QUIC handshake and TUIC authentication. - -### UDP Relaying - -TUIC has 2 UDP relay modes: - -- `native` - using QUIC's datagram to transmit UDP packets. As with native UDP, packets may be lost, but the overhead of the acknowledgment mechanism is omitted. Relayed packets are still encrypted by QUIC. - -- `quic` - transporting UDP packets as QUIC streams. Because of the acknowledgment and retransmission mechanism, UDP packets can guarantee a 100% delivery rate, but have additional transmission overhead as a result. Note that each UDP data packet is transmitted as a separate stream, and the flow controlled separately, so the loss and retransmission of one packet will not cause other packets to be blocked. This mode can be used to transmit UDP packets larger than the MTU of the underlying network. - -### Bidirectional User-space Congestion Control - -Since QUIC is implemented over UDP, its congestion control implementation is not limited by platform and operating system. For poor quality network, [BBR algorithm](https://en.wikipedia.org/wiki/TCP_congestion_control#TCP_BBR) can be used on both the server and the client to achieve better transmission performance. - -### Security - -As mentioned above, TUIC is based on the QUIC protocol, which uses TLS to encrypt data. TUIC protocol itself does not provide any security, but the QUIC protocol provides a strong security guarantee. TUIC also supports QUIC's 0-RTT handshake, but it came with a cost of weakened security, [read more about QUIC 0-RTT handshake](https://blog.cloudflare.com/even-faster-connection-establishment-with-quic-0-rtt-resumption/#attack-of-the-clones). - -## Usage - -TUIC depends on [rustls](https://github.com/rustls/rustls), which uses [ring](https://github.com/briansmith/ring) for implementing the cryptography in TLS. As a result, TUIC only runs on platforms supported by ring. At the time of writing this means x86, x86-64, armv7, and aarch64. - -You can find pre-compiled binaries in the latest [releases](https://github.com/EAimTY/tuic/releases). - -### Server - -``` -tuic-server - -Options: - -c, --config CONFIG_FILE - Read configuration from a file. Note that command line - arguments will override the configuration file - --port SERVER_PORT - Set the server listening port - --token TOKEN Set the token for TUIC authentication. This option can - be used multiple times to set multiple tokens. - --certificate CERTIFICATE - Set the X.509 certificate. This must be an end-entity - certificate - --private-key PRIVATE_KEY - Set the certificate private key - --ip IP Set the server listening IP. Default: 0.0.0.0 - --congestion-controller CONGESTION_CONTROLLER - Set the congestion control algorithm. Available: - "cubic", "new_reno", "bbr". Default: "cubic" - --max-idle-time MAX_IDLE_TIME - Set the maximum idle time for QUIC connections, in - milliseconds. Default: 15000 - --authentication-timeout AUTHENTICATION_TIMEOUT - Set the maximum time allowed between a QUIC connection - established and the TUIC authentication packet - received, in milliseconds. Default: 1000 - --alpn ALPN_PROTOCOL - Set ALPN protocols that the server accepts. This - option can be used multiple times to set multiple ALPN - protocols. If not set, the server will not check ALPN - at all - --max-udp-relay-packet-size MAX_UDP_RELAY_PACKET_SIZE - UDP relay mode QUIC can transmit UDP packets larger - than the MTU. Set this to a higher value allows - outbound to receive larger UDP packet. Default: 1500 - --log-level LOG_LEVEL - Set the log level. Available: "off", "error", "warn", - "info", "debug", "trace". Default: "info" - -v, --version Print the version - -h, --help Print this help menu -``` - -The configuration file is in JSON format: - -```json -{ - "port": 443, - "token": ["TOKEN0", "TOKEN1"], - "certificate": "/PATH/TO/CERT", - "private_key": "/PATH/TO/PRIV_KEY", - - "ip": "0.0.0.0", - "congestion_controller": "cubic", - "max_idle_time": 15000, - "authentication_timeout": 1000, - "alpn": ["h3"], - "max_udp_relay_packet_size": 1500, - "log_level": "info" -} -``` - -Fields `port`, `token`, `certificate`, `private_key` are required. Other fields are optional and can be deleted to fall-back the default value. - -Note that command line arguments can override the configuration file. - -### Client - -``` -tuic-client - -Options: - -c, --config CONFIG_FILE - Read configuration from a file. Note that command line - arguments will override the configuration file - --server SERVER Set the server address. This address must be included - in the certificate - --server-port SERVER_PORT - Set the server port - --token TOKEN Set the token for TUIC authentication - --server-ip SERVER_IP - Set the server IP, for overwriting the DNS lookup - result of the server address set in option 'server' - --certificate CERTIFICATE - Set custom X.509 certificate alongside native CA roots - for the QUIC handshake. This option can be used - multiple times to set multiple certificates - --udp-relay-mode UDP_MODE - Set the UDP relay mode. Available: "native", "quic". - Default: "native" - --congestion-controller CONGESTION_CONTROLLER - Set the congestion control algorithm. Available: - "cubic", "new_reno", "bbr". Default: "cubic" - --heartbeat-interval HEARTBEAT_INTERVAL - Set the heartbeat interval to ensures that the QUIC - connection is not closed when there are relay tasks - but no data transfer, in milliseconds. This value - needs to be smaller than the maximum idle time set at - the server side. Default: 10000 - --alpn ALPN_PROTOCOL - Set ALPN protocols included in the TLS client hello. - This option can be used multiple times to set multiple - ALPN protocols. If not set, no ALPN extension will be - sent - --disable-sni Not sending the Server Name Indication (SNI) extension - during the client TLS handshake - --reduce-rtt Enable 0-RTT QUIC handshake - --request-timeout REQUEST_TIMEOUT - Set the timeout for negotiating tasks between client - and the server, in milliseconds. Default: 8000 - --max-udp-relay-packet-size MAX_UDP_RELAY_PACKET_SIZE - UDP relay mode QUIC can transmit UDP packets larger - than the MTU. Set this to a higher value allows - inbound to receive larger UDP packet. Default: 1500 - --local-port LOCAL_PORT - Set the listening port for the local socks5 server - --local-ip LOCAL_IP - Set the listening IP for the local socks5 server. Note - that the sock5 server socket will be a dual-stack - socket if it is IPv6. Default: "127.0.0.1" - --local-username LOCAL_USERNAME - Set the username for the local socks5 server - authentication - --local-password LOCAL_PASSWORD - Set the password for the local socks5 server - authentication - --log-level LOG_LEVEL - Set the log level. Available: "off", "error", "warn", - "info", "debug", "trace". Default: "info" - -v, --version Print the version - -h, --help Print this help menu -``` - -The configuration file is in JSON format: - -```json -{ - "relay": { - "server": "SERVER", - "port": 443, - "token": "TOKEN", - - "ip": "SERVER_IP", - "certificates": ["/PATH/TO/CERT"], - "udp_relay_mode": "native", - "congestion_controller": "cubic", - "heartbeat_interval": 10000, - "alpn": ["h3"], - "disable_sni": false, - "reduce_rtt": false, - "request_timeout": 8000, - "max_udp_relay_packet_size": 1500 - }, - "local": { - "port": 1080, - - "ip": "127.0.0.1", - "username": "SOCKS5_USERNAME", - "password": "SOCKS5_PASSWORD" - }, - "log_level": "info" -} -``` - -Fields `server`, `token` and `port` in both sections are required. Other fields are optional and can be deleted to fall-back the default value. - -Note that command line arguments can override the configuration file. - -## GUI Clients - -### Android - -- [SagerNet](https://sagernet.org/) - -### iOS - -- [Stash](https://stash.ws/) * - -*[Stash](https://stash.ws/) re-implemented the TUIC protocol from scratch, so it didn't preserve the GPL License. - -### Windows - -- [v2rayN](https://github.com/2dust/v2rayN) - -## FAQ - -### What are the advantages of TUIC over other proxy protocols / implementions? - -As mentioned before, TUIC's goal is to minimize the handshake latency as much as possible. Thus, the core of the TUIC protocol is to reduce the additional round trip time added by the relay. For TCP relaying, TUIC only adds a single round trip between the TUIC server and the TUIC client - half of a typical TCP-based proxy would require. TUIC also has a unique UDP relaying mechanism. It achieves 0-RTT UDP relaying by syncing UDP relay sessions implicitly between the server and the client. - -Low handshake latency means faster connection establishment and UDP packet delay time. TUIC also supports both UDP over streams and UDP over datagrams for UDP relaying. All of these makes TUIC one of the most efficient proxy protocol for UDP relaying. - -### Why my TUIC is slower than other proxy protocols / implementions? - -For an Internet connection, fast / slow is defined by both: - -- Handshake latency -- Bandwidth - -They are equally important. For the first case, TUIC can be one of the best solution right now. You can directly feel it from things like the speed of opening a web page in your browser. For the second case, TUIC may be a bit slower than other TCP-based proxy protocols due to ISPs' QoS, but TUIC's bandwidth can still be competitive in most scenario. - -### How can I listen both IPv4 and IPv6 on TUIC server / TUIC client's socks5 server? - -TUIC always constructs an IPv6 listener as a dual-stack socket. If you need to listen on both IPv4 and IPv6, you can set the bind IP to the unspecified IPv6 address `::`. - -### Why TUIC client doesn't support other inbound / advanced route settings? - -Since there are already many great proxy convert / distribute solutions, there really is no need for me to reimplement those again. If you need those functions, the best choice to chain a V2Ray layer in front of the TUIC client. For a typical network program, there is basically no performance cost for local relaying. - -### Why TUIC client is not able to convert the first connection into 0-RTT - -It is totally fine and designed to be like that. - -> The basic idea behind 0-RTT connection resumption is that if the client and server had previously established a TLS connection between each other, they can use information cached from that session to establish a new one without having to negotiate the connection’s parameters from scratch. Notably this allows the client to compute the private encryption keys required to protect application data before even talking to the server. -> -> *--[Even faster connection establishment with QUIC 0-RTT resumption](https://blog.cloudflare.com/even-faster-connection-establishment-with-quic-0-rtt-resumption) - Cloudflare* - -When the client program starts, trying to convert the very first connection to 0-RTT will always fail because the client has no server-related information yet. This connection handshake will fall-back to the regular 1-RTT one. - -Once the client caches server information from the first connection, any subsequent connection will be convert into a 0-RTT one. That is why you only see this warning message once just after starting the client. - -Therefore, you can safely ignore this warn. - -## License - -GNU General Public License v3.0 diff --git a/src/client/connection.rs b/src/client/connection.rs deleted file mode 100644 index b0c0d33..0000000 --- a/src/client/connection.rs +++ /dev/null @@ -1,301 +0,0 @@ -use super::IncomingTasks; -use crate::{ - common::{ - packet, - stream::{RecvStream, SendStream, StreamReg}, - }, - protocol::{Address, Command, MarshalingError, ProtocolError}, - BiStream, UdpRelayMode, -}; -use bytes::{Bytes, BytesMut}; -use quinn::{ - Connecting as QuinnConnecting, Connection as QuinnConnection, - NewConnection as QuinnNewConnection, SendDatagramError as QuinnSendDatagramError, -}; -use std::{ - io::Error as IoError, - string::FromUtf8Error, - sync::{ - atomic::{AtomicU16, Ordering}, - Arc, - }, -}; -use thiserror::Error; -use tokio::io::AsyncWriteExt; - -pub struct Connecting { - conn: QuinnConnecting, - enable_quic_0rtt: bool, - udp_relay_mode: UdpRelayMode, -} - -impl Connecting { - pub(super) fn new( - conn: QuinnConnecting, - enable_quic_0rtt: bool, - udp_relay_mode: UdpRelayMode, - ) -> Self { - Self { - conn, - enable_quic_0rtt, - udp_relay_mode, - } - } - - pub async fn establish( - self, - ) -> Result, Self> { - let QuinnNewConnection { - connection, - bi_streams, - uni_streams, - datagrams, - .. - } = if self.enable_quic_0rtt { - match self.conn.into_0rtt() { - Ok((conn, _)) => conn, - Err(conn) => { - return Err(Self { - conn, - enable_quic_0rtt: false, - udp_relay_mode: self.udp_relay_mode, - }); - } - } - } else { - match self.conn.await { - Ok(conn) => conn, - Err(err) => return Ok(Err(ConnectionError::from(IoError::from(err)))), - } - }; - - let stream_reg = Arc::new(Arc::new(())); - let conn = Connection::new(connection, self.udp_relay_mode, stream_reg.clone()); - let incoming = IncomingTasks::new( - bi_streams, - uni_streams, - datagrams, - self.udp_relay_mode, - stream_reg, - ); - - Ok(Ok((conn, incoming))) - } -} - -pub struct Connection { - conn: QuinnConnection, - udp_relay_mode: UdpRelayMode, - stream_reg: Arc, - next_pkt_id: Arc, -} - -impl Connection { - fn new( - conn: QuinnConnection, - udp_relay_mode: UdpRelayMode, - stream_reg: Arc, - ) -> Self { - Self { - conn, - udp_relay_mode, - stream_reg: stream_reg.clone(), - next_pkt_id: Arc::new(AtomicU16::new(0)), - } - } - - pub async fn authenticate(&self, token: [u8; 32]) -> Result<(), ConnectionError> { - let mut send = self.get_send_stream().await?; - let cmd = Command::Authenticate(token); - cmd.write_to(&mut send).await?; - send.finish().await?; - Ok(()) - } - - pub async fn heartbeat(&self) -> Result<(), ConnectionError> { - let mut send = self.get_send_stream().await?; - let cmd = Command::Heartbeat; - cmd.write_to(&mut send).await?; - send.finish().await?; - Ok(()) - } - - pub async fn connect(&self, addr: Address) -> Result, ConnectionError> { - let mut stream = self.get_bi_stream().await?; - - let cmd = Command::Connect { addr }; - cmd.write_to(&mut stream).await?; - - let resp = match Command::read_from(&mut stream).await { - Ok(Command::Respond(resp)) => Ok(resp), - Ok(cmd) => Err(ConnectionError::ShouldBeRespond(cmd)), - Err(err) => Err(ConnectionError::from(err)), - }; - - let res = match resp { - Ok(true) => return Ok(Some(stream)), - Ok(false) => Ok(None), - Err(err) => Err(err), - }; - - stream.finish().await?; - res - } - - pub async fn packet( - &self, - assoc_id: u32, - addr: Address, - pkt: Bytes, - ) -> Result<(), ConnectionError> { - match self.udp_relay_mode { - UdpRelayMode::Native => self.send_packet_to_datagram(assoc_id, addr, pkt), - UdpRelayMode::Quic => self.send_packet_to_uni_stream(assoc_id, addr, pkt).await, - } - } - - pub async fn dissociate(&self, assoc_id: u32) -> Result<(), ConnectionError> { - let mut send = self.get_send_stream().await?; - let cmd = Command::Dissociate { assoc_id }; - cmd.write_to(&mut send).await?; - send.finish().await?; - Ok(()) - } - - async fn get_send_stream(&self) -> Result { - let send = self.conn.open_uni().await.map_err(IoError::from)?; - - Ok(SendStream::new(send, self.stream_reg.as_ref().clone())) - } - - async fn get_bi_stream(&self) -> Result { - let (send, recv) = self.conn.open_bi().await.map_err(IoError::from)?; - - let send = SendStream::new(send, self.stream_reg.as_ref().clone()); - let recv = RecvStream::new(recv, self.stream_reg.as_ref().clone()); - - Ok(BiStream::new(send, recv)) - } - - fn send_packet_to_datagram( - &self, - assoc_id: u32, - addr: Address, - pkt: Bytes, - ) -> Result<(), ConnectionError> { - let max_datagram_size = if let Some(size) = self.conn.max_datagram_size() { - size - } else { - return Err(ConnectionError::DatagramDisabled); - }; - - let pkt_id = self.next_pkt_id.fetch_add(1, Ordering::SeqCst); - let mut pkts = packet::split_packet(pkt, &addr, max_datagram_size); - let frag_total = pkts.len() as u8; - - let first_pkt = pkts.next().unwrap(); - let first_pkt_header = Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id: 0, - len: first_pkt.len() as u16, - addr: Some(addr), - }; - - let mut buf = BytesMut::with_capacity(first_pkt_header.serialized_len() + first_pkt.len()); - first_pkt_header.write_to_buf(&mut buf); - buf.extend_from_slice(&first_pkt); - let buf = buf.freeze(); - - self.conn - .send_datagram(buf) - .map_err(ConnectionError::from_quinn_send_datagram_error)?; - - for (id, pkt) in pkts.enumerate() { - let pkt_header = Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id: id as u8 + 1, - len: pkt.len() as u16, - addr: None, - }; - - let mut buf = BytesMut::with_capacity(pkt_header.serialized_len() + pkt.len()); - pkt_header.write_to_buf(&mut buf); - buf.extend_from_slice(&pkt); - let buf = buf.freeze(); - - self.conn - .send_datagram(buf) - .map_err(ConnectionError::from_quinn_send_datagram_error)?; - } - - Ok(()) - } - - async fn send_packet_to_uni_stream( - &self, - assoc_id: u32, - addr: Address, - pkt: Bytes, - ) -> Result<(), ConnectionError> { - let mut send = self.get_send_stream().await?; - - let cmd = Command::Packet { - assoc_id, - pkt_id: self.next_pkt_id.fetch_add(1, Ordering::SeqCst), - frag_total: 1, - frag_id: 0, - len: pkt.len() as u16, - addr: Some(addr), - }; - - cmd.write_to(&mut send).await?; - send.write_all(&pkt).await?; - send.finish().await?; - - Ok(()) - } -} - -#[derive(Error, Debug)] -pub enum ConnectionError { - #[error(transparent)] - Io(#[from] IoError), - #[error(transparent)] - Protocol(#[from] ProtocolError), - #[error("invalid address encoding: {0}")] - InvalidEncoding(#[from] FromUtf8Error), - #[error("expecting a `Respond`, got a command")] - ShouldBeRespond(Command), - #[error("datagrams not supported by peer")] - DatagramUnsupportedByPeer, - #[error("datagram support disabled")] - DatagramDisabled, - #[error("datagram too large")] - DatagramTooLarge, -} - -impl ConnectionError { - #[inline] - fn from_quinn_send_datagram_error(err: QuinnSendDatagramError) -> Self { - match err { - QuinnSendDatagramError::UnsupportedByPeer => Self::DatagramUnsupportedByPeer, - QuinnSendDatagramError::Disabled => Self::DatagramDisabled, - QuinnSendDatagramError::TooLarge => Self::DatagramTooLarge, - QuinnSendDatagramError::ConnectionLost(err) => Self::Io(IoError::from(err)), - } - } -} - -impl From for ConnectionError { - fn from(err: MarshalingError) -> Self { - match err { - MarshalingError::Io(err) => Self::Io(err), - MarshalingError::Protocol(err) => Self::Protocol(err), - MarshalingError::InvalidEncoding(err) => Self::InvalidEncoding(err), - } - } -} diff --git a/src/client/incoming.rs b/src/client/incoming.rs deleted file mode 100644 index 728f9ea..0000000 --- a/src/client/incoming.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::{ - common::{ - incoming::{IncomingError, RawIncomingTask, RawIncomingTasks, RawPendingIncomingTask}, - packet::PacketBuffer, - stream::StreamReg, - }, - PacketBufferHandle, UdpRelayMode, -}; -use futures::Stream; -use quinn::{Datagrams, IncomingBiStreams, IncomingUniStreams}; -use std::{ - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -pub struct IncomingTasks { - inner: RawIncomingTasks, - udp_relay_mode: UdpRelayMode, -} - -impl IncomingTasks { - pub(super) fn new( - bi_streams: IncomingBiStreams, - uni_streams: IncomingUniStreams, - datagrams: Datagrams, - udp_relay_mode: UdpRelayMode, - stream_reg: Arc, - ) -> Self { - Self { - inner: RawIncomingTasks::new(bi_streams, uni_streams, datagrams, stream_reg), - udp_relay_mode, - } - } -} - -impl Stream for IncomingTasks { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_next(cx).map(|poll| { - poll.map(|res| match res { - Ok(source) => match (source, self.udp_relay_mode) { - (RawPendingIncomingTask::BiStream(stream), _) => { - Err(IncomingError::UnexpectedIncomingBiStream(stream)) - } - (RawPendingIncomingTask::UniStream(stream), UdpRelayMode::Native) => { - Err(IncomingError::UnexpectedIncomingUniStream(stream)) - } - (RawPendingIncomingTask::Datagram(datagram, ..), UdpRelayMode::Quic) => { - Err(IncomingError::UnexpectedIncomingDatagram(datagram)) - } - (source, _) => Ok(PendingIncomingTask(source)), - }, - Err(err) => Err(IncomingError::from(err)), - }) - }) - } -} - -pub struct PendingIncomingTask(RawPendingIncomingTask); - -impl PendingIncomingTask { - pub async fn accept(self) -> Result { - todo!() - } -} - -pub enum IncomingTask {} diff --git a/src/client/mod.rs b/src/client/mod.rs deleted file mode 100644 index cdcb18b..0000000 --- a/src/client/mod.rs +++ /dev/null @@ -1,171 +0,0 @@ -mod connection; -mod incoming; - -pub use self::{ - connection::{Connecting, Connection, ConnectionError}, - incoming::{IncomingTasks, PendingIncomingTask}, -}; - -use crate::{CongestionControl, UdpRelayMode}; -use quinn::{ - congestion::{BbrConfig, CubicConfig, NewRenoConfig}, - ClientConfig as QuinnClientConfig, ConnectError as QuinnConnectError, Endpoint, EndpointConfig, -}; -use rustls::{version, ClientConfig as RustlsClientConfig, RootCertStore}; -use std::{ - convert::Infallible, - fmt::Debug, - io::Error as IoError, - net::{SocketAddr, UdpSocket}, - sync::Arc, -}; -use thiserror::Error; - -#[derive(Debug)] -pub struct Client { - endpoint: Endpoint, - enable_quic_0rtt: bool, - udp_relay_mode: UdpRelayMode, -} - -impl Client { - pub fn bind(cfg: ClientConfig, socket: UdpSocket) -> Result { - let mut crypto = RustlsClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&version::TLS13]) - .unwrap() - .with_root_certificates(cfg.certificates) - .with_no_client_auth(); - - crypto.alpn_protocols = cfg.alpn_protocols; - crypto.enable_early_data = cfg.enable_quic_0rtt; - crypto.enable_sni = !cfg.disable_sni; - - let mut quinn_config = QuinnClientConfig::new(Arc::new(crypto)); - - let transport = Arc::get_mut(&mut quinn_config.transport).unwrap(); - transport.max_idle_timeout(None); - - match cfg.congestion_controller { - CongestionControl::Cubic => { - transport.congestion_controller_factory(Arc::new(CubicConfig::default())); - } - CongestionControl::NewReno => { - transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); - } - CongestionControl::Bbr => { - transport.congestion_controller_factory(Arc::new(BbrConfig::default())); - } - } - - let (mut ep, _) = Endpoint::new(EndpointConfig::default(), None, socket)?; - ep.set_default_client_config(quinn_config); - - Ok(Self { - endpoint: ep, - udp_relay_mode: cfg.udp_relay_mode, - enable_quic_0rtt: cfg.enable_quic_0rtt, - }) - } - - pub fn reconfigure(&mut self, cfg: ClientConfig) -> Result<(), Infallible> { - let mut crypto = RustlsClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&version::TLS13]) - .unwrap() - .with_root_certificates(cfg.certificates) - .with_no_client_auth(); - - crypto.alpn_protocols = cfg.alpn_protocols; - crypto.enable_early_data = cfg.enable_quic_0rtt; - crypto.enable_sni = !cfg.disable_sni; - - let mut quinn_config = QuinnClientConfig::new(Arc::new(crypto)); - - let transport = Arc::get_mut(&mut quinn_config.transport).unwrap(); - transport.max_idle_timeout(None); - - match cfg.congestion_controller { - CongestionControl::Cubic => { - transport.congestion_controller_factory(Arc::new(CubicConfig::default())); - } - CongestionControl::NewReno => { - transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); - } - CongestionControl::Bbr => { - transport.congestion_controller_factory(Arc::new(BbrConfig::default())); - } - } - - self.endpoint.set_default_client_config(quinn_config); - - self.udp_relay_mode = cfg.udp_relay_mode; - self.enable_quic_0rtt = cfg.enable_quic_0rtt; - - Ok(()) - } - - pub fn rebind(&mut self, socket: UdpSocket) -> Result<(), ClientError> { - self.endpoint.rebind(socket)?; - Ok(()) - } - - pub async fn connect( - &self, - addr: SocketAddr, - server_name: &str, - ) -> Result { - let conn = self - .endpoint - .connect(addr, server_name) - .map_err(ClientError::from_quinn_connect_error)?; - - Ok(Connecting::new( - conn, - self.enable_quic_0rtt, - self.udp_relay_mode, - )) - } -} - -#[derive(Clone, Debug)] -pub struct ClientConfig { - pub certificates: RootCertStore, - pub alpn_protocols: Vec>, - pub disable_sni: bool, - pub enable_quic_0rtt: bool, - pub udp_relay_mode: UdpRelayMode, - pub congestion_controller: CongestionControl, -} - -#[derive(Error, Debug)] -pub enum ClientError { - #[error("socket binding error: {0}")] - Socket(#[from] IoError), - #[error("endpoint stopping")] - EndpointStopping, - #[error("too many connections")] - TooManyConnections, - #[error("invalid DNS name: {0}")] - InvalidDnsName(String), - #[error("invalid remote address: {0}")] - InvalidRemoteAddress(SocketAddr), - #[error("unsupported QUIC version")] - UnsupportedQUICVersion, -} - -impl ClientError { - #[inline] - fn from_quinn_connect_error(err: QuinnConnectError) -> Self { - match err { - QuinnConnectError::UnsupportedVersion => Self::UnsupportedQUICVersion, - QuinnConnectError::EndpointStopping => Self::EndpointStopping, - QuinnConnectError::TooManyConnections => Self::TooManyConnections, - QuinnConnectError::InvalidDnsName(err) => Self::InvalidDnsName(err), - QuinnConnectError::InvalidRemoteAddress(err) => Self::InvalidRemoteAddress(err), - QuinnConnectError::NoDefaultClientConfig => unreachable!(), - } - } -} diff --git a/src/common/incoming.rs b/src/common/incoming.rs deleted file mode 100644 index cd57515..0000000 --- a/src/common/incoming.rs +++ /dev/null @@ -1,227 +0,0 @@ -use super::{ - packet::{state::NeedAccept, Packet, PacketBuffer}, - stream::{BiStream, RecvStream, SendStream, StreamReg}, -}; -use crate::protocol::{Address, Command, MarshalingError, ProtocolError}; -use bytes::Bytes; -use futures::{stream::SelectAll, Stream}; -use quinn::{ - Datagrams, IncomingBiStreams, IncomingUniStreams, RecvStream as QuinnRecvStream, - SendStream as QuinnSendStream, -}; -use std::{ - io::Error as IoError, - pin::Pin, - string::FromUtf8Error, - sync::Arc, - task::{Context, Poll}, -}; -use thiserror::Error; - -pub(crate) struct RawIncomingTasks { - incoming: SelectAll, - stream_reg: Arc, - pkt_buf: Arc, -} - -impl RawIncomingTasks { - pub(crate) fn new( - bi_streams: IncomingBiStreams, - uni_streams: IncomingUniStreams, - datagrams: Datagrams, - stream_reg: Arc, - ) -> Self { - let mut incoming = SelectAll::new(); - - incoming.push(IncomingSource::BiStreams(bi_streams)); - incoming.push(IncomingSource::UniStreams(uni_streams)); - incoming.push(IncomingSource::Datagrams(datagrams)); - - Self { - incoming, - stream_reg, - pkt_buf: Arc::new(PacketBuffer::new()), - } - } -} - -impl Stream for RawIncomingTasks { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.incoming) - .poll_next(cx) - .map_ok(|src| match src { - IncomingItem::BiStream((send, recv)) => { - RawPendingIncomingTask::BiStream(BiStream::new( - SendStream::new(send, self.stream_reg.as_ref().clone()), - RecvStream::new(recv, self.stream_reg.as_ref().clone()), - )) - } - IncomingItem::UniStream(recv) => RawPendingIncomingTask::UniStream( - RecvStream::new(recv, self.stream_reg.as_ref().clone()), - ), - IncomingItem::Datagram(datagram) => { - RawPendingIncomingTask::Datagram(datagram, self.pkt_buf.clone()) - } - }) - .map_err(IoError::from) - } -} - -enum IncomingSource { - BiStreams(IncomingBiStreams), - UniStreams(IncomingUniStreams), - Datagrams(Datagrams), -} - -impl Stream for IncomingSource { - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.get_mut() { - IncomingSource::BiStreams(bi_streams) => Pin::new(bi_streams) - .poll_next(cx) - .map_ok(IncomingItem::BiStream) - .map_err(IoError::from), - IncomingSource::UniStreams(uni_streams) => Pin::new(uni_streams) - .poll_next(cx) - .map_ok(IncomingItem::UniStream) - .map_err(IoError::from), - IncomingSource::Datagrams(datagrams) => Pin::new(datagrams) - .poll_next(cx) - .map_ok(IncomingItem::Datagram) - .map_err(IoError::from), - } - } -} - -enum IncomingItem { - BiStream((QuinnSendStream, QuinnRecvStream)), - UniStream(QuinnRecvStream), - Datagram(Bytes), -} - -pub(crate) enum RawPendingIncomingTask { - BiStream(BiStream), - UniStream(RecvStream), - Datagram(Bytes, Arc), -} - -impl RawPendingIncomingTask { - pub(crate) async fn accept(self) -> Result { - match self { - Self::BiStream(stream) => Self::accept_from_bi_stream(stream).await, - Self::UniStream(stream) => Self::accept_from_uni_stream(stream).await, - Self::Datagram(datagram, pkt_buf) => { - Self::accept_from_datagram(datagram, pkt_buf).await - } - } - } - - async fn accept_from_bi_stream(mut stream: BiStream) -> Result { - let cmd = Command::read_from(&mut stream) - .await - .map_err(IncomingError::from_marshaling_error)?; - - match cmd { - Command::Connect { addr } => Ok(RawIncomingTask::Connect(addr, stream)), - cmd => Err(IncomingError::UnexpectedCommandFromBiStream(stream, cmd)), - } - } - - async fn accept_from_datagram( - datagram: Bytes, - pkt_buf: Arc, - ) -> Result { - let cmd = Command::read_from(&mut datagram.as_ref()) - .await - .map_err(IncomingError::from_marshaling_error)?; - let payload = datagram.slice(cmd.serialized_len()..); - - match cmd { - Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - } => Ok(RawIncomingTask::Packet( - Packet::::new_from_datagram( - assoc_id, pkt_id, frag_total, frag_id, len, addr, pkt_buf, payload, - ), - )), - cmd => Err(IncomingError::UnexpectedCommandFromDatagram(datagram, cmd)), - } - } - - async fn accept_from_uni_stream( - mut stream: RecvStream, - ) -> Result { - let cmd = Command::read_from(&mut stream) - .await - .map_err(IncomingError::from_marshaling_error)?; - - match cmd { - Command::Authenticate(token) => Ok(RawIncomingTask::Authenticate(token)), - Command::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - } => Ok(RawIncomingTask::Packet( - Packet::::new_from_uni_stream( - assoc_id, pkt_id, frag_total, frag_id, len, addr, stream, - ), - )), - Command::Dissociate { assoc_id } => Ok(RawIncomingTask::Dissociate(assoc_id)), - Command::Heartbeat => Ok(RawIncomingTask::Heartbeat), - cmd => Err(IncomingError::UnexpectedCommandFromUniStream(stream, cmd)), - } - } -} - -#[non_exhaustive] -pub(crate) enum RawIncomingTask { - Authenticate([u8; 32]), - Connect(Address, BiStream), - Packet(Packet), - Dissociate(u32), - Heartbeat, -} - -#[derive(Error, Debug)] -pub enum IncomingError { - #[error(transparent)] - Io(#[from] IoError), - #[error(transparent)] - Protocol(#[from] ProtocolError), - #[error("invalid address encoding: {0}")] - InvalidEncoding(#[from] FromUtf8Error), - #[error("unexpected incoming bi_stream")] - UnexpectedIncomingBiStream(BiStream), - #[error("unexpected incoming uni_stream")] - UnexpectedIncomingUniStream(RecvStream), - #[error("unexpected incoming datagram")] - UnexpectedIncomingDatagram(Bytes), - #[error("unexpected command from bi_stream: {1:?}")] - UnexpectedCommandFromBiStream(BiStream, Command), - #[error("unexpected command from uni_stream: {1:?}")] - UnexpectedCommandFromUniStream(RecvStream, Command), - #[error("unexpected command from datagram: {1:?}")] - UnexpectedCommandFromDatagram(Bytes, Command), -} - -impl IncomingError { - #[inline] - pub(super) fn from_marshaling_error(err: MarshalingError) -> Self { - match err { - MarshalingError::Io(err) => Self::Io(err), - MarshalingError::Protocol(err) => Self::Protocol(err), - MarshalingError::InvalidEncoding(err) => Self::InvalidEncoding(err), - } - } -} diff --git a/src/common/mod.rs b/src/common/mod.rs deleted file mode 100644 index e074f73..0000000 --- a/src/common/mod.rs +++ /dev/null @@ -1,16 +0,0 @@ -pub(crate) mod incoming; -pub(crate) mod packet; -pub(crate) mod stream; - -#[derive(Clone, Copy, Debug)] -pub enum CongestionControl { - Cubic, - NewReno, - Bbr, -} - -#[derive(Clone, Copy, Debug)] -pub enum UdpRelayMode { - Native, - Quic, -} diff --git a/src/common/packet.rs b/src/common/packet.rs deleted file mode 100644 index 0f99527..0000000 --- a/src/common/packet.rs +++ /dev/null @@ -1,327 +0,0 @@ -use self::state::{NeedAccept, Ready, StateInner}; -use crate::{ - protocol::{Address, Command}, - RecvStream, -}; -use bytes::{Bytes, BytesMut}; -use parking_lot::Mutex; -use std::{ - collections::{hash_map::Entry, HashMap}, - sync::Arc, - time::{Duration, Instant}, -}; -use thiserror::Error; - -pub mod state { - use super::PacketBuffer; - use crate::RecvStream; - use bytes::Bytes; - use std::sync::Arc; - - pub struct NeedAccept; - pub struct Ready; - - pub(super) enum StateInner { - FromDatagram(Bytes, Arc), - FromUniStream(RecvStream), - Ready(Bytes), - } -} - -pub struct Packet { - assoc_id: u32, - pkt_id: u16, - frag_id: u8, - frag_total: u8, - len: u16, - addr: Option
, - inner: StateInner, - _state: S, -} - -impl Packet { - pub(super) fn new_from_datagram( - assoc_id: u32, - pkt_id: u16, - frag_total: u8, - frag_id: u8, - len: u16, - addr: Option
, - pkt_buf: Arc, - payload: Bytes, - ) -> Self { - Self { - assoc_id, - pkt_id, - frag_id, - frag_total, - len, - addr, - inner: StateInner::FromDatagram(payload, pkt_buf), - _state: NeedAccept, - } - } - - pub(super) fn new_from_uni_stream( - assoc_id: u32, - pkt_id: u16, - frag_total: u8, - frag_id: u8, - len: u16, - addr: Option
, - stream: RecvStream, - ) -> Self { - Self { - assoc_id, - pkt_id, - frag_id, - frag_total, - len, - addr, - inner: StateInner::FromUniStream(stream), - _state: NeedAccept, - } - } - - pub async fn accept(self) -> Result>, PacketError> { - todo!() - } -} - -impl Packet { - fn new( - assoc_id: u32, - pkt_id: u16, - frag_total: u8, - len: u16, - addr: Address, - pkt: Bytes, - ) -> Self { - Self { - assoc_id, - pkt_id, - frag_id: 0, - frag_total, - len, - addr: Some(addr), - inner: StateInner::Ready(pkt), - _state: Ready, - } - } -} - -pub(crate) struct PacketBuffer(Mutex>); - -impl PacketBuffer { - pub(crate) fn new() -> Self { - Self(Mutex::new(HashMap::new())) - } - - pub(crate) fn insert( - &mut self, - assoc_id: u32, - pkt_id: u16, - frag_total: u8, - frag_id: u8, - len: u16, - addr: Option
, - pkt: Bytes, - ) -> Result>, PacketError> { - let mut pkt_buf = self.0.lock(); - let key = PacketBufferKey { assoc_id, pkt_id }; - - if frag_id == 0 && addr.is_none() { - pkt_buf.remove(&key); - return Err(PacketError::NoAddress); - } - - if frag_id != 0 && addr.is_some() { - pkt_buf.remove(&key); - return Err(PacketError::UnexpectedAddress); - } - - match pkt_buf.entry(key) { - Entry::Occupied(mut entry) => { - let v = entry.get_mut(); - - if frag_total == 0 - || frag_id >= frag_total - || v.buf.len() != frag_total as usize - || v.buf[frag_id as usize].is_some() - { - return Err(PacketError::BadFragment); - } - - v.total_len += len as usize; - v.buf[frag_id as usize] = Some(pkt); - v.recv_count += 1; - - if v.recv_count == frag_total as usize { - let v = entry.remove(); - let mut res = BytesMut::with_capacity(v.total_len); - - for pkt in v.buf { - res.extend_from_slice(&pkt.unwrap()); - } - - Ok(Some(Packet::::new( - assoc_id, - pkt_id, - frag_total, - len, - v.addr.unwrap(), - res.freeze(), - ))) - } else { - Ok(None) - } - } - Entry::Vacant(entry) => { - if frag_total == 0 || frag_id >= frag_total { - return Err(PacketError::BadFragment); - } - - if frag_total == 1 { - return Ok(Some(Packet::::new( - assoc_id, - pkt_id, - frag_total, - len, - addr.unwrap(), - pkt, - ))); - } - - let mut v = PacketBufferValue { - buf: vec![None; frag_total as usize], - addr, - recv_count: 0, - total_len: 0, - c_time: Instant::now(), - }; - - v.total_len += len as usize; - v.buf[frag_id as usize] = Some(pkt); - v.recv_count += 1; - entry.insert(v); - - Ok(None) - } - } - } - - fn collect_garbage(&self, timeout: Duration) { - self.0.lock().retain(|_, v| v.c_time.elapsed() < timeout); - } - - pub(crate) fn get_handler(self: Arc) -> PacketBufferHandle { - PacketBufferHandle(self.clone()) - } -} - -pub struct PacketBufferHandle(Arc); - -impl PacketBufferHandle { - fn new(pkt_buf: Arc) -> Self { - Self(pkt_buf) - } - pub fn collect_garbage(&self, timeout: Duration) { - self.0.collect_garbage(timeout) - } -} - -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] -struct PacketBufferKey { - assoc_id: u32, - pkt_id: u16, -} - -#[derive(Debug)] -struct PacketBufferValue { - buf: Vec>, - addr: Option
, - recv_count: usize, - total_len: usize, - c_time: Instant, -} - -#[derive(Error, Debug)] -pub enum PacketError { - #[error("missing address in packet with frag_id 0")] - NoAddress, - #[error("unexpected address in packet")] - UnexpectedAddress, - #[error("received bad-fragmented packet")] - BadFragment, -} - -#[inline] -pub(crate) fn split_packet(pkt: Bytes, addr: &Address, max_datagram_size: usize) -> SplitPacket { - SplitPacket::new(pkt, addr, max_datagram_size) -} - -#[derive(Debug)] -pub(crate) struct SplitPacket { - pkt: Bytes, - max_pkt_size: usize, - next_start: usize, - next_end: usize, - len: usize, -} - -impl SplitPacket { - #[inline] - fn new(pkt: Bytes, addr: &Address, max_datagram_size: usize) -> Self { - const DEFAULT_HEADER: Command = Command::Packet { - assoc_id: 0, - pkt_id: 0, - frag_total: 0, - frag_id: 0, - len: 0, - addr: None, - }; - - let first_pkt_size = - max_datagram_size - DEFAULT_HEADER.serialized_len() - addr.serialized_len(); - let max_pkt_size = max_datagram_size - DEFAULT_HEADER.serialized_len(); - let len = if first_pkt_size > pkt.len() { - 1 + (pkt.len() - first_pkt_size) / max_pkt_size + 1 - } else { - 1 - }; - - Self { - pkt, - max_pkt_size, - next_start: 0, - next_end: first_pkt_size, - len, - } - } -} - -impl Iterator for SplitPacket { - type Item = Bytes; - - fn next(&mut self) -> Option { - if self.next_start <= self.pkt.len() { - let next = self - .pkt - .slice(self.next_start..self.next_end.min(self.pkt.len())); - - self.next_start += self.max_pkt_size; - self.next_end += self.max_pkt_size; - - Some(next) - } else { - None - } - } -} - -impl ExactSizeIterator for SplitPacket { - #[inline] - fn len(&self) -> usize { - self.len - } -} diff --git a/src/common/stream.rs b/src/common/stream.rs deleted file mode 100644 index 49d413b..0000000 --- a/src/common/stream.rs +++ /dev/null @@ -1,143 +0,0 @@ -use quinn::{RecvStream as QuinnRecvStream, SendStream as QuinnSendStream}; -use std::{ - io::{IoSlice, Result}, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -pub(crate) type StreamReg = Arc<()>; - -#[derive(Debug)] -pub struct SendStream(QuinnSendStream, StreamReg); - -impl SendStream { - #[inline] - pub(crate) fn new(send: QuinnSendStream, reg: StreamReg) -> Self { - Self(send, reg) - } - - #[inline] - pub async fn finish(&mut self) -> Result<()> { - self.0.finish().await?; - Ok(()) - } -} - -impl AsyncWrite for SendStream { - #[inline] - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) - } - - #[inline] - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() - } - - #[inline] - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) - } - - #[inline] - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) - } -} - -#[derive(Debug)] -pub struct RecvStream(QuinnRecvStream, StreamReg); - -impl RecvStream { - #[inline] - pub(crate) fn new(recv: QuinnRecvStream, reg: StreamReg) -> Self { - Self(recv, reg) - } -} - -impl AsyncRead for RecvStream { - #[inline] - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.0).poll_read(cx, buf) - } -} - -#[derive(Debug)] -pub struct BiStream(SendStream, RecvStream); - -impl BiStream { - #[inline] - pub(crate) fn new(send: SendStream, recv: RecvStream) -> Self { - Self(send, recv) - } - - #[inline] - pub async fn finish(&mut self) -> Result<()> { - self.0.finish().await - } -} - -impl AsyncWrite for BiStream { - #[inline] - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) - } - - #[inline] - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() - } - - #[inline] - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) - } - - #[inline] - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) - } -} - -impl AsyncRead for BiStream { - #[inline] - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.1).poll_read(cx, buf) - } -} diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index 4b174ef..0000000 --- a/src/lib.rs +++ /dev/null @@ -1,23 +0,0 @@ -pub mod protocol; - -#[cfg(any(feature = "server", feature = "client"))] -mod common; - -#[cfg(feature = "server")] -pub mod server; - -#[cfg(feature = "client")] -pub mod client; - -#[cfg(any(feature = "server", feature = "client"))] -pub use crate::common::{ - packet::{state as packet_state, Packet, PacketBufferHandle}, - stream::{BiStream, RecvStream, SendStream}, - CongestionControl, UdpRelayMode, -}; - -#[cfg(feature = "client")] -pub use crate::client::{Client, ClientConfig, ClientError}; - -#[cfg(feature = "server")] -pub use crate::server::{Server, ServerConfig, ServerError}; diff --git a/src/protocol/marshaling.rs b/src/protocol/marshaling.rs deleted file mode 100644 index bd39c28..0000000 --- a/src/protocol/marshaling.rs +++ /dev/null @@ -1,247 +0,0 @@ -use super::{Address, Command, ProtocolError, TUIC_PROTOCOL_VERSION}; -use byteorder::{BigEndian, ReadBytesExt}; -use bytes::BufMut; -use std::{ - io::{Cursor, Error as IoError}, - net::{Ipv4Addr, Ipv6Addr, SocketAddr}, - string::FromUtf8Error, -}; -use thiserror::Error; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - -impl Command { - pub async fn read_from(r: &mut R) -> Result - where - R: AsyncRead + Unpin, - { - let ver = r.read_u8().await?; - - if ver != TUIC_PROTOCOL_VERSION { - return Err(MarshalingError::from(ProtocolError::UnsupportedVersion( - ver, - ))); - } - - let cmd = r.read_u8().await?; - - match cmd { - Self::TYPE_RESPOND => { - let resp = r.read_u8().await?; - match resp { - Self::RESPONSE_SUCCEEDED => Ok(Self::Respond(true)), - Self::RESPONSE_FAILED => Ok(Self::Respond(false)), - _ => Err(MarshalingError::from(ProtocolError::InvalidResponse(resp))), - } - } - Self::TYPE_AUTHENTICATE => { - let mut digest = [0; 32]; - r.read_exact(&mut digest).await?; - Ok(Self::Authenticate(digest)) - } - Self::TYPE_CONNECT => { - let addr = Address::read_from(r).await?; - Ok(Self::Connect { addr }) - } - Self::TYPE_PACKET => { - let mut buf = [0; 12]; - r.read_exact(&mut buf).await?; - let mut rdr = Cursor::new(buf); - - let assoc_id = ReadBytesExt::read_u32::(&mut rdr).unwrap(); - let pkt_id = ReadBytesExt::read_u16::(&mut rdr).unwrap(); - let frag_total = ReadBytesExt::read_u8(&mut rdr).unwrap(); - let frag_id = ReadBytesExt::read_u8(&mut rdr).unwrap(); - let len = ReadBytesExt::read_u16::(&mut rdr).unwrap(); - - let addr = if frag_id == 0 { - Some(Address::read_from(r).await?) - } else { - None - }; - - Ok(Self::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - }) - } - Self::TYPE_DISSOCIATE => { - let assoc_id = r.read_u32().await?; - Ok(Self::Dissociate { assoc_id }) - } - Self::TYPE_HEARTBEAT => Ok(Self::Heartbeat), - _ => Err(MarshalingError::from(ProtocolError::InvalidCommand(cmd))), - } - } - - pub async fn write_to(&self, w: &mut W) -> Result<(), IoError> - where - W: AsyncWrite + Unpin, - { - let mut buf = Vec::with_capacity(self.serialized_len()); - self.write_to_buf(&mut buf); - w.write_all(&buf).await?; - Ok(()) - } - - pub fn write_to_buf(&self, buf: &mut B) { - buf.put_u8(TUIC_PROTOCOL_VERSION); - - match self { - Self::Respond(is_succeeded) => { - buf.put_u8(Self::TYPE_RESPOND); - if *is_succeeded { - buf.put_u8(Self::RESPONSE_SUCCEEDED); - } else { - buf.put_u8(Self::RESPONSE_FAILED); - } - } - Self::Authenticate(digest) => { - buf.put_u8(Self::TYPE_AUTHENTICATE); - buf.put_slice(digest); - } - Self::Connect { addr } => { - buf.put_u8(Self::TYPE_CONNECT); - addr.write_to_buf(buf); - } - Self::Packet { - assoc_id, - pkt_id, - frag_total, - frag_id, - len, - addr, - } => { - buf.put_u8(Self::TYPE_PACKET); - buf.put_u32(*assoc_id); - buf.put_u16(*pkt_id); - buf.put_u8(*frag_total); - buf.put_u8(*frag_id); - buf.put_u16(*len); - - if *frag_id == 0 { - addr.as_ref().unwrap().write_to_buf(buf); - } - } - Self::Dissociate { assoc_id } => { - buf.put_u8(Self::TYPE_DISSOCIATE); - buf.put_u32(*assoc_id); - } - Self::Heartbeat => { - buf.put_u8(Self::TYPE_HEARTBEAT); - } - } - } -} - -impl Address { - pub async fn read_from(stream: &mut R) -> Result - where - R: AsyncRead + Unpin, - { - let addr_type = stream.read_u8().await?; - - match addr_type { - Self::TYPE_DOMAIN => { - let len = stream.read_u8().await? as usize; - - let mut buf = vec![0; len + 2]; - stream.read_exact(&mut buf).await?; - - let port = ReadBytesExt::read_u16::(&mut &buf[len..]).unwrap(); - buf.truncate(len); - - let addr = String::from_utf8(buf)?; - - Ok(Self::DomainAddress(addr, port)) - } - Self::TYPE_IPV4 => { - let mut buf = [0; 6]; - stream.read_exact(&mut buf).await?; - let mut rdr = Cursor::new(buf); - - let addr = Ipv4Addr::new( - ReadBytesExt::read_u8(&mut rdr).unwrap(), - ReadBytesExt::read_u8(&mut rdr).unwrap(), - ReadBytesExt::read_u8(&mut rdr).unwrap(), - ReadBytesExt::read_u8(&mut rdr).unwrap(), - ); - - let port = ReadBytesExt::read_u16::(&mut rdr).unwrap(); - - Ok(Self::SocketAddress(SocketAddr::from((addr, port)))) - } - Self::TYPE_IPV6 => { - let mut buf = [0; 18]; - stream.read_exact(&mut buf).await?; - let mut rdr = Cursor::new(buf); - - let addr = Ipv6Addr::new( - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ReadBytesExt::read_u16::(&mut rdr).unwrap(), - ); - - let port = ReadBytesExt::read_u16::(&mut rdr).unwrap(); - - Ok(Self::SocketAddress(SocketAddr::from((addr, port)))) - } - _ => Err(MarshalingError::from(ProtocolError::InvalidAddressType( - addr_type, - ))), - } - } - - pub async fn write_to(&self, writer: &mut W) -> Result<(), IoError> - where - W: AsyncWrite + Unpin, - { - let mut buf = Vec::with_capacity(self.serialized_len()); - self.write_to_buf(&mut buf); - writer.write_all(&buf).await?; - Ok(()) - } - - pub fn write_to_buf(&self, buf: &mut B) { - match self { - Self::DomainAddress(addr, port) => { - buf.put_u8(Self::TYPE_DOMAIN); - buf.put_u8(addr.len() as u8); - buf.put_slice(addr.as_bytes()); - buf.put_u16(*port); - } - Self::SocketAddress(addr) => match addr { - SocketAddr::V4(addr) => { - buf.put_u8(Self::TYPE_IPV4); - buf.put_slice(&addr.ip().octets()); - buf.put_u16(addr.port()); - } - SocketAddr::V6(addr) => { - buf.put_u8(Self::TYPE_IPV6); - for seg in addr.ip().segments() { - buf.put_u16(seg); - } - buf.put_u16(addr.port()); - } - }, - } - } -} - -#[derive(Error, Debug)] -pub enum MarshalingError { - #[error(transparent)] - Io(#[from] IoError), - #[error(transparent)] - Protocol(#[from] ProtocolError), - #[error("invalid address encoding: {0}")] - InvalidEncoding(#[from] FromUtf8Error), -} diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs deleted file mode 100644 index e179bf1..0000000 --- a/src/protocol/mod.rs +++ /dev/null @@ -1,161 +0,0 @@ -//! The TUIC protocol - -#[cfg(feature = "protocol_marshaling")] -mod marshaling; - -use std::{ - fmt::{Display, Formatter, Result as FmtResult}, - net::SocketAddr, -}; -use thiserror::Error; - -pub const TUIC_PROTOCOL_VERSION: u8 = 0x05; - -#[cfg(feature = "protocol_marshaling")] -pub use self::marshaling::MarshalingError; - -/// Command -/// -/// ```plain -/// +-----+------+----------+ -/// | VER | TYPE | OPT | -/// +-----+------+----------+ -/// | 1 | 1 | Variable | -/// +-----+------+----------+ -/// ``` -#[non_exhaustive] -#[derive(Clone, Debug)] -pub enum Command { - // +-----+ - // | REP | - // +-----+ - // | 1 | - // +-----+ - Respond(bool), - - // +-----+ - // | TKN | - // +-----+ - // | 32 | - // +-----+ - Authenticate([u8; 32]), - - // +----------+ - // | ADDR | - // +----------+ - // | Variable | - // +----------+ - Connect { - addr: Address, - }, - - // +----------+--------+------------+---------+-----+----------+ - // | ASSOC_ID | PKT_ID | FRAG_TOTAL | FRAG_ID | LEN | ADDR | - // +----------+--------+------------+---------+-----+----------+ - // | 4 | 2 | 1 | 1 | 2 | Variable | - // +----------+--------+------------+---------+-----+----------+ - Packet { - assoc_id: u32, - pkt_id: u16, - frag_total: u8, - frag_id: u8, - len: u16, - addr: Option
, - }, - - // +----------+ - // | ASSOC_ID | - // +----------+ - // | 4 | - // +----------+ - Dissociate { - assoc_id: u32, - }, - - // +-+ - // | | - // +-+ - // | | - // +-+ - Heartbeat, -} - -impl Command { - pub const TYPE_RESPOND: u8 = 0xff; - pub const TYPE_AUTHENTICATE: u8 = 0x00; - pub const TYPE_CONNECT: u8 = 0x01; - pub const TYPE_PACKET: u8 = 0x02; - pub const TYPE_DISSOCIATE: u8 = 0x03; - pub const TYPE_HEARTBEAT: u8 = 0x04; - - pub const RESPONSE_SUCCEEDED: u8 = 0x00; - pub const RESPONSE_FAILED: u8 = 0xff; - - pub fn serialized_len(&self) -> usize { - 2 + match self { - Self::Respond(_) => 1, - Self::Authenticate { .. } => 32, - Self::Connect { addr } => addr.serialized_len(), - Self::Packet { addr, .. } => 10 + addr.as_ref().map_or(0, |addr| addr.serialized_len()), - Self::Dissociate { .. } => 4, - Self::Heartbeat => 0, - } - } -} - -/// Address -/// -/// ```plain -/// +------+----------+----------+ -/// | TYPE | ADDR | PORT | -/// +------+----------+----------+ -/// | 1 | Variable | 2 | -/// +------+----------+----------+ -/// ``` -/// -/// The address type can be one of the following: -/// 0x00: fully-qualified domain name (the first byte indicates the length of the domain name) -/// 0x01: IPv4 address -/// 0x02: IPv6 address -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub enum Address { - DomainAddress(String, u16), - SocketAddress(SocketAddr), -} - -impl Address { - pub const TYPE_DOMAIN: u8 = 0x00; - pub const TYPE_IPV4: u8 = 0x01; - pub const TYPE_IPV6: u8 = 0x02; - - pub fn serialized_len(&self) -> usize { - 1 + match self { - Address::DomainAddress(addr, _) => 1 + addr.len() + 2, - Address::SocketAddress(addr) => match addr { - SocketAddr::V4(_) => 6, - SocketAddr::V6(_) => 18, - }, - } - } -} - -impl Display for Address { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - Self::DomainAddress(addr, port) => write!(f, "{addr}:{port}"), - Self::SocketAddress(addr) => write!(f, "{addr}"), - } - } -} - -#[derive(Error, Debug)] -pub enum ProtocolError { - #[error("unsupported TUIC version: {0:#x}")] - UnsupportedVersion(u8), - #[error("invalid command: {0:#x}")] - InvalidCommand(u8), - #[error("invalid response: {0:#x}")] - InvalidResponse(u8), - #[error("invalid address type: {0:#x}")] - InvalidAddressType(u8), -} diff --git a/src/server/connection.rs b/src/server/connection.rs deleted file mode 100644 index 6a36996..0000000 --- a/src/server/connection.rs +++ /dev/null @@ -1,65 +0,0 @@ -use super::IncomingTasks; -use crate::UdpRelayMode; -use crossbeam_utils::atomic::AtomicCell; -use quinn::{ - Connecting as QuinnConnecting, Connection as QuinnConnection, - ConnectionError as QuinnConnectionError, NewConnection as QuinnNewConnection, -}; -use std::{io::Error as IoError, sync::Arc}; -use thiserror::Error; - -pub struct Connecting { - conn: QuinnConnecting, -} - -impl Connecting { - pub(super) fn new(conn: QuinnConnecting) -> Self { - Self { conn } - } - - pub async fn establish(self) -> Result<(Connection, IncomingTasks), ConnectionError> { - let QuinnNewConnection { - connection, - datagrams, - uni_streams, - .. - } = self - .conn - .await - .map_err(ConnectionError::from_quinn_connection_error)?; - - let udp_relay_mode = Arc::new(AtomicCell::new(None)); - - let conn = Connection::new(connection, udp_relay_mode.clone()); - let incoming = IncomingTasks::new(uni_streams, datagrams, udp_relay_mode); - - Ok((conn, incoming)) - } -} - -pub struct Connection { - conn: QuinnConnection, - udp_relay_mode: Arc>>, -} - -impl Connection { - fn new(conn: QuinnConnection, udp_relay_mode: Arc>>) -> Self { - Self { - conn, - udp_relay_mode, - } - } -} - -#[derive(Error, Debug)] -pub enum ConnectionError { - #[error(transparent)] - Io(#[from] IoError), -} - -impl ConnectionError { - #[inline] - fn from_quinn_connection_error(err: QuinnConnectionError) -> Self { - Self::Io(IoError::from(err)) - } -} diff --git a/src/server/incoming.rs b/src/server/incoming.rs deleted file mode 100644 index cd4aadf..0000000 --- a/src/server/incoming.rs +++ /dev/null @@ -1,24 +0,0 @@ -use crate::UdpRelayMode; -use crossbeam_utils::atomic::AtomicCell; -use quinn::{Datagrams, IncomingUniStreams}; -use std::sync::Arc; - -pub struct IncomingTasks { - uni_streams: IncomingUniStreams, - datagrams: Datagrams, - udp_relay_mode: Arc>>, -} - -impl IncomingTasks { - pub(super) fn new( - uni_streams: IncomingUniStreams, - datagrams: Datagrams, - udp_relay_mode: Arc>>, - ) -> Self { - Self { - uni_streams, - datagrams, - udp_relay_mode, - } - } -} diff --git a/src/server/mod.rs b/src/server/mod.rs deleted file mode 100644 index f96f4f5..0000000 --- a/src/server/mod.rs +++ /dev/null @@ -1,137 +0,0 @@ -mod connection; -mod incoming; - -pub use self::{ - connection::{Connecting, Connection, ConnectionError}, - incoming::IncomingTasks, -}; - -use crate::CongestionControl; -use futures::StreamExt; -use quinn::{ - congestion::{BbrConfig, CubicConfig, NewRenoConfig}, - Endpoint, EndpointConfig, IdleTimeout, Incoming, ServerConfig as QuinnServerConfig, VarInt, -}; -use rustls::{ - version, Certificate, Error as RustlsError, PrivateKey, ServerConfig as RustlsServerConfig, -}; -use std::{io::Error as IoError, net::UdpSocket, sync::Arc, time::Duration}; -use thiserror::Error; - -#[derive(Debug)] -pub struct Server { - endpoint: Endpoint, - incoming: Incoming, -} - -impl Server { - pub fn bind(cfg: ServerConfig, socket: UdpSocket) -> Result { - let mut crypto = RustlsServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&version::TLS13]) - .unwrap() - .with_no_client_auth() - .with_single_cert(cfg.certificate_chain, cfg.private_key)?; - - if cfg.allow_quic_0rtt { - crypto.max_early_data_size = u32::MAX; - } - - crypto.alpn_protocols = cfg.alpn_protocols; - - let mut quinn_config = QuinnServerConfig::with_crypto(Arc::new(crypto)); - let transport = Arc::get_mut(&mut quinn_config.transport).unwrap(); - - let max_idle_timeout = cfg.max_idle_timeout.map(|timeout| { - IdleTimeout::try_from(timeout).unwrap_or_else(|_| IdleTimeout::from(VarInt::MAX)) - }); - - transport.max_idle_timeout(max_idle_timeout); - - match cfg.congestion_controller { - CongestionControl::Cubic => { - transport.congestion_controller_factory(Arc::new(CubicConfig::default())); - } - CongestionControl::NewReno => { - transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); - } - CongestionControl::Bbr => { - transport.congestion_controller_factory(Arc::new(BbrConfig::default())); - } - } - - let (endpoint, incoming) = - Endpoint::new(EndpointConfig::default(), Some(quinn_config), socket)?; - - Ok(Self { endpoint, incoming }) - } - - pub fn reconfigure(&mut self, cfg: ServerConfig) -> Result<(), ServerError> { - let mut crypto = RustlsServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&version::TLS13]) - .unwrap() - .with_no_client_auth() - .with_single_cert(cfg.certificate_chain, cfg.private_key)?; - - if cfg.allow_quic_0rtt { - crypto.max_early_data_size = u32::MAX; - } - - crypto.alpn_protocols = cfg.alpn_protocols; - - let mut quinn_config = QuinnServerConfig::with_crypto(Arc::new(crypto)); - let transport = Arc::get_mut(&mut quinn_config.transport).unwrap(); - - let max_idle_timeout = cfg.max_idle_timeout.map(|timeout| { - IdleTimeout::try_from(timeout).unwrap_or_else(|_| IdleTimeout::from(VarInt::MAX)) - }); - - transport.max_idle_timeout(max_idle_timeout); - - match cfg.congestion_controller { - CongestionControl::Cubic => { - transport.congestion_controller_factory(Arc::new(CubicConfig::default())); - } - CongestionControl::NewReno => { - transport.congestion_controller_factory(Arc::new(NewRenoConfig::default())); - } - CongestionControl::Bbr => { - transport.congestion_controller_factory(Arc::new(BbrConfig::default())); - } - } - - self.endpoint.set_server_config(Some(quinn_config)); - - Ok(()) - } - - pub fn rebind(&mut self, socket: UdpSocket) -> Result<(), ServerError> { - self.endpoint.rebind(socket)?; - Ok(()) - } - - pub async fn accept(&mut self) -> Option { - self.incoming.next().await.map(Connecting::new) - } -} - -#[derive(Clone, Debug)] -pub struct ServerConfig { - pub certificate_chain: Vec, - pub private_key: PrivateKey, - pub alpn_protocols: Vec>, - pub allow_quic_0rtt: bool, - pub max_idle_timeout: Option, - pub congestion_controller: CongestionControl, -} - -#[derive(Error, Debug)] -pub enum ServerError { - #[error(transparent)] - Io(#[from] IoError), - #[error(transparent)] - Certificate(#[from] RustlsError), -} diff --git a/client/Cargo.toml b/tuic-client/Cargo.toml similarity index 100% rename from client/Cargo.toml rename to tuic-client/Cargo.toml diff --git a/client/src/main.rs b/tuic-client/src/main.rs similarity index 100% rename from client/src/main.rs rename to tuic-client/src/main.rs diff --git a/tuic-quinn/Cargo.toml b/tuic-quinn/Cargo.toml new file mode 100644 index 0000000..d98ceed --- /dev/null +++ b/tuic-quinn/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "tuic-quinn" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs new file mode 100644 index 0000000..7d12d9a --- /dev/null +++ b/tuic-quinn/src/lib.rs @@ -0,0 +1,14 @@ +pub fn add(left: usize, right: usize) -> usize { + left + right +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn it_works() { + let result = add(2, 2); + assert_eq!(result, 4); + } +} diff --git a/server/Cargo.toml b/tuic-server/Cargo.toml similarity index 100% rename from server/Cargo.toml rename to tuic-server/Cargo.toml diff --git a/server/src/main.rs b/tuic-server/src/main.rs similarity index 100% rename from server/src/main.rs rename to tuic-server/src/main.rs diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml new file mode 100644 index 0000000..353d51d --- /dev/null +++ b/tuic/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "tuic" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/tuic/src/lib.rs b/tuic/src/lib.rs new file mode 100644 index 0000000..7d12d9a --- /dev/null +++ b/tuic/src/lib.rs @@ -0,0 +1,14 @@ +pub fn add(left: usize, right: usize) -> usize { + left + right +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn it_works() { + let result = add(2, 2); + assert_eq!(result, 4); + } +} From 8acd266f25d147205ead54ca483e5f6de90ce637 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 23 Jan 2023 18:54:53 +0900 Subject: [PATCH 038/103] protocol abstraction --- tuic/src/lib.rs | 15 +----- tuic/src/protocol/address.rs | 49 +++++++++++++++++++ tuic/src/protocol/authenticate.rs | 80 +++++++++++++++++++++++++++++++ tuic/src/protocol/connect.rs | 27 +++++++++++ tuic/src/protocol/dissociate.rs | 25 ++++++++++ tuic/src/protocol/heartbeat.rs | 23 +++++++++ tuic/src/protocol/mod.rs | 63 ++++++++++++++++++++++++ tuic/src/protocol/packet.rs | 46 ++++++++++++++++++ 8 files changed, 315 insertions(+), 13 deletions(-) create mode 100644 tuic/src/protocol/address.rs create mode 100644 tuic/src/protocol/authenticate.rs create mode 100644 tuic/src/protocol/connect.rs create mode 100644 tuic/src/protocol/dissociate.rs create mode 100644 tuic/src/protocol/heartbeat.rs create mode 100644 tuic/src/protocol/mod.rs create mode 100644 tuic/src/protocol/packet.rs diff --git a/tuic/src/lib.rs b/tuic/src/lib.rs index 7d12d9a..47bbb5c 100644 --- a/tuic/src/lib.rs +++ b/tuic/src/lib.rs @@ -1,14 +1,3 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right -} +//! The TUIC protocol -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } -} +pub mod protocol; diff --git a/tuic/src/protocol/address.rs b/tuic/src/protocol/address.rs new file mode 100644 index 0000000..0edd4ff --- /dev/null +++ b/tuic/src/protocol/address.rs @@ -0,0 +1,49 @@ +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + net::SocketAddr, +}; + +/// Address +/// +/// ```plain +/// +------+----------+----------+ +/// | TYPE | ADDR | PORT | +/// +------+----------+----------+ +/// | 1 | Variable | 2 | +/// +------+----------+----------+ +/// ``` +/// +/// The address type can be one of the following: +/// 0x00: fully-qualified domain name (the first byte indicates the length of the domain name) +/// 0x01: IPv4 address +/// 0x02: IPv6 address +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Address { + DomainAddress(String, u16), + SocketAddress(SocketAddr), +} + +impl Address { + pub const TYPE_DOMAIN: u8 = 0x00; + pub const TYPE_IPV4: u8 = 0x01; + pub const TYPE_IPV6: u8 = 0x02; + + pub fn len(&self) -> usize { + 1 + match self { + Address::DomainAddress(addr, _) => 1 + addr.len() + 2, + Address::SocketAddress(addr) => match addr { + SocketAddr::V4(_) => 6, + SocketAddr::V6(_) => 18, + }, + } + } +} + +impl Display for Address { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + Self::DomainAddress(addr, port) => write!(f, "{addr}:{port}"), + Self::SocketAddress(addr) => write!(f, "{addr}"), + } + } +} diff --git a/tuic/src/protocol/authenticate.rs b/tuic/src/protocol/authenticate.rs new file mode 100644 index 0000000..2c07ae4 --- /dev/null +++ b/tuic/src/protocol/authenticate.rs @@ -0,0 +1,80 @@ +// +--------+-----------+ +// | METHOD | OPT | +// +--------+-----------+ +// | 1 | Variable | +// +--------+-----------+ +#[derive(Clone, Debug)] +pub struct Authenticate +where + A: Method, +{ + pub method: A, +} + +impl Authenticate +where + A: Method, +{ + const CMD_TYPE: u8 = 0x00; + + pub fn new(method: A) -> Self { + Self { method } + } + + pub const fn cmd_type() -> u8 { + Self::CMD_TYPE + } + + pub fn len(&self) -> usize { + 1 + self.method.len() + } +} + +pub trait Method { + fn auth_type(&self) -> u8; + fn len(&self) -> usize; +} + +#[derive(Debug, Clone)] +pub struct None; + +impl None { + const AUTH_TYPE: u8 = 0x00; + + pub fn new() -> Self { + Self + } +} + +impl Method for None { + fn auth_type(&self) -> u8 { + Self::AUTH_TYPE + } + + fn len(&self) -> usize { + 0 + } +} + +#[derive(Debug, Clone)] +pub struct Blake3 { + pub token: [u8; 32], +} + +impl Blake3 { + const AUTH_TYPE: u8 = 0x01; + + pub fn new(token: [u8; 32]) -> Self { + Self { token } + } +} + +impl Method for Blake3 { + fn auth_type(&self) -> u8 { + Self::AUTH_TYPE + } + + fn len(&self) -> usize { + 32 + } +} \ No newline at end of file diff --git a/tuic/src/protocol/connect.rs b/tuic/src/protocol/connect.rs new file mode 100644 index 0000000..f781b11 --- /dev/null +++ b/tuic/src/protocol/connect.rs @@ -0,0 +1,27 @@ +use super::Address; + +// +----------+ +// | ADDR | +// +----------+ +// | Variable | +// +----------+ +#[derive(Clone, Debug)] +pub struct Connect { + pub addr: Address, +} + +impl Connect { + const CMD_TYPE: u8 = 0x01; + + pub fn new(addr: Address) -> Self { + Self { addr } + } + + pub const fn cmd_type() -> u8 { + Self::CMD_TYPE + } + + pub fn len(&self) -> usize { + self.addr.len() + } +} diff --git a/tuic/src/protocol/dissociate.rs b/tuic/src/protocol/dissociate.rs new file mode 100644 index 0000000..c200b78 --- /dev/null +++ b/tuic/src/protocol/dissociate.rs @@ -0,0 +1,25 @@ +// +----------+ +// | ASSOC_ID | +// +----------+ +// | 2 | +// +----------+ +#[derive(Clone, Debug)] +pub struct Dissociate { + pub assoc_id: u16, +} + +impl Dissociate { + const CMD_TYPE: u8 = 0x03; + + pub fn new(assoc_id: u16) -> Self { + Self { assoc_id } + } + + pub const fn cmd_type() -> u8 { + Self::CMD_TYPE + } + + pub fn len(&self) -> usize { + 2 + } +} diff --git a/tuic/src/protocol/heartbeat.rs b/tuic/src/protocol/heartbeat.rs new file mode 100644 index 0000000..dd396e9 --- /dev/null +++ b/tuic/src/protocol/heartbeat.rs @@ -0,0 +1,23 @@ +// +-+ +// | | +// +-+ +// | | +// +-+ +#[derive(Clone, Debug)] +pub struct Heartbeat; + +impl Heartbeat { + const CMD_TYPE: u8 = 0x04; + + pub fn new() -> Self { + Self + } + + pub const fn cmd_type() -> u8 { + Self::CMD_TYPE + } + + pub fn len(&self) -> usize { + 0 + } +} diff --git a/tuic/src/protocol/mod.rs b/tuic/src/protocol/mod.rs new file mode 100644 index 0000000..409eb6e --- /dev/null +++ b/tuic/src/protocol/mod.rs @@ -0,0 +1,63 @@ +use self::authenticate::Method as AuthenticationMethod; + +mod address; +mod connect; +mod dissociate; +mod heartbeat; +mod packet; + +pub mod authenticate; + +pub use self::{ + address::Address, authenticate::Authenticate, connect::Connect, dissociate::Dissociate, + heartbeat::Heartbeat, packet::Packet, +}; + +pub const VERSION: u8 = 0x05; + +/// Command +/// +/// ```plain +/// +-----+----------+----------+ +/// | VER | CMD_TYPE | OPT | +/// +-----+----------+----------+ +/// | 1 | 1 | Variable | +/// +-----+----------+----------+ +/// ``` +#[non_exhaustive] +#[derive(Clone, Debug)] +pub enum Command +where + A: AuthenticationMethod, +{ + Authenticate(Authenticate), + Connect(Connect), + Packet(Packet), + Dissociate(Dissociate), + Heartbeat(Heartbeat), +} + +impl Command +where + A: AuthenticationMethod, +{ + pub fn cmd_type(&self) -> u8 { + match self { + Self::Authenticate(_) => Authenticate::::cmd_type(), + Self::Connect(_) => Connect::cmd_type(), + Self::Packet(_) => Packet::cmd_type(), + Self::Dissociate(_) => Dissociate::cmd_type(), + Self::Heartbeat(_) => Heartbeat::cmd_type(), + } + } + + pub fn len(&self) -> usize { + 2 + match self { + Self::Authenticate(auth) => auth.len(), + Self::Connect(connect) => connect.len(), + Self::Packet(packet) => packet.len(), + Self::Dissociate(dissociate) => dissociate.len(), + Self::Heartbeat(heartbeat) => heartbeat.len(), + } + } +} diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs new file mode 100644 index 0000000..08525f8 --- /dev/null +++ b/tuic/src/protocol/packet.rs @@ -0,0 +1,46 @@ +use super::Address; + +// +----------+--------+------------+---------+-----+----------+ +// | ASSOC_ID | PKT_ID | FRAG_TOTAL | FRAG_ID | LEN | ADDR | +// +----------+--------+------------+---------+-----+----------+ +// | 2 | 2 | 1 | 1 | 2 | Variable | +// +----------+--------+------------+---------+-----+----------+ +#[derive(Clone, Debug)] +pub struct Packet { + pub assoc_id: u16, + pub pkt_id: u16, + pub frag_total: u8, + pub frag_id: u8, + pub len: u16, + pub addr: Option
, +} + +impl Packet { + const CMD_TYPE: u8 = 0x02; + + pub fn new( + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + len: u16, + addr: Option
, + ) -> Self { + Self { + assoc_id, + pkt_id, + frag_total, + frag_id, + len, + addr, + } + } + + pub const fn cmd_type() -> u8 { + Self::CMD_TYPE + } + + pub fn len(&self) -> usize { + 2 + 2 + 1 + 1 + 2 + self.addr.as_ref().map_or(0, |addr| addr.len()) + } +} From de07d6a1e1f7e01e098021133decfb1eae9deffe Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 24 Jan 2023 01:24:49 +0900 Subject: [PATCH 039/103] adding trait `Command` for header abstraction --- tuic/src/protocol/address.rs | 16 +++++++++++++--- tuic/src/protocol/authenticate.rs | 31 +++++++++++++++++++------------ tuic/src/protocol/connect.rs | 12 +++++++----- tuic/src/protocol/dissociate.rs | 12 ++++++++---- tuic/src/protocol/heartbeat.rs | 12 ++++++++---- tuic/src/protocol/mod.rs | 31 +++++++++++++++++++++---------- tuic/src/protocol/packet.rs | 12 +++++++----- 7 files changed, 83 insertions(+), 43 deletions(-) diff --git a/tuic/src/protocol/address.rs b/tuic/src/protocol/address.rs index 0edd4ff..3244719 100644 --- a/tuic/src/protocol/address.rs +++ b/tuic/src/protocol/address.rs @@ -24,9 +24,19 @@ pub enum Address { } impl Address { - pub const TYPE_DOMAIN: u8 = 0x00; - pub const TYPE_IPV4: u8 = 0x01; - pub const TYPE_IPV6: u8 = 0x02; + pub const TYPE_CODE_DOMAIN: u8 = 0x00; + pub const TYPE_CODE_IPV4: u8 = 0x01; + pub const TYPE_CODE_IPV6: u8 = 0x02; + + pub fn type_code(&self) -> u8 { + match self { + Self::DomainAddress(_, _) => Self::TYPE_CODE_DOMAIN, + Self::SocketAddress(addr) => match addr { + SocketAddr::V4(_) => Self::TYPE_CODE_IPV4, + SocketAddr::V6(_) => Self::TYPE_CODE_IPV6, + }, + } + } pub fn len(&self) -> usize { 1 + match self { diff --git a/tuic/src/protocol/authenticate.rs b/tuic/src/protocol/authenticate.rs index 2c07ae4..c7919c2 100644 --- a/tuic/src/protocol/authenticate.rs +++ b/tuic/src/protocol/authenticate.rs @@ -1,3 +1,5 @@ +use super::Command; + // +--------+-----------+ // | METHOD | OPT | // +--------+-----------+ @@ -15,23 +17,28 @@ impl Authenticate where A: Method, { - const CMD_TYPE: u8 = 0x00; + pub(super) const TYPE_CODE: u8 = 0x00; pub fn new(method: A) -> Self { Self { method } } +} - pub const fn cmd_type() -> u8 { - Self::CMD_TYPE +impl Command for Authenticate +where + A: Method, +{ + fn type_code() -> u8 { + Self::TYPE_CODE } - pub fn len(&self) -> usize { + fn len(&self) -> usize { 1 + self.method.len() } } pub trait Method { - fn auth_type(&self) -> u8; + fn type_code(&self) -> u8; fn len(&self) -> usize; } @@ -39,7 +46,7 @@ pub trait Method { pub struct None; impl None { - const AUTH_TYPE: u8 = 0x00; + const TYPE_CODE: u8 = 0x00; pub fn new() -> Self { Self @@ -47,8 +54,8 @@ impl None { } impl Method for None { - fn auth_type(&self) -> u8 { - Self::AUTH_TYPE + fn type_code(&self) -> u8 { + Self::TYPE_CODE } fn len(&self) -> usize { @@ -62,7 +69,7 @@ pub struct Blake3 { } impl Blake3 { - const AUTH_TYPE: u8 = 0x01; + const TYPE_CODE: u8 = 0x01; pub fn new(token: [u8; 32]) -> Self { Self { token } @@ -70,11 +77,11 @@ impl Blake3 { } impl Method for Blake3 { - fn auth_type(&self) -> u8 { - Self::AUTH_TYPE + fn type_code(&self) -> u8 { + Self::TYPE_CODE } fn len(&self) -> usize { 32 } -} \ No newline at end of file +} diff --git a/tuic/src/protocol/connect.rs b/tuic/src/protocol/connect.rs index f781b11..286eaf7 100644 --- a/tuic/src/protocol/connect.rs +++ b/tuic/src/protocol/connect.rs @@ -1,4 +1,4 @@ -use super::Address; +use super::{Address, Command}; // +----------+ // | ADDR | @@ -11,17 +11,19 @@ pub struct Connect { } impl Connect { - const CMD_TYPE: u8 = 0x01; + pub(super) const TYPE_CODE: u8 = 0x01; pub fn new(addr: Address) -> Self { Self { addr } } +} - pub const fn cmd_type() -> u8 { - Self::CMD_TYPE +impl Command for Connect { + fn type_code() -> u8 { + Self::TYPE_CODE } - pub fn len(&self) -> usize { + fn len(&self) -> usize { self.addr.len() } } diff --git a/tuic/src/protocol/dissociate.rs b/tuic/src/protocol/dissociate.rs index c200b78..f443f06 100644 --- a/tuic/src/protocol/dissociate.rs +++ b/tuic/src/protocol/dissociate.rs @@ -1,3 +1,5 @@ +use super::Command; + // +----------+ // | ASSOC_ID | // +----------+ @@ -9,17 +11,19 @@ pub struct Dissociate { } impl Dissociate { - const CMD_TYPE: u8 = 0x03; + pub const TYPE_CODE: u8 = 0x03; pub fn new(assoc_id: u16) -> Self { Self { assoc_id } } +} - pub const fn cmd_type() -> u8 { - Self::CMD_TYPE +impl Command for Dissociate { + fn type_code() -> u8 { + Self::TYPE_CODE } - pub fn len(&self) -> usize { + fn len(&self) -> usize { 2 } } diff --git a/tuic/src/protocol/heartbeat.rs b/tuic/src/protocol/heartbeat.rs index dd396e9..efb7e52 100644 --- a/tuic/src/protocol/heartbeat.rs +++ b/tuic/src/protocol/heartbeat.rs @@ -1,3 +1,5 @@ +use super::Command; + // +-+ // | | // +-+ @@ -7,17 +9,19 @@ pub struct Heartbeat; impl Heartbeat { - const CMD_TYPE: u8 = 0x04; + pub const TYPE_CODE: u8 = 0x04; pub fn new() -> Self { Self } +} - pub const fn cmd_type() -> u8 { - Self::CMD_TYPE +impl Command for Heartbeat { + fn type_code() -> u8 { + Self::TYPE_CODE } - pub fn len(&self) -> usize { + fn len(&self) -> usize { 0 } } diff --git a/tuic/src/protocol/mod.rs b/tuic/src/protocol/mod.rs index 409eb6e..37f8c89 100644 --- a/tuic/src/protocol/mod.rs +++ b/tuic/src/protocol/mod.rs @@ -15,18 +15,18 @@ pub use self::{ pub const VERSION: u8 = 0x05; -/// Command +/// Header /// /// ```plain /// +-----+----------+----------+ -/// | VER | CMD_TYPE | OPT | +/// | VER | TYPE | OPT | /// +-----+----------+----------+ /// | 1 | 1 | Variable | /// +-----+----------+----------+ /// ``` #[non_exhaustive] #[derive(Clone, Debug)] -pub enum Command +pub enum Header where A: AuthenticationMethod, { @@ -37,17 +37,23 @@ where Heartbeat(Heartbeat), } -impl Command +impl Header where A: AuthenticationMethod, { - pub fn cmd_type(&self) -> u8 { + pub const TYPE_CODE_AUTHENTICATE: u8 = Authenticate::::TYPE_CODE; + pub const TYPE_CODE_CONNECT: u8 = Connect::TYPE_CODE; + pub const TYPE_CODE_PACKET: u8 = Packet::TYPE_CODE; + pub const TYPE_CODE_DISSOCIATE: u8 = Dissociate::TYPE_CODE; + pub const TYPE_CODE_HEARTBEAT: u8 = Heartbeat::TYPE_CODE; + + pub fn type_code(&self) -> u8 { match self { - Self::Authenticate(_) => Authenticate::::cmd_type(), - Self::Connect(_) => Connect::cmd_type(), - Self::Packet(_) => Packet::cmd_type(), - Self::Dissociate(_) => Dissociate::cmd_type(), - Self::Heartbeat(_) => Heartbeat::cmd_type(), + Self::Authenticate(_) => Authenticate::::type_code(), + Self::Connect(_) => Connect::type_code(), + Self::Packet(_) => Packet::type_code(), + Self::Dissociate(_) => Dissociate::type_code(), + Self::Heartbeat(_) => Heartbeat::type_code(), } } @@ -61,3 +67,8 @@ where } } } + +pub trait Command { + fn type_code() -> u8; + fn len(&self) -> usize; +} diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs index 08525f8..5e8d7f7 100644 --- a/tuic/src/protocol/packet.rs +++ b/tuic/src/protocol/packet.rs @@ -1,4 +1,4 @@ -use super::Address; +use super::{Address, Command}; // +----------+--------+------------+---------+-----+----------+ // | ASSOC_ID | PKT_ID | FRAG_TOTAL | FRAG_ID | LEN | ADDR | @@ -16,7 +16,7 @@ pub struct Packet { } impl Packet { - const CMD_TYPE: u8 = 0x02; + pub(super) const TYPE_CODE: u8 = 0x02; pub fn new( assoc_id: u16, @@ -35,12 +35,14 @@ impl Packet { addr, } } +} - pub const fn cmd_type() -> u8 { - Self::CMD_TYPE +impl Command for Packet { + fn type_code() -> u8 { + Self::TYPE_CODE } - pub fn len(&self) -> usize { + fn len(&self) -> usize { 2 + 2 + 1 + 1 + 2 + self.addr.as_ref().map_or(0, |addr| addr.len()) } } From 00c10926c5cc99dedd845ebbbde3b15b807c6330 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 24 Jan 2023 02:55:43 +0900 Subject: [PATCH 040/103] using TLS Keying Material Exporters for token --- tuic/src/protocol/authenticate.rs | 82 +++++-------------------------- tuic/src/protocol/mod.rs | 18 ++----- 2 files changed, 17 insertions(+), 83 deletions(-) diff --git a/tuic/src/protocol/authenticate.rs b/tuic/src/protocol/authenticate.rs index c7919c2..9f1a781 100644 --- a/tuic/src/protocol/authenticate.rs +++ b/tuic/src/protocol/authenticate.rs @@ -1,87 +1,29 @@ use super::Command; -// +--------+-----------+ -// | METHOD | OPT | -// +--------+-----------+ -// | 1 | Variable | -// +--------+-----------+ +// +-------+ +// | TOKEN | +// +-------+ +// | 8 | +// +-------+ #[derive(Clone, Debug)] -pub struct Authenticate -where - A: Method, -{ - pub method: A, +pub struct Authenticate { + pub token: [u8; 8], } -impl Authenticate -where - A: Method, -{ +impl Authenticate { pub(super) const TYPE_CODE: u8 = 0x00; - pub fn new(method: A) -> Self { - Self { method } + pub fn new(token: [u8; 8]) -> Self { + Self { token } } } -impl Command for Authenticate -where - A: Method, -{ +impl Command for Authenticate { fn type_code() -> u8 { Self::TYPE_CODE } fn len(&self) -> usize { - 1 + self.method.len() - } -} - -pub trait Method { - fn type_code(&self) -> u8; - fn len(&self) -> usize; -} - -#[derive(Debug, Clone)] -pub struct None; - -impl None { - const TYPE_CODE: u8 = 0x00; - - pub fn new() -> Self { - Self - } -} - -impl Method for None { - fn type_code(&self) -> u8 { - Self::TYPE_CODE - } - - fn len(&self) -> usize { - 0 - } -} - -#[derive(Debug, Clone)] -pub struct Blake3 { - pub token: [u8; 32], -} - -impl Blake3 { - const TYPE_CODE: u8 = 0x01; - - pub fn new(token: [u8; 32]) -> Self { - Self { token } - } -} - -impl Method for Blake3 { - fn type_code(&self) -> u8 { - Self::TYPE_CODE - } - - fn len(&self) -> usize { - 32 + 8 } } diff --git a/tuic/src/protocol/mod.rs b/tuic/src/protocol/mod.rs index 37f8c89..406ebc3 100644 --- a/tuic/src/protocol/mod.rs +++ b/tuic/src/protocol/mod.rs @@ -1,5 +1,3 @@ -use self::authenticate::Method as AuthenticationMethod; - mod address; mod connect; mod dissociate; @@ -26,22 +24,16 @@ pub const VERSION: u8 = 0x05; /// ``` #[non_exhaustive] #[derive(Clone, Debug)] -pub enum Header -where - A: AuthenticationMethod, -{ - Authenticate(Authenticate), +pub enum Header { + Authenticate(Authenticate), Connect(Connect), Packet(Packet), Dissociate(Dissociate), Heartbeat(Heartbeat), } -impl Header -where - A: AuthenticationMethod, -{ - pub const TYPE_CODE_AUTHENTICATE: u8 = Authenticate::::TYPE_CODE; +impl Header { + pub const TYPE_CODE_AUTHENTICATE: u8 = Authenticate::TYPE_CODE; pub const TYPE_CODE_CONNECT: u8 = Connect::TYPE_CODE; pub const TYPE_CODE_PACKET: u8 = Packet::TYPE_CODE; pub const TYPE_CODE_DISSOCIATE: u8 = Dissociate::TYPE_CODE; @@ -49,7 +41,7 @@ where pub fn type_code(&self) -> u8 { match self { - Self::Authenticate(_) => Authenticate::::type_code(), + Self::Authenticate(_) => Authenticate::type_code(), Self::Connect(_) => Connect::type_code(), Self::Packet(_) => Packet::type_code(), Self::Dissociate(_) => Dissociate::type_code(), From 595ba01e2ddfa3cec55aaa76a4c710af5b9c4bd1 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 24 Jan 2023 04:13:33 +0900 Subject: [PATCH 041/103] adding prototype abstraction --- tuic/Cargo.toml | 7 +- tuic/src/lib.rs | 3 + tuic/src/prototype/authenticate.rs | 16 ++++ tuic/src/prototype/connect.rs | 16 ++++ tuic/src/prototype/dissociate.rs | 16 ++++ tuic/src/prototype/heartbeat.rs | 13 +++ tuic/src/prototype/mod.rs | 131 +++++++++++++++++++++++++++++ tuic/src/prototype/packet.rs | 19 +++++ 8 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 tuic/src/prototype/authenticate.rs create mode 100644 tuic/src/prototype/connect.rs create mode 100644 tuic/src/prototype/dissociate.rs create mode 100644 tuic/src/prototype/heartbeat.rs create mode 100644 tuic/src/prototype/mod.rs create mode 100644 tuic/src/prototype/packet.rs diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index 353d51d..cdefb9c 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -3,6 +3,11 @@ name = "tuic" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +prototype = ["parking_lot"] [dependencies] +parking_lot = { version = "0.12.1", default-features = false, optional = true } + +[dev-dependencies] +tuic = { path = ".", features = ["prototype"] } diff --git a/tuic/src/lib.rs b/tuic/src/lib.rs index 47bbb5c..de2d745 100644 --- a/tuic/src/lib.rs +++ b/tuic/src/lib.rs @@ -1,3 +1,6 @@ //! The TUIC protocol pub mod protocol; + +#[cfg(feature = "prototype")] +pub mod prototype; diff --git a/tuic/src/prototype/authenticate.rs b/tuic/src/prototype/authenticate.rs new file mode 100644 index 0000000..06b4a94 --- /dev/null +++ b/tuic/src/prototype/authenticate.rs @@ -0,0 +1,16 @@ +use super::TaskRegister; +use crate::protocol::{Authenticate as AuthenticateHeader, Header}; + +pub struct Authenticate { + header: Header, + _task_reg: TaskRegister, +} + +impl Authenticate { + pub(super) fn new(task_reg: TaskRegister, token: [u8; 8]) -> Self { + Self { + header: Header::Authenticate(AuthenticateHeader::new(token)), + _task_reg: task_reg, + } + } +} diff --git a/tuic/src/prototype/connect.rs b/tuic/src/prototype/connect.rs new file mode 100644 index 0000000..7db35aa --- /dev/null +++ b/tuic/src/prototype/connect.rs @@ -0,0 +1,16 @@ +use super::TaskRegister; +use crate::protocol::{Address, Connect as ConnectHeader, Header}; + +pub struct Connect { + header: Header, + _task_reg: TaskRegister, +} + +impl Connect { + pub(super) fn new(task_reg: TaskRegister, addr: Address) -> Self { + Self { + header: Header::Connect(ConnectHeader::new(addr)), + _task_reg: task_reg, + } + } +} diff --git a/tuic/src/prototype/dissociate.rs b/tuic/src/prototype/dissociate.rs new file mode 100644 index 0000000..e34c576 --- /dev/null +++ b/tuic/src/prototype/dissociate.rs @@ -0,0 +1,16 @@ +use super::TaskRegister; +use crate::protocol::{Dissociate as DissociateHeader, Header}; + +pub struct Dissociate { + header: Header, + _task_reg: TaskRegister, +} + +impl Dissociate { + pub(super) fn new(task_reg: TaskRegister, assoc_id: u16) -> Self { + Self { + header: Header::Dissociate(DissociateHeader::new(assoc_id)), + _task_reg: task_reg, + } + } +} diff --git a/tuic/src/prototype/heartbeat.rs b/tuic/src/prototype/heartbeat.rs new file mode 100644 index 0000000..c20b0e2 --- /dev/null +++ b/tuic/src/prototype/heartbeat.rs @@ -0,0 +1,13 @@ +use crate::protocol::{Header, Heartbeat as HeartbeatHeader}; + +pub struct Heartbeat { + header: Header, +} + +impl Heartbeat { + pub(super) fn new() -> Self { + Self { + header: Header::Heartbeat(HeartbeatHeader::new()), + } + } +} diff --git a/tuic/src/prototype/mod.rs b/tuic/src/prototype/mod.rs new file mode 100644 index 0000000..49cda85 --- /dev/null +++ b/tuic/src/prototype/mod.rs @@ -0,0 +1,131 @@ +use crate::protocol::Address; +use parking_lot::Mutex; +use std::{ + collections::{hash_map::Entry, HashMap}, + sync::{Arc, Weak}, +}; + +mod authenticate; +mod connect; +mod dissociate; +mod heartbeat; +mod packet; + +pub use self::{ + authenticate::Authenticate, connect::Connect, dissociate::Dissociate, heartbeat::Heartbeat, + packet::Packet, +}; + +pub struct Connection { + udp_sessions: Mutex, + local_active_task_count: ActiveTaskCount, +} + +impl Connection { + pub fn new() -> Self { + let local_active_task_count = ActiveTaskCount::new(); + + Self { + udp_sessions: Mutex::new(UdpSessions::new(local_active_task_count.clone())), + local_active_task_count, + } + } + + pub fn authenticate(&self, token: [u8; 8]) -> Authenticate { + Authenticate::new(self.local_active_task_count.reg(), token) + } + + pub fn connect(&self, addr: Address) -> Connect { + Connect::new(self.local_active_task_count.reg(), addr) + } + + pub fn packet<'a>( + &self, + assoc_id: u16, + addr: Address, + payload: &'a [u8], + frag_len: usize, + ) -> Packet<'a> { + self.udp_sessions + .lock() + .send(assoc_id, addr, payload, frag_len) + } + + pub fn dissociate(&self, assoc_id: u16) -> Dissociate { + self.udp_sessions.lock().dissociate(assoc_id) + } + + pub fn heartbeat(&self) -> Heartbeat { + Heartbeat::new() + } + + pub fn local_active_task_count(&self) -> usize { + self.local_active_task_count.get() + } +} + +#[derive(Clone)] +struct ActiveTaskCount(Arc<()>); +struct TaskRegister(Weak<()>); + +impl ActiveTaskCount { + fn new() -> Self { + Self(Arc::new(())) + } + + fn reg(&self) -> TaskRegister { + TaskRegister(Arc::downgrade(&self.0)) + } + + fn get(&self) -> usize { + Arc::weak_count(&self.0) + } +} + +struct UdpSessions { + sessions: HashMap, + local_active_task_count: ActiveTaskCount, +} + +impl UdpSessions { + fn new(local_active_task_count: ActiveTaskCount) -> Self { + Self { + sessions: HashMap::new(), + local_active_task_count, + } + } + + fn send<'a>( + &mut self, + assoc_id: u16, + addr: Address, + payload: &'a [u8], + frag_len: usize, + ) -> Packet<'a> { + match self.sessions.entry(assoc_id) { + Entry::Occupied(_) => {} + Entry::Vacant(entry) => { + entry.insert(UdpSession::new(self.local_active_task_count.reg())); + } + } + + Packet::new(assoc_id, addr, payload, frag_len) + } + + fn dissociate(&mut self, assoc_id: u16) -> Dissociate { + self.sessions.remove(&assoc_id); + Dissociate::new(self.local_active_task_count.reg(), assoc_id) + } +} + +struct UdpSession { + _task_reg: TaskRegister, +} + +impl UdpSession { + fn new(task_reg: TaskRegister) -> Self { + Self { + _task_reg: task_reg, + } + } +} diff --git a/tuic/src/prototype/packet.rs b/tuic/src/prototype/packet.rs new file mode 100644 index 0000000..3e262eb --- /dev/null +++ b/tuic/src/prototype/packet.rs @@ -0,0 +1,19 @@ +use crate::protocol::Address; + +pub struct Packet<'a> { + assoc_id: u16, + addr: Address, + payload: &'a [u8], + frag_len: usize, +} + +impl<'a> Packet<'a> { + pub(super) fn new(assoc_id: u16, addr: Address, payload: &'a [u8], frag_len: usize) -> Self { + Self { + assoc_id, + addr, + payload, + frag_len, + } + } +} From 3caa232f13c34e3dec744ee18fcc3c58354d5794 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 24 Jan 2023 05:00:10 +0900 Subject: [PATCH 042/103] adding `Address::None` --- tuic/src/protocol/address.rs | 51 +++++++++++++++++++++++++++--------- tuic/src/protocol/packet.rs | 6 ++--- 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/tuic/src/protocol/address.rs b/tuic/src/protocol/address.rs index 3244719..da3a7f4 100644 --- a/tuic/src/protocol/address.rs +++ b/tuic/src/protocol/address.rs @@ -6,30 +6,37 @@ use std::{ /// Address /// /// ```plain -/// +------+----------+----------+ -/// | TYPE | ADDR | PORT | -/// +------+----------+----------+ -/// | 1 | Variable | 2 | -/// +------+----------+----------+ +/// +------+----------+ +/// | TYPE | ADDR | +/// +------+----------+ +/// | 1 | Variable | +/// +------+----------+ /// ``` /// /// The address type can be one of the following: -/// 0x00: fully-qualified domain name (the first byte indicates the length of the domain name) -/// 0x01: IPv4 address -/// 0x02: IPv6 address +/// +/// 0x00: None +/// 0x01: Fully-qualified domain name (the first byte indicates the length of the domain name) +/// 0x02: IPv4 address +/// 0x03: IPv6 address +/// +/// The port number is encoded in 2 bytes after the Domain name / IP address. #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub enum Address { + None, DomainAddress(String, u16), SocketAddress(SocketAddr), } impl Address { - pub const TYPE_CODE_DOMAIN: u8 = 0x00; - pub const TYPE_CODE_IPV4: u8 = 0x01; - pub const TYPE_CODE_IPV6: u8 = 0x02; + pub const TYPE_CODE_NONE: u8 = 0x00; + pub const TYPE_CODE_DOMAIN: u8 = 0x01; + pub const TYPE_CODE_IPV4: u8 = 0x02; + pub const TYPE_CODE_IPV6: u8 = 0x03; pub fn type_code(&self) -> u8 { match self { + Self::None => Self::TYPE_CODE_NONE, Self::DomainAddress(_, _) => Self::TYPE_CODE_DOMAIN, Self::SocketAddress(addr) => match addr { SocketAddr::V4(_) => Self::TYPE_CODE_IPV4, @@ -40,18 +47,36 @@ impl Address { pub fn len(&self) -> usize { 1 + match self { + Address::None => 0, Address::DomainAddress(addr, _) => 1 + addr.len() + 2, Address::SocketAddress(addr) => match addr { - SocketAddr::V4(_) => 6, - SocketAddr::V6(_) => 18, + SocketAddr::V4(_) => 1 * 4 + 2, + SocketAddr::V6(_) => 2 * 8 + 2, }, } } + + pub fn is_none(&self) -> bool { + matches!(self, Self::None) + } + + pub fn is_domain(&self) -> bool { + matches!(self, Self::DomainAddress(_, _)) + } + + pub fn is_ipv4(&self) -> bool { + matches!(self, Self::SocketAddress(SocketAddr::V4(_))) + } + + pub fn is_ipv6(&self) -> bool { + matches!(self, Self::SocketAddress(SocketAddr::V6(_))) + } } impl Display for Address { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { match self { + Self::None => write!(f, "none"), Self::DomainAddress(addr, port) => write!(f, "{addr}:{port}"), Self::SocketAddress(addr) => write!(f, "{addr}"), } diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs index 5e8d7f7..68af7a9 100644 --- a/tuic/src/protocol/packet.rs +++ b/tuic/src/protocol/packet.rs @@ -12,7 +12,7 @@ pub struct Packet { pub frag_total: u8, pub frag_id: u8, pub len: u16, - pub addr: Option
, + pub addr: Address, } impl Packet { @@ -24,7 +24,7 @@ impl Packet { frag_total: u8, frag_id: u8, len: u16, - addr: Option
, + addr: Address, ) -> Self { Self { assoc_id, @@ -43,6 +43,6 @@ impl Command for Packet { } fn len(&self) -> usize { - 2 + 2 + 1 + 1 + 2 + self.addr.as_ref().map_or(0, |addr| addr.len()) + 2 + 2 + 1 + 1 + 2 + self.addr.len() } } From f7e29260513decfb2539c46b511ebc801b611b4d Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 24 Jan 2023 14:44:02 +0900 Subject: [PATCH 043/103] const-ifying `Command::new()` --- tuic/src/protocol/address.rs | 16 ++++++++-------- tuic/src/protocol/authenticate.rs | 2 +- tuic/src/protocol/connect.rs | 2 +- tuic/src/protocol/dissociate.rs | 4 ++-- tuic/src/protocol/heartbeat.rs | 4 ++-- tuic/src/protocol/packet.rs | 2 +- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tuic/src/protocol/address.rs b/tuic/src/protocol/address.rs index da3a7f4..717629b 100644 --- a/tuic/src/protocol/address.rs +++ b/tuic/src/protocol/address.rs @@ -15,10 +15,10 @@ use std::{ /// /// The address type can be one of the following: /// -/// 0x00: None -/// 0x01: Fully-qualified domain name (the first byte indicates the length of the domain name) -/// 0x02: IPv4 address -/// 0x03: IPv6 address +/// 0xff: None +/// 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name) +/// 0x01: IPv4 address +/// 0x02: IPv6 address /// /// The port number is encoded in 2 bytes after the Domain name / IP address. #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] @@ -29,10 +29,10 @@ pub enum Address { } impl Address { - pub const TYPE_CODE_NONE: u8 = 0x00; - pub const TYPE_CODE_DOMAIN: u8 = 0x01; - pub const TYPE_CODE_IPV4: u8 = 0x02; - pub const TYPE_CODE_IPV6: u8 = 0x03; + pub const TYPE_CODE_NONE: u8 = 0xff; + pub const TYPE_CODE_DOMAIN: u8 = 0x00; + pub const TYPE_CODE_IPV4: u8 = 0x01; + pub const TYPE_CODE_IPV6: u8 = 0x02; pub fn type_code(&self) -> u8 { match self { diff --git a/tuic/src/protocol/authenticate.rs b/tuic/src/protocol/authenticate.rs index 9f1a781..988230d 100644 --- a/tuic/src/protocol/authenticate.rs +++ b/tuic/src/protocol/authenticate.rs @@ -13,7 +13,7 @@ pub struct Authenticate { impl Authenticate { pub(super) const TYPE_CODE: u8 = 0x00; - pub fn new(token: [u8; 8]) -> Self { + pub const fn new(token: [u8; 8]) -> Self { Self { token } } } diff --git a/tuic/src/protocol/connect.rs b/tuic/src/protocol/connect.rs index 286eaf7..db57013 100644 --- a/tuic/src/protocol/connect.rs +++ b/tuic/src/protocol/connect.rs @@ -13,7 +13,7 @@ pub struct Connect { impl Connect { pub(super) const TYPE_CODE: u8 = 0x01; - pub fn new(addr: Address) -> Self { + pub const fn new(addr: Address) -> Self { Self { addr } } } diff --git a/tuic/src/protocol/dissociate.rs b/tuic/src/protocol/dissociate.rs index f443f06..d931d13 100644 --- a/tuic/src/protocol/dissociate.rs +++ b/tuic/src/protocol/dissociate.rs @@ -11,9 +11,9 @@ pub struct Dissociate { } impl Dissociate { - pub const TYPE_CODE: u8 = 0x03; + pub(super) const TYPE_CODE: u8 = 0x03; - pub fn new(assoc_id: u16) -> Self { + pub const fn new(assoc_id: u16) -> Self { Self { assoc_id } } } diff --git a/tuic/src/protocol/heartbeat.rs b/tuic/src/protocol/heartbeat.rs index efb7e52..7b03ad1 100644 --- a/tuic/src/protocol/heartbeat.rs +++ b/tuic/src/protocol/heartbeat.rs @@ -9,9 +9,9 @@ use super::Command; pub struct Heartbeat; impl Heartbeat { - pub const TYPE_CODE: u8 = 0x04; + pub(super) const TYPE_CODE: u8 = 0x04; - pub fn new() -> Self { + pub const fn new() -> Self { Self } } diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs index 68af7a9..83e4504 100644 --- a/tuic/src/protocol/packet.rs +++ b/tuic/src/protocol/packet.rs @@ -18,7 +18,7 @@ pub struct Packet { impl Packet { pub(super) const TYPE_CODE: u8 = 0x02; - pub fn new( + pub const fn new( assoc_id: u16, pkt_id: u16, frag_total: u8, From 99e48ca276b52d7bb057687a5294151e213ae35d Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 24 Jan 2023 15:46:41 +0900 Subject: [PATCH 044/103] adding UDP packet fragmentation --- tuic/src/protocol/address.rs | 11 +++++ tuic/src/protocol/packet.rs | 4 ++ tuic/src/prototype/authenticate.rs | 4 ++ tuic/src/prototype/connect.rs | 4 ++ tuic/src/prototype/dissociate.rs | 4 ++ tuic/src/prototype/heartbeat.rs | 4 ++ tuic/src/prototype/mod.rs | 37 +++++++++++----- tuic/src/prototype/packet.rs | 68 ++++++++++++++++++++++++++++-- 8 files changed, 122 insertions(+), 14 deletions(-) diff --git a/tuic/src/protocol/address.rs b/tuic/src/protocol/address.rs index 717629b..a45ac40 100644 --- a/tuic/src/protocol/address.rs +++ b/tuic/src/protocol/address.rs @@ -1,5 +1,6 @@ use std::{ fmt::{Display, Formatter, Result as FmtResult}, + mem, net::SocketAddr, }; @@ -56,6 +57,10 @@ impl Address { } } + pub fn take(&mut self) -> Self { + mem::take(self) + } + pub fn is_none(&self) -> bool { matches!(self, Self::None) } @@ -82,3 +87,9 @@ impl Display for Address { } } } + +impl Default for Address { + fn default() -> Self { + Self::None + } +} diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs index 83e4504..62a8219 100644 --- a/tuic/src/protocol/packet.rs +++ b/tuic/src/protocol/packet.rs @@ -35,6 +35,10 @@ impl Packet { addr, } } + + pub const fn len_without_addr() -> usize { + 2 + 2 + 1 + 1 + 2 + } } impl Command for Packet { diff --git a/tuic/src/prototype/authenticate.rs b/tuic/src/prototype/authenticate.rs index 06b4a94..fbccf57 100644 --- a/tuic/src/prototype/authenticate.rs +++ b/tuic/src/prototype/authenticate.rs @@ -13,4 +13,8 @@ impl Authenticate { _task_reg: task_reg, } } + + pub fn header(&self) -> &Header { + &self.header + } } diff --git a/tuic/src/prototype/connect.rs b/tuic/src/prototype/connect.rs index 7db35aa..fddac8e 100644 --- a/tuic/src/prototype/connect.rs +++ b/tuic/src/prototype/connect.rs @@ -13,4 +13,8 @@ impl Connect { _task_reg: task_reg, } } + + pub fn header(&self) -> &Header { + &self.header + } } diff --git a/tuic/src/prototype/dissociate.rs b/tuic/src/prototype/dissociate.rs index e34c576..abf5cef 100644 --- a/tuic/src/prototype/dissociate.rs +++ b/tuic/src/prototype/dissociate.rs @@ -13,4 +13,8 @@ impl Dissociate { _task_reg: task_reg, } } + + pub fn header(&self) -> &Header { + &self.header + } } diff --git a/tuic/src/prototype/heartbeat.rs b/tuic/src/prototype/heartbeat.rs index c20b0e2..369a7ec 100644 --- a/tuic/src/prototype/heartbeat.rs +++ b/tuic/src/prototype/heartbeat.rs @@ -10,4 +10,8 @@ impl Heartbeat { header: Header::Heartbeat(HeartbeatHeader::new()), } } + + pub fn header(&self) -> &Header { + &self.header + } } diff --git a/tuic/src/prototype/mod.rs b/tuic/src/prototype/mod.rs index 49cda85..b96f414 100644 --- a/tuic/src/prototype/mod.rs +++ b/tuic/src/prototype/mod.rs @@ -1,8 +1,11 @@ use crate::protocol::Address; use parking_lot::Mutex; use std::{ - collections::{hash_map::Entry, HashMap}, - sync::{Arc, Weak}, + collections::HashMap, + sync::{ + atomic::{AtomicU16, Ordering}, + Arc, Weak, + }, }; mod authenticate; @@ -102,14 +105,10 @@ impl UdpSessions { payload: &'a [u8], frag_len: usize, ) -> Packet<'a> { - match self.sessions.entry(assoc_id) { - Entry::Occupied(_) => {} - Entry::Vacant(entry) => { - entry.insert(UdpSession::new(self.local_active_task_count.reg())); - } - } - - Packet::new(assoc_id, addr, payload, frag_len) + self.sessions + .entry(assoc_id) + .or_insert_with(|| UdpSession::new(self.local_active_task_count.reg())) + .send(assoc_id, addr, payload, frag_len) } fn dissociate(&mut self, assoc_id: u16) -> Dissociate { @@ -119,13 +118,31 @@ impl UdpSessions { } struct UdpSession { + next_pkt_id: AtomicU16, _task_reg: TaskRegister, } impl UdpSession { fn new(task_reg: TaskRegister) -> Self { Self { + next_pkt_id: AtomicU16::new(0), _task_reg: task_reg, } } + + fn send<'a>( + &self, + assoc_id: u16, + addr: Address, + payload: &'a [u8], + frag_len: usize, + ) -> Packet<'a> { + Packet::new( + assoc_id, + self.next_pkt_id.fetch_add(1, Ordering::AcqRel), + addr, + payload, + frag_len, + ) + } } diff --git a/tuic/src/prototype/packet.rs b/tuic/src/prototype/packet.rs index 3e262eb..ccd17fc 100644 --- a/tuic/src/prototype/packet.rs +++ b/tuic/src/prototype/packet.rs @@ -1,19 +1,79 @@ -use crate::protocol::Address; +use crate::protocol::{Address, Header, Packet as PacketHeader}; pub struct Packet<'a> { assoc_id: u16, + pkt_id: u16, addr: Address, payload: &'a [u8], - frag_len: usize, + max_pkt_size: usize, + frag_total: u8, + next_frag_id: u8, + next_frag_start: usize, } impl<'a> Packet<'a> { - pub(super) fn new(assoc_id: u16, addr: Address, payload: &'a [u8], frag_len: usize) -> Self { + pub(super) fn new( + assoc_id: u16, + pkt_id: u16, + addr: Address, + payload: &'a [u8], + max_pkt_size: usize, + ) -> Self { + let first_frag_size = max_pkt_size - PacketHeader::len_without_addr() - addr.len(); + let frag_size_addr_none = + max_pkt_size - PacketHeader::len_without_addr() - Address::None.len(); + + let frag_total = if first_frag_size < payload.len() { + (1 + (payload.len() - first_frag_size) / frag_size_addr_none + 1) as u8 + } else { + 1u8 + }; + Self { assoc_id, + pkt_id, addr, payload, - frag_len, + max_pkt_size, + frag_total, + next_frag_id: 0, + next_frag_start: 0, } } } + +impl<'a> Iterator for Packet<'a> { + type Item = (Header, &'a [u8]); + + fn next(&mut self) -> Option { + if self.next_frag_id < self.frag_total { + let payload_size = + self.max_pkt_size - PacketHeader::len_without_addr() - self.addr.len(); + let next_frag_end = (self.next_frag_start + payload_size).min(self.payload.len()); + + let header = Header::Packet(PacketHeader::new( + self.assoc_id, + self.pkt_id, + self.frag_total, + self.next_frag_id, + (next_frag_end - self.next_frag_start) as u16, + self.addr.take(), + )); + + let payload = &self.payload[self.next_frag_start..next_frag_end]; + + self.next_frag_id += 1; + self.next_frag_start = next_frag_end; + + Some((header, payload)) + } else { + None + } + } +} + +impl ExactSizeIterator for Packet<'_> { + fn len(&self) -> usize { + self.frag_total as usize + } +} From a707943d5aeb497ea266fc4c9a9719d492100d56 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 24 Jan 2023 15:51:24 +0900 Subject: [PATCH 045/103] privatizing fields in header abstraction --- tuic/src/protocol/authenticate.rs | 6 +++- tuic/src/protocol/connect.rs | 6 +++- tuic/src/protocol/dissociate.rs | 6 +++- tuic/src/protocol/packet.rs | 50 +++++++++++++++++++++++-------- 4 files changed, 52 insertions(+), 16 deletions(-) diff --git a/tuic/src/protocol/authenticate.rs b/tuic/src/protocol/authenticate.rs index 988230d..ba016da 100644 --- a/tuic/src/protocol/authenticate.rs +++ b/tuic/src/protocol/authenticate.rs @@ -7,7 +7,7 @@ use super::Command; // +-------+ #[derive(Clone, Debug)] pub struct Authenticate { - pub token: [u8; 8], + token: [u8; 8], } impl Authenticate { @@ -16,6 +16,10 @@ impl Authenticate { pub const fn new(token: [u8; 8]) -> Self { Self { token } } + + pub fn token(&self) -> &[u8; 8] { + &self.token + } } impl Command for Authenticate { diff --git a/tuic/src/protocol/connect.rs b/tuic/src/protocol/connect.rs index db57013..f32c6e7 100644 --- a/tuic/src/protocol/connect.rs +++ b/tuic/src/protocol/connect.rs @@ -7,7 +7,7 @@ use super::{Address, Command}; // +----------+ #[derive(Clone, Debug)] pub struct Connect { - pub addr: Address, + addr: Address, } impl Connect { @@ -16,6 +16,10 @@ impl Connect { pub const fn new(addr: Address) -> Self { Self { addr } } + + pub fn addr(&self) -> &Address { + &self.addr + } } impl Command for Connect { diff --git a/tuic/src/protocol/dissociate.rs b/tuic/src/protocol/dissociate.rs index d931d13..dc4dcf3 100644 --- a/tuic/src/protocol/dissociate.rs +++ b/tuic/src/protocol/dissociate.rs @@ -7,7 +7,7 @@ use super::Command; // +----------+ #[derive(Clone, Debug)] pub struct Dissociate { - pub assoc_id: u16, + assoc_id: u16, } impl Dissociate { @@ -16,6 +16,10 @@ impl Dissociate { pub const fn new(assoc_id: u16) -> Self { Self { assoc_id } } + + pub fn assoc_id(&self) -> u16 { + self.assoc_id + } } impl Command for Dissociate { diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs index 62a8219..c4d2a2b 100644 --- a/tuic/src/protocol/packet.rs +++ b/tuic/src/protocol/packet.rs @@ -1,18 +1,18 @@ use super::{Address, Command}; -// +----------+--------+------------+---------+-----+----------+ -// | ASSOC_ID | PKT_ID | FRAG_TOTAL | FRAG_ID | LEN | ADDR | -// +----------+--------+------------+---------+-----+----------+ -// | 2 | 2 | 1 | 1 | 2 | Variable | -// +----------+--------+------------+---------+-----+----------+ +// +----------+--------+------------+---------+------+----------+ +// | ASSOC_ID | PKT_ID | FRAG_TOTAL | FRAG_ID | SIZE | ADDR | +// +----------+--------+------------+---------+------+----------+ +// | 2 | 2 | 1 | 1 | 2 | Variable | +// +----------+--------+------------+---------+------+----------+ #[derive(Clone, Debug)] pub struct Packet { - pub assoc_id: u16, - pub pkt_id: u16, - pub frag_total: u8, - pub frag_id: u8, - pub len: u16, - pub addr: Address, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, } impl Packet { @@ -23,7 +23,7 @@ impl Packet { pkt_id: u16, frag_total: u8, frag_id: u8, - len: u16, + size: u16, addr: Address, ) -> Self { Self { @@ -31,11 +31,35 @@ impl Packet { pkt_id, frag_total, frag_id, - len, + size, addr, } } + pub fn assoc_id(&self) -> u16 { + self.assoc_id + } + + pub fn pkt_id(&self) -> u16 { + self.pkt_id + } + + pub fn frag_total(&self) -> u8 { + self.frag_total + } + + pub fn frag_id(&self) -> u8 { + self.frag_id + } + + pub fn size(&self) -> u16 { + self.size + } + + pub fn addr(&self) -> &Address { + &self.addr + } + pub const fn len_without_addr() -> usize { 2 + 2 + 1 + 1 + 2 } From 46ad3068d2abd1a8367a92d0e1a7e2f9a93ab81c Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 24 Jan 2023 19:20:11 +0900 Subject: [PATCH 046/103] sharing abstraction between command Tx and Rx --- tuic/src/protocol/mod.rs | 3 +- tuic/src/prototype/authenticate.rs | 29 ++++++++++--- tuic/src/prototype/connect.rs | 29 ++++++++++--- tuic/src/prototype/dissociate.rs | 30 +++++++++---- tuic/src/prototype/heartbeat.rs | 23 ++++++++-- tuic/src/prototype/mod.rs | 67 +++++++++++++++--------------- tuic/src/prototype/packet.rs | 53 +++++++++++++++++++---- 7 files changed, 166 insertions(+), 68 deletions(-) diff --git a/tuic/src/protocol/mod.rs b/tuic/src/protocol/mod.rs index 406ebc3..8368314 100644 --- a/tuic/src/protocol/mod.rs +++ b/tuic/src/protocol/mod.rs @@ -1,11 +1,10 @@ mod address; +mod authenticate; mod connect; mod dissociate; mod heartbeat; mod packet; -pub mod authenticate; - pub use self::{ address::Address, authenticate::Authenticate, connect::Connect, dissociate::Dissociate, heartbeat::Heartbeat, packet::Packet, diff --git a/tuic/src/prototype/authenticate.rs b/tuic/src/prototype/authenticate.rs index fbccf57..fa02069 100644 --- a/tuic/src/prototype/authenticate.rs +++ b/tuic/src/prototype/authenticate.rs @@ -1,20 +1,37 @@ -use super::TaskRegister; +use super::{ + side::{self, Side, SideMarker}, + TaskRegister, +}; use crate::protocol::{Authenticate as AuthenticateHeader, Header}; -pub struct Authenticate { +pub struct Authenticate +where + M: SideMarker, +{ + inner: Side, + _marker: M, +} + +pub struct Tx { header: Header, _task_reg: TaskRegister, } -impl Authenticate { +pub struct Rx; + +impl Authenticate { pub(super) fn new(task_reg: TaskRegister, token: [u8; 8]) -> Self { Self { - header: Header::Authenticate(AuthenticateHeader::new(token)), - _task_reg: task_reg, + inner: Side::Tx(Tx { + header: Header::Authenticate(AuthenticateHeader::new(token)), + _task_reg: task_reg, + }), + _marker: side::Tx, } } pub fn header(&self) -> &Header { - &self.header + let Side::Tx(tx) = &self.inner else { unreachable!() }; + &tx.header } } diff --git a/tuic/src/prototype/connect.rs b/tuic/src/prototype/connect.rs index fddac8e..391c2d0 100644 --- a/tuic/src/prototype/connect.rs +++ b/tuic/src/prototype/connect.rs @@ -1,20 +1,37 @@ -use super::TaskRegister; +use super::{ + side::{self, Side, SideMarker}, + TaskRegister, +}; use crate::protocol::{Address, Connect as ConnectHeader, Header}; -pub struct Connect { +pub struct Connect +where + M: SideMarker, +{ + inner: Side, + _marker: M, +} + +struct Tx { header: Header, _task_reg: TaskRegister, } -impl Connect { +struct Rx; + +impl Connect { pub(super) fn new(task_reg: TaskRegister, addr: Address) -> Self { Self { - header: Header::Connect(ConnectHeader::new(addr)), - _task_reg: task_reg, + inner: Side::Tx(Tx { + header: Header::Connect(ConnectHeader::new(addr)), + _task_reg: task_reg, + }), + _marker: side::Tx, } } pub fn header(&self) -> &Header { - &self.header + let Side::Tx(tx) = &self.inner else { unreachable!() }; + &tx.header } } diff --git a/tuic/src/prototype/dissociate.rs b/tuic/src/prototype/dissociate.rs index abf5cef..c1088f5 100644 --- a/tuic/src/prototype/dissociate.rs +++ b/tuic/src/prototype/dissociate.rs @@ -1,20 +1,32 @@ -use super::TaskRegister; +use super::side::{self, Side, SideMarker}; use crate::protocol::{Dissociate as DissociateHeader, Header}; -pub struct Dissociate { - header: Header, - _task_reg: TaskRegister, +pub struct Dissociate +where + M: SideMarker, +{ + inner: Side, + _marker: M, } -impl Dissociate { - pub(super) fn new(task_reg: TaskRegister, assoc_id: u16) -> Self { +pub struct Tx { + header: Header, +} + +pub struct Rx; + +impl Dissociate { + pub(super) fn new(assoc_id: u16) -> Self { Self { - header: Header::Dissociate(DissociateHeader::new(assoc_id)), - _task_reg: task_reg, + inner: Side::Tx(Tx { + header: Header::Dissociate(DissociateHeader::new(assoc_id)), + }), + _marker: side::Tx, } } pub fn header(&self) -> &Header { - &self.header + let Side::Tx(tx) = &self.inner else { unreachable!() }; + &tx.header } } diff --git a/tuic/src/prototype/heartbeat.rs b/tuic/src/prototype/heartbeat.rs index 369a7ec..17ecdfc 100644 --- a/tuic/src/prototype/heartbeat.rs +++ b/tuic/src/prototype/heartbeat.rs @@ -1,17 +1,32 @@ +use super::side::{self, Side, SideMarker}; use crate::protocol::{Header, Heartbeat as HeartbeatHeader}; -pub struct Heartbeat { +pub struct Heartbeat +where + M: SideMarker, +{ + inner: Side, + _marker: M, +} + +pub struct Tx { header: Header, } -impl Heartbeat { +pub struct Rx; + +impl Heartbeat { pub(super) fn new() -> Self { Self { - header: Header::Heartbeat(HeartbeatHeader::new()), + inner: Side::Tx(Tx { + header: Header::Heartbeat(HeartbeatHeader::new()), + }), + _marker: side::Tx, } } pub fn header(&self) -> &Header { - &self.header + let Side::Tx(tx) = &self.inner else { unreachable!() }; + &tx.header } } diff --git a/tuic/src/prototype/mod.rs b/tuic/src/prototype/mod.rs index b96f414..4f8c5e9 100644 --- a/tuic/src/prototype/mod.rs +++ b/tuic/src/prototype/mod.rs @@ -15,8 +15,11 @@ mod heartbeat; mod packet; pub use self::{ - authenticate::Authenticate, connect::Connect, dissociate::Dissociate, heartbeat::Heartbeat, - packet::Packet, + authenticate::Authenticate, + connect::Connect, + dissociate::Dissociate, + heartbeat::Heartbeat, + packet::{Fragment, Packet}, }; pub struct Connection { @@ -34,31 +37,28 @@ impl Connection { } } - pub fn authenticate(&self, token: [u8; 8]) -> Authenticate { + pub fn send_authenticate(&self, token: [u8; 8]) -> Authenticate { Authenticate::new(self.local_active_task_count.reg(), token) } - pub fn connect(&self, addr: Address) -> Connect { - Connect::new(self.local_active_task_count.reg(), addr) + pub fn send_connect(&self, addr: Address) -> Connect { + Connect::::new(self.local_active_task_count.reg(), addr) } - pub fn packet<'a>( + pub fn send_packet( &self, assoc_id: u16, addr: Address, - payload: &'a [u8], - frag_len: usize, - ) -> Packet<'a> { - self.udp_sessions - .lock() - .send(assoc_id, addr, payload, frag_len) + max_pkt_size: usize, + ) -> Packet { + self.udp_sessions.lock().send(assoc_id, addr, max_pkt_size) } - pub fn dissociate(&self, assoc_id: u16) -> Dissociate { + pub fn send_dissociate(&self, assoc_id: u16) -> Dissociate { self.udp_sessions.lock().dissociate(assoc_id) } - pub fn heartbeat(&self) -> Heartbeat { + pub fn send_heartbeat(&self) -> Heartbeat { Heartbeat::new() } @@ -98,22 +98,16 @@ impl UdpSessions { } } - fn send<'a>( - &mut self, - assoc_id: u16, - addr: Address, - payload: &'a [u8], - frag_len: usize, - ) -> Packet<'a> { + fn send<'a>(&mut self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet { self.sessions .entry(assoc_id) .or_insert_with(|| UdpSession::new(self.local_active_task_count.reg())) - .send(assoc_id, addr, payload, frag_len) + .send(assoc_id, addr, max_pkt_size) } - fn dissociate(&mut self, assoc_id: u16) -> Dissociate { + fn dissociate(&mut self, assoc_id: u16) -> Dissociate { self.sessions.remove(&assoc_id); - Dissociate::new(self.local_active_task_count.reg(), assoc_id) + Dissociate::new(assoc_id) } } @@ -130,19 +124,26 @@ impl UdpSession { } } - fn send<'a>( - &self, - assoc_id: u16, - addr: Address, - payload: &'a [u8], - frag_len: usize, - ) -> Packet<'a> { + fn send<'a>(&self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet { Packet::new( assoc_id, self.next_pkt_id.fetch_add(1, Ordering::AcqRel), addr, - payload, - frag_len, + max_pkt_size, ) } } + +pub mod side { + pub struct Tx; + pub struct Rx; + + pub trait SideMarker {} + impl SideMarker for Tx {} + impl SideMarker for Rx {} + + pub(super) enum Side { + Tx(T), + Rx(R), + } +} diff --git a/tuic/src/prototype/packet.rs b/tuic/src/prototype/packet.rs index ccd17fc..4a139c6 100644 --- a/tuic/src/prototype/packet.rs +++ b/tuic/src/prototype/packet.rs @@ -1,23 +1,60 @@ +use super::side::{self, Side, SideMarker}; use crate::protocol::{Address, Header, Packet as PacketHeader}; -pub struct Packet<'a> { +pub struct Packet +where + M: SideMarker, +{ + inner: Side, + _marker: M, +} + +pub struct Tx { + assoc_id: u16, + pkt_id: u16, + addr: Address, + max_pkt_size: usize, +} + +pub struct Rx; + +impl Packet { + pub(super) fn new(assoc_id: u16, pkt_id: u16, addr: Address, max_pkt_size: usize) -> Self { + Self { + inner: Side::Tx(Tx { + assoc_id, + pkt_id, + addr, + max_pkt_size, + }), + _marker: side::Tx, + } + } + + pub fn into_fragments<'a>(self, payload: &'a [u8]) -> Fragment<'a> { + let Side::Tx(tx) = self.inner else { unreachable!() }; + Fragment::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload) + } +} + +pub struct Fragment<'a> { assoc_id: u16, pkt_id: u16, addr: Address, - payload: &'a [u8], max_pkt_size: usize, frag_total: u8, next_frag_id: u8, next_frag_start: usize, + payload: &'a [u8], } -impl<'a> Packet<'a> { - pub(super) fn new( +impl<'a> Fragment<'a> { + fn new( assoc_id: u16, pkt_id: u16, addr: Address, - payload: &'a [u8], max_pkt_size: usize, + payload: &'a [u8], ) -> Self { let first_frag_size = max_pkt_size - PacketHeader::len_without_addr() - addr.len(); let frag_size_addr_none = @@ -33,16 +70,16 @@ impl<'a> Packet<'a> { assoc_id, pkt_id, addr, - payload, max_pkt_size, frag_total, next_frag_id: 0, next_frag_start: 0, + payload, } } } -impl<'a> Iterator for Packet<'a> { +impl<'a> Iterator for Fragment<'a> { type Item = (Header, &'a [u8]); fn next(&mut self) -> Option { @@ -72,7 +109,7 @@ impl<'a> Iterator for Packet<'a> { } } -impl ExactSizeIterator for Packet<'_> { +impl ExactSizeIterator for Fragment<'_> { fn len(&self) -> usize { self.frag_total as usize } From 2c2fc7924b2e2de9be954ac2021d6be78d11fc4a Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 25 Jan 2023 00:12:34 +0900 Subject: [PATCH 047/103] adding header destructors --- tuic/src/protocol/authenticate.rs | 6 ++++++ tuic/src/protocol/connect.rs | 6 ++++++ tuic/src/protocol/dissociate.rs | 10 ++++++++-- tuic/src/protocol/heartbeat.rs | 6 ++++++ tuic/src/protocol/packet.rs | 33 +++++++++++++++++++++---------- tuic/src/prototype/connect.rs | 16 +++++++++++++-- tuic/src/prototype/mod.rs | 6 +++++- 7 files changed, 68 insertions(+), 15 deletions(-) diff --git a/tuic/src/protocol/authenticate.rs b/tuic/src/protocol/authenticate.rs index ba016da..eebbafd 100644 --- a/tuic/src/protocol/authenticate.rs +++ b/tuic/src/protocol/authenticate.rs @@ -31,3 +31,9 @@ impl Command for Authenticate { 8 } } + +impl From for ([u8; 8],) { + fn from(auth: Authenticate) -> Self { + (auth.token,) + } +} diff --git a/tuic/src/protocol/connect.rs b/tuic/src/protocol/connect.rs index f32c6e7..1814558 100644 --- a/tuic/src/protocol/connect.rs +++ b/tuic/src/protocol/connect.rs @@ -31,3 +31,9 @@ impl Command for Connect { self.addr.len() } } + +impl From for (Address,) { + fn from(connect: Connect) -> Self { + (connect.addr,) + } +} diff --git a/tuic/src/protocol/dissociate.rs b/tuic/src/protocol/dissociate.rs index dc4dcf3..94734f5 100644 --- a/tuic/src/protocol/dissociate.rs +++ b/tuic/src/protocol/dissociate.rs @@ -17,8 +17,8 @@ impl Dissociate { Self { assoc_id } } - pub fn assoc_id(&self) -> u16 { - self.assoc_id + pub fn assoc_id(&self) -> &u16 { + &self.assoc_id } } @@ -31,3 +31,9 @@ impl Command for Dissociate { 2 } } + +impl From for (u16,) { + fn from(dissoc: Dissociate) -> Self { + (dissoc.assoc_id,) + } +} diff --git a/tuic/src/protocol/heartbeat.rs b/tuic/src/protocol/heartbeat.rs index 7b03ad1..4694444 100644 --- a/tuic/src/protocol/heartbeat.rs +++ b/tuic/src/protocol/heartbeat.rs @@ -25,3 +25,9 @@ impl Command for Heartbeat { 0 } } + +impl From for () { + fn from(hb: Heartbeat) -> Self { + () + } +} diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs index c4d2a2b..71c79e5 100644 --- a/tuic/src/protocol/packet.rs +++ b/tuic/src/protocol/packet.rs @@ -36,24 +36,24 @@ impl Packet { } } - pub fn assoc_id(&self) -> u16 { - self.assoc_id + pub fn assoc_id(&self) -> &u16 { + &self.assoc_id } - pub fn pkt_id(&self) -> u16 { - self.pkt_id + pub fn pkt_id(&self) -> &u16 { + &self.pkt_id } - pub fn frag_total(&self) -> u8 { - self.frag_total + pub fn frag_total(&self) -> &u8 { + &self.frag_total } - pub fn frag_id(&self) -> u8 { - self.frag_id + pub fn frag_id(&self) -> &u8 { + &self.frag_id } - pub fn size(&self) -> u16 { - self.size + pub fn size(&self) -> &u16 { + &self.size } pub fn addr(&self) -> &Address { @@ -74,3 +74,16 @@ impl Command for Packet { 2 + 2 + 1 + 1 + 2 + self.addr.len() } } + +impl From for (u16, u16, u8, u8, u16, Address) { + fn from(pkt: Packet) -> Self { + ( + pkt.assoc_id, + pkt.pkt_id, + pkt.frag_total, + pkt.frag_id, + pkt.size, + pkt.addr, + ) + } +} diff --git a/tuic/src/prototype/connect.rs b/tuic/src/prototype/connect.rs index 391c2d0..646ea9b 100644 --- a/tuic/src/prototype/connect.rs +++ b/tuic/src/prototype/connect.rs @@ -17,8 +17,6 @@ struct Tx { _task_reg: TaskRegister, } -struct Rx; - impl Connect { pub(super) fn new(task_reg: TaskRegister, addr: Address) -> Self { Self { @@ -35,3 +33,17 @@ impl Connect { &tx.header } } + +struct Rx { + addr: Address, +} + +impl Connect { + pub(super) fn new(header: ConnectHeader) -> Self { + let (addr,) = header.into(); + Self { + inner: Side::Rx(Rx { addr }), + _marker: side::Rx, + } + } +} diff --git a/tuic/src/prototype/mod.rs b/tuic/src/prototype/mod.rs index 4f8c5e9..2c0b509 100644 --- a/tuic/src/prototype/mod.rs +++ b/tuic/src/prototype/mod.rs @@ -1,4 +1,4 @@ -use crate::protocol::Address; +use crate::protocol::{Address, Connect as ConnectHeader}; use parking_lot::Mutex; use std::{ collections::HashMap, @@ -45,6 +45,10 @@ impl Connection { Connect::::new(self.local_active_task_count.reg(), addr) } + pub fn recv_connect(&self, header: ConnectHeader) -> Connect { + Connect::::new(header) + } + pub fn send_packet( &self, assoc_id: u16, From 6cfb00dd810e9244de00c004f3ee5b2529ed792a Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 25 Jan 2023 00:26:30 +0900 Subject: [PATCH 048/103] seperating task `CONNECT` and `ASSOCIATE` count --- tuic/src/prototype/authenticate.rs | 9 ++------ tuic/src/prototype/connect.rs | 8 +++++-- tuic/src/prototype/mod.rs | 36 +++++++++++++++++------------- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/tuic/src/prototype/authenticate.rs b/tuic/src/prototype/authenticate.rs index fa02069..be03ee8 100644 --- a/tuic/src/prototype/authenticate.rs +++ b/tuic/src/prototype/authenticate.rs @@ -1,7 +1,4 @@ -use super::{ - side::{self, Side, SideMarker}, - TaskRegister, -}; +use super::side::{self, Side, SideMarker}; use crate::protocol::{Authenticate as AuthenticateHeader, Header}; pub struct Authenticate @@ -14,17 +11,15 @@ where pub struct Tx { header: Header, - _task_reg: TaskRegister, } pub struct Rx; impl Authenticate { - pub(super) fn new(task_reg: TaskRegister, token: [u8; 8]) -> Self { + pub(super) fn new(token: [u8; 8]) -> Self { Self { inner: Side::Tx(Tx { header: Header::Authenticate(AuthenticateHeader::new(token)), - _task_reg: task_reg, }), _marker: side::Tx, } diff --git a/tuic/src/prototype/connect.rs b/tuic/src/prototype/connect.rs index 646ea9b..a3643a1 100644 --- a/tuic/src/prototype/connect.rs +++ b/tuic/src/prototype/connect.rs @@ -36,13 +36,17 @@ impl Connect { struct Rx { addr: Address, + _task_reg: TaskRegister, } impl Connect { - pub(super) fn new(header: ConnectHeader) -> Self { + pub(super) fn new(task_reg: TaskRegister, header: ConnectHeader) -> Self { let (addr,) = header.into(); Self { - inner: Side::Rx(Rx { addr }), + inner: Side::Rx(Rx { + addr, + _task_reg: task_reg, + }), _marker: side::Rx, } } diff --git a/tuic/src/prototype/mod.rs b/tuic/src/prototype/mod.rs index 2c0b509..86ca1d7 100644 --- a/tuic/src/prototype/mod.rs +++ b/tuic/src/prototype/mod.rs @@ -24,29 +24,31 @@ pub use self::{ pub struct Connection { udp_sessions: Mutex, - local_active_task_count: ActiveTaskCount, + task_connect_count: TaskCount, + task_associate_count: TaskCount, } impl Connection { pub fn new() -> Self { - let local_active_task_count = ActiveTaskCount::new(); + let task_associate_count = TaskCount::new(); Self { - udp_sessions: Mutex::new(UdpSessions::new(local_active_task_count.clone())), - local_active_task_count, + udp_sessions: Mutex::new(UdpSessions::new(task_associate_count.clone())), + task_connect_count: TaskCount::new(), + task_associate_count, } } pub fn send_authenticate(&self, token: [u8; 8]) -> Authenticate { - Authenticate::new(self.local_active_task_count.reg(), token) + Authenticate::new(token) } pub fn send_connect(&self, addr: Address) -> Connect { - Connect::::new(self.local_active_task_count.reg(), addr) + Connect::::new(self.task_connect_count.reg(), addr) } pub fn recv_connect(&self, header: ConnectHeader) -> Connect { - Connect::::new(header) + Connect::::new(self.task_connect_count.reg(), header) } pub fn send_packet( @@ -66,16 +68,20 @@ impl Connection { Heartbeat::new() } - pub fn local_active_task_count(&self) -> usize { - self.local_active_task_count.get() + pub fn task_connect_count(&self) -> usize { + self.task_connect_count.get() + } + + pub fn task_associate_count(&self) -> usize { + self.task_associate_count.get() } } #[derive(Clone)] -struct ActiveTaskCount(Arc<()>); +struct TaskCount(Arc<()>); struct TaskRegister(Weak<()>); -impl ActiveTaskCount { +impl TaskCount { fn new() -> Self { Self(Arc::new(())) } @@ -91,21 +97,21 @@ impl ActiveTaskCount { struct UdpSessions { sessions: HashMap, - local_active_task_count: ActiveTaskCount, + task_associate_count: TaskCount, } impl UdpSessions { - fn new(local_active_task_count: ActiveTaskCount) -> Self { + fn new(task_associate_count: TaskCount) -> Self { Self { sessions: HashMap::new(), - local_active_task_count, + task_associate_count, } } fn send<'a>(&mut self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet { self.sessions .entry(assoc_id) - .or_insert_with(|| UdpSession::new(self.local_active_task_count.reg())) + .or_insert_with(|| UdpSession::new(self.task_associate_count.reg())) .send(assoc_id, addr, max_pkt_size) } From 358fcd95f7b1b8e03a73eba092176346ba489634 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 25 Jan 2023 17:34:59 +0900 Subject: [PATCH 049/103] packet assembling mechanism --- tuic/Cargo.toml | 3 +- tuic/src/protocol/address.rs | 8 +- tuic/src/protocol/heartbeat.rs | 2 +- tuic/src/prototype/authenticate.rs | 7 +- tuic/src/prototype/connect.rs | 15 +- tuic/src/prototype/dissociate.rs | 7 +- tuic/src/prototype/heartbeat.rs | 7 +- tuic/src/prototype/mod.rs | 270 +++++++++++++++++++++++++---- tuic/src/prototype/packet.rs | 77 +++++++- 9 files changed, 329 insertions(+), 67 deletions(-) diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index cdefb9c..3aab101 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -4,10 +4,11 @@ version = "0.1.0" edition = "2021" [features] -prototype = ["parking_lot"] +prototype = ["parking_lot", "thiserror"] [dependencies] parking_lot = { version = "0.12.1", default-features = false, optional = true } +thiserror = { version = "1.0.38", default-features = false, optional = true } [dev-dependencies] tuic = { path = ".", features = ["prototype"] } diff --git a/tuic/src/protocol/address.rs b/tuic/src/protocol/address.rs index a45ac40..78eaefc 100644 --- a/tuic/src/protocol/address.rs +++ b/tuic/src/protocol/address.rs @@ -16,10 +16,10 @@ use std::{ /// /// The address type can be one of the following: /// -/// 0xff: None -/// 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name) -/// 0x01: IPv4 address -/// 0x02: IPv6 address +/// - 0xff: None +/// - 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name) +/// - 0x01: IPv4 address +/// - 0x02: IPv6 address /// /// The port number is encoded in 2 bytes after the Domain name / IP address. #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] diff --git a/tuic/src/protocol/heartbeat.rs b/tuic/src/protocol/heartbeat.rs index 4694444..91087e8 100644 --- a/tuic/src/protocol/heartbeat.rs +++ b/tuic/src/protocol/heartbeat.rs @@ -27,7 +27,7 @@ impl Command for Heartbeat { } impl From for () { - fn from(hb: Heartbeat) -> Self { + fn from(_: Heartbeat) -> Self { () } } diff --git a/tuic/src/prototype/authenticate.rs b/tuic/src/prototype/authenticate.rs index be03ee8..118ee20 100644 --- a/tuic/src/prototype/authenticate.rs +++ b/tuic/src/prototype/authenticate.rs @@ -1,10 +1,7 @@ -use super::side::{self, Side, SideMarker}; +use super::side::{self, Side}; use crate::protocol::{Authenticate as AuthenticateHeader, Header}; -pub struct Authenticate -where - M: SideMarker, -{ +pub struct Authenticate { inner: Side, _marker: M, } diff --git a/tuic/src/prototype/connect.rs b/tuic/src/prototype/connect.rs index a3643a1..0e40479 100644 --- a/tuic/src/prototype/connect.rs +++ b/tuic/src/prototype/connect.rs @@ -1,13 +1,10 @@ use super::{ - side::{self, Side, SideMarker}, + side::{self, Side}, TaskRegister, }; use crate::protocol::{Address, Connect as ConnectHeader, Header}; -pub struct Connect -where - M: SideMarker, -{ +pub struct Connect { inner: Side, _marker: M, } @@ -40,8 +37,7 @@ struct Rx { } impl Connect { - pub(super) fn new(task_reg: TaskRegister, header: ConnectHeader) -> Self { - let (addr,) = header.into(); + pub(super) fn new(task_reg: TaskRegister, addr: Address) -> Self { Self { inner: Side::Rx(Rx { addr, @@ -50,4 +46,9 @@ impl Connect { _marker: side::Rx, } } + + pub fn addr(&self) -> &Address { + let Side::Rx(rx) = &self.inner else { unreachable!() }; + &rx.addr + } } diff --git a/tuic/src/prototype/dissociate.rs b/tuic/src/prototype/dissociate.rs index c1088f5..bf7f728 100644 --- a/tuic/src/prototype/dissociate.rs +++ b/tuic/src/prototype/dissociate.rs @@ -1,10 +1,7 @@ -use super::side::{self, Side, SideMarker}; +use super::side::{self, Side}; use crate::protocol::{Dissociate as DissociateHeader, Header}; -pub struct Dissociate -where - M: SideMarker, -{ +pub struct Dissociate { inner: Side, _marker: M, } diff --git a/tuic/src/prototype/heartbeat.rs b/tuic/src/prototype/heartbeat.rs index 17ecdfc..b5b96eb 100644 --- a/tuic/src/prototype/heartbeat.rs +++ b/tuic/src/prototype/heartbeat.rs @@ -1,10 +1,7 @@ -use super::side::{self, Side, SideMarker}; +use super::side::{self, Side}; use crate::protocol::{Header, Heartbeat as HeartbeatHeader}; -pub struct Heartbeat -where - M: SideMarker, -{ +pub struct Heartbeat { inner: Side, _marker: M, } diff --git a/tuic/src/prototype/mod.rs b/tuic/src/prototype/mod.rs index 86ca1d7..b79cd7e 100644 --- a/tuic/src/prototype/mod.rs +++ b/tuic/src/prototype/mod.rs @@ -1,4 +1,4 @@ -use crate::protocol::{Address, Connect as ConnectHeader}; +use crate::protocol::{Address, Connect as ConnectHeader, Packet as PacketHeader}; use parking_lot::Mutex; use std::{ collections::HashMap, @@ -6,7 +6,9 @@ use std::{ atomic::{AtomicU16, Ordering}, Arc, Weak, }, + time::{Duration, Instant}, }; +use thiserror::Error; mod authenticate; mod connect; @@ -22,18 +24,21 @@ pub use self::{ packet::{Fragment, Packet}, }; -pub struct Connection { - udp_sessions: Mutex, +pub struct Connection { + udp_sessions: Arc>>, task_connect_count: TaskCount, task_associate_count: TaskCount, } -impl Connection { +impl Connection +where + B: AsRef<[u8]>, +{ pub fn new() -> Self { let task_associate_count = TaskCount::new(); Self { - udp_sessions: Mutex::new(UdpSessions::new(task_associate_count.clone())), + udp_sessions: Arc::new(Mutex::new(UdpSessions::new(task_associate_count.clone()))), task_connect_count: TaskCount::new(), task_associate_count, } @@ -44,11 +49,12 @@ impl Connection { } pub fn send_connect(&self, addr: Address) -> Connect { - Connect::::new(self.task_connect_count.reg(), addr) + Connect::::new(self.task_connect_count.register(), addr) } pub fn recv_connect(&self, header: ConnectHeader) -> Connect { - Connect::::new(self.task_connect_count.reg(), header) + let (addr,) = header.into(); + Connect::::new(self.task_connect_count.register(), addr) } pub fn send_packet( @@ -56,8 +62,23 @@ impl Connection { assoc_id: u16, addr: Address, max_pkt_size: usize, - ) -> Packet { - self.udp_sessions.lock().send(assoc_id, addr, max_pkt_size) + ) -> Packet { + self.udp_sessions + .lock() + .send_packet(assoc_id, addr, max_pkt_size) + } + + pub fn recv_packet(&self, header: PacketHeader) -> Packet { + let (assoc_id, pkt_id, frag_total, frag_id, size, addr) = header.into(); + self.udp_sessions.lock().recv_packet( + self.udp_sessions.clone(), + assoc_id, + pkt_id, + frag_total, + frag_id, + size, + addr, + ) } pub fn send_dissociate(&self, assoc_id: u16) -> Dissociate { @@ -75,6 +96,10 @@ impl Connection { pub fn task_associate_count(&self) -> usize { self.task_associate_count.get() } + + pub fn collect_garbage(&self, timeout: Duration) { + self.udp_sessions.lock().collect_garbage(timeout); + } } #[derive(Clone)] @@ -86,7 +111,7 @@ impl TaskCount { Self(Arc::new(())) } - fn reg(&self) -> TaskRegister { + fn register(&self) -> TaskRegister { TaskRegister(Arc::downgrade(&self.0)) } @@ -95,12 +120,25 @@ impl TaskCount { } } -struct UdpSessions { - sessions: HashMap, +pub mod side { + pub struct Tx; + pub struct Rx; + + pub(super) enum Side { + Tx(T), + Rx(R), + } +} + +struct UdpSessions { + sessions: HashMap>, task_associate_count: TaskCount, } -impl UdpSessions { +impl UdpSessions +where + B: AsRef<[u8]>, +{ fn new(task_associate_count: TaskCount) -> Self { Self { sessions: HashMap::new(), @@ -108,52 +146,224 @@ impl UdpSessions { } } - fn send<'a>(&mut self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet { + fn send_packet<'a>( + &mut self, + assoc_id: u16, + addr: Address, + max_pkt_size: usize, + ) -> Packet { self.sessions .entry(assoc_id) - .or_insert_with(|| UdpSession::new(self.task_associate_count.reg())) - .send(assoc_id, addr, max_pkt_size) + .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) + .send_packet(assoc_id, addr, max_pkt_size) + } + + fn recv_packet<'a>( + &mut self, + sessions: Arc>, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + ) -> Packet { + self.sessions + .entry(assoc_id) + .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) + .recv_packet(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr) } fn dissociate(&mut self, assoc_id: u16) -> Dissociate { self.sessions.remove(&assoc_id); Dissociate::new(assoc_id) } + + fn insert( + &mut self, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + data: B, + ) -> Result, AssembleError> + where + A: Assembled, + { + self.sessions + .entry(assoc_id) + .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) + .insert(pkt_id, frag_total, frag_id, size, addr, data) + } + + fn collect_garbage(&mut self, timeout: Duration) { + for (_, session) in self.sessions.iter_mut() { + session.collect_garbage(timeout); + } + } } -struct UdpSession { +struct UdpSession { + pkt_buf: HashMap>, next_pkt_id: AtomicU16, _task_reg: TaskRegister, } -impl UdpSession { +impl UdpSession +where + B: AsRef<[u8]>, +{ fn new(task_reg: TaskRegister) -> Self { Self { + pkt_buf: HashMap::new(), next_pkt_id: AtomicU16::new(0), _task_reg: task_reg, } } - fn send<'a>(&self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet { - Packet::new( + fn send_packet( + &self, + assoc_id: u16, + addr: Address, + max_pkt_size: usize, + ) -> Packet { + Packet::::new( assoc_id, self.next_pkt_id.fetch_add(1, Ordering::AcqRel), addr, max_pkt_size, ) } -} -pub mod side { - pub struct Tx; - pub struct Rx; + fn recv_packet( + &self, + sessions: Arc>>, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + ) -> Packet { + Packet::::new(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr) + } - pub trait SideMarker {} - impl SideMarker for Tx {} - impl SideMarker for Rx {} + fn insert( + &mut self, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + data: B, + ) -> Result, AssembleError> + where + A: Assembled, + { + let res = self + .pkt_buf + .entry(pkt_id) + .or_insert_with(|| PacketBuffer::new(frag_total)) + .insert(frag_total, frag_id, size, addr, data)?; - pub(super) enum Side { - Tx(T), - Rx(R), + if res.is_some() { + self.pkt_buf.remove(&pkt_id); + } + + Ok(res) + } + + fn collect_garbage(&mut self, timeout: Duration) { + self.pkt_buf.retain(|_, buf| buf.c_time.elapsed() < timeout); } } + +struct PacketBuffer { + buf: Vec>, + frag_total: u8, + frag_received: u8, + addr: Address, + c_time: Instant, +} + +impl PacketBuffer +where + B: AsRef<[u8]>, +{ + fn new(frag_total: u8) -> Self { + let mut buf = Vec::with_capacity(frag_total as usize); + buf.resize_with(frag_total as usize, || None); + + Self { + buf, + frag_total, + frag_received: 0, + addr: Address::None, + c_time: Instant::now(), + } + } + + fn insert( + &mut self, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + data: B, + ) -> Result, AssembleError> + where + A: Assembled, + { + if data.as_ref().len() != size as usize { + return Err(AssembleError::InvalidFragmentSize); + } + + if frag_id >= frag_total { + return Err(AssembleError::InvalidFragmentId); + } + + if (frag_id == 0 && addr.is_none()) || (frag_id != 0 && !addr.is_none()) { + return Err(AssembleError::InvalidAddress); + } + + if self.buf[frag_id as usize].is_some() { + return Err(AssembleError::DuplicateFragment); + } + + self.buf[frag_id as usize] = Some(data); + self.frag_received += 1; + + if frag_id == 0 { + self.addr = addr; + } + + if self.frag_received == self.frag_total { + let iter = self.buf.iter_mut().map(|x| x.take().unwrap()); + Ok(Some((A::assemble(iter)?, self.addr.take()))) + } else { + Ok(None) + } + } +} + +pub trait Assembled +where + Self: Sized, + B: AsRef<[u8]>, +{ + fn assemble(buf: impl IntoIterator) -> Result; +} + +#[derive(Debug, Error)] +pub enum AssembleError { + #[error("invalid fragment size")] + InvalidFragmentSize, + #[error("invalid fragment id")] + InvalidFragmentId, + #[error("invalid address")] + InvalidAddress, + #[error("duplicate fragment")] + DuplicateFragment, +} diff --git a/tuic/src/prototype/packet.rs b/tuic/src/prototype/packet.rs index 4a139c6..56ae3bc 100644 --- a/tuic/src/prototype/packet.rs +++ b/tuic/src/prototype/packet.rs @@ -1,11 +1,13 @@ -use super::side::{self, Side, SideMarker}; +use super::{ + side::{self, Side}, + AssembleError, Assembled, UdpSessions, +}; use crate::protocol::{Address, Header, Packet as PacketHeader}; +use parking_lot::Mutex; +use std::sync::Arc; -pub struct Packet -where - M: SideMarker, -{ - inner: Side, +pub struct Packet { + inner: Side>, _marker: M, } @@ -16,9 +18,10 @@ pub struct Tx { max_pkt_size: usize, } -pub struct Rx; - -impl Packet { +impl Packet +where + B: AsRef<[u8]>, +{ pub(super) fn new(assoc_id: u16, pkt_id: u16, addr: Address, max_pkt_size: usize) -> Self { Self { inner: Side::Tx(Tx { @@ -37,6 +40,62 @@ impl Packet { } } +pub struct Rx { + sessions: Arc>>, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, +} + +impl Packet +where + B: AsRef<[u8]>, +{ + pub(super) fn new( + sessions: Arc>>, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + ) -> Self { + Self { + inner: Side::Rx(Rx { + sessions, + assoc_id, + pkt_id, + frag_total, + frag_id, + size, + addr, + }), + _marker: side::Rx, + } + } + + pub fn assemble(self, data: B) -> Result, AssembleError> + where + A: Assembled, + { + let Side::Rx(rx) = self.inner else { unreachable!() }; + let mut sessions = rx.sessions.lock(); + + sessions.insert( + rx.assoc_id, + rx.pkt_id, + rx.frag_total, + rx.frag_id, + rx.size, + rx.addr, + data, + ) + } +} + pub struct Fragment<'a> { assoc_id: u16, pkt_id: u16, From f175aad8e8283bfd7bd742a2714f621cafa404fa Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 25 Jan 2023 18:18:38 +0900 Subject: [PATCH 050/103] generalizing payload for `Fragment` --- tuic/src/prototype/packet.rs | 50 +++++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/tuic/src/prototype/packet.rs b/tuic/src/prototype/packet.rs index 56ae3bc..5afd443 100644 --- a/tuic/src/prototype/packet.rs +++ b/tuic/src/prototype/packet.rs @@ -4,7 +4,7 @@ use super::{ }; use crate::protocol::{Address, Header, Packet as PacketHeader}; use parking_lot::Mutex; -use std::sync::Arc; +use std::{marker::PhantomData, slice, sync::Arc}; pub struct Packet { inner: Side>, @@ -34,7 +34,10 @@ where } } - pub fn into_fragments<'a>(self, payload: &'a [u8]) -> Fragment<'a> { + pub fn into_fragments<'a, P>(self, payload: P) -> Fragment<'a, P> + where + P: AsRef<[u8]>, + { let Side::Tx(tx) = self.inner else { unreachable!() }; Fragment::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload) } @@ -96,7 +99,10 @@ where } } -pub struct Fragment<'a> { +pub struct Fragment<'a, P> +where + P: 'a, +{ assoc_id: u16, pkt_id: u16, addr: Address, @@ -104,23 +110,21 @@ pub struct Fragment<'a> { frag_total: u8, next_frag_id: u8, next_frag_start: usize, - payload: &'a [u8], + payload: P, + _marker: PhantomData<&'a P>, } -impl<'a> Fragment<'a> { - fn new( - assoc_id: u16, - pkt_id: u16, - addr: Address, - max_pkt_size: usize, - payload: &'a [u8], - ) -> Self { +impl<'a, P> Fragment<'a, P> +where + P: AsRef<[u8]> + 'a, +{ + fn new(assoc_id: u16, pkt_id: u16, addr: Address, max_pkt_size: usize, payload: P) -> Self { let first_frag_size = max_pkt_size - PacketHeader::len_without_addr() - addr.len(); let frag_size_addr_none = max_pkt_size - PacketHeader::len_without_addr() - Address::None.len(); - let frag_total = if first_frag_size < payload.len() { - (1 + (payload.len() - first_frag_size) / frag_size_addr_none + 1) as u8 + let frag_total = if first_frag_size < payload.as_ref().len() { + (1 + (payload.as_ref().len() - first_frag_size) / frag_size_addr_none + 1) as u8 } else { 1u8 }; @@ -134,18 +138,23 @@ impl<'a> Fragment<'a> { next_frag_id: 0, next_frag_start: 0, payload, + _marker: PhantomData, } } } -impl<'a> Iterator for Fragment<'a> { +impl<'a, P> Iterator for Fragment<'a, P> +where + P: AsRef<[u8]> + 'a, +{ type Item = (Header, &'a [u8]); fn next(&mut self) -> Option { if self.next_frag_id < self.frag_total { let payload_size = self.max_pkt_size - PacketHeader::len_without_addr() - self.addr.len(); - let next_frag_end = (self.next_frag_start + payload_size).min(self.payload.len()); + let next_frag_end = + (self.next_frag_start + payload_size).min(self.payload.as_ref().len()); let header = Header::Packet(PacketHeader::new( self.assoc_id, @@ -156,7 +165,9 @@ impl<'a> Iterator for Fragment<'a> { self.addr.take(), )); - let payload = &self.payload[self.next_frag_start..next_frag_end]; + let payload_ptr = &(self.payload.as_ref()[self.next_frag_start]) as *const u8; + let payload = + unsafe { slice::from_raw_parts(payload_ptr, next_frag_end - self.next_frag_start) }; self.next_frag_id += 1; self.next_frag_start = next_frag_end; @@ -168,7 +179,10 @@ impl<'a> Iterator for Fragment<'a> { } } -impl ExactSizeIterator for Fragment<'_> { +impl

ExactSizeIterator for Fragment<'_, P> +impl

ExactSizeIterator for Fragments<'_, P> where P: AsRef<[u8]>, { From a4d178c95e0f7058260e0561de1cdabf0e92a585 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 25 Jan 2023 20:08:05 +0900 Subject: [PATCH 054/103] assembling packet to existed buffer --- tuic/src/model/mod.rs | 55 ++++++++++++++++++++++++++-------------- tuic/src/model/packet.rs | 7 ++--- 2 files changed, 38 insertions(+), 24 deletions(-) diff --git a/tuic/src/model/mod.rs b/tuic/src/model/mod.rs index 4cd8cb6..f526e41 100644 --- a/tuic/src/model/mod.rs +++ b/tuic/src/model/mod.rs @@ -5,6 +5,7 @@ use crate::protocol::{ use parking_lot::Mutex; use std::{ collections::HashMap, + mem, sync::{ atomic::{AtomicU16, Ordering}, Arc, Weak, @@ -202,7 +203,7 @@ where Dissociate::::new(assoc_id) } - fn insert( + fn insert( &mut self, assoc_id: u16, pkt_id: u16, @@ -211,10 +212,7 @@ where size: u16, addr: Address, data: B, - ) -> Result, AssembleError> - where - A: Assembled, - { + ) -> Result>, AssembleError> { self.sessions .entry(assoc_id) .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) @@ -273,7 +271,7 @@ where Packet::::new(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr) } - fn insert( + fn insert( &mut self, pkt_id: u16, frag_total: u8, @@ -281,10 +279,7 @@ where size: u16, addr: Address, data: B, - ) -> Result, AssembleError> - where - A: Assembled, - { + ) -> Result>, AssembleError> { let res = self .pkt_buf .entry(pkt_id) @@ -328,17 +323,14 @@ where } } - fn insert( + fn insert( &mut self, frag_total: u8, frag_id: u8, size: u16, addr: Address, data: B, - ) -> Result, AssembleError> - where - A: Assembled, - { + ) -> Result>, AssembleError> { if data.as_ref().len() != size as usize { return Err(AssembleError::InvalidFragmentSize); } @@ -363,20 +355,45 @@ where } if self.frag_received == self.frag_total { - let iter = self.buf.iter_mut().map(|x| x.take().unwrap()); - Ok(Some((A::assemble(iter)?, self.addr.take()))) + Ok(Some(Assemblable::new( + mem::take(&mut self.buf), + self.addr.take(), + ))) } else { Ok(None) } } } -pub trait Assembled +pub struct Assemblable { + buf: Vec>, + addr: Address, +} + +impl Assemblable +where + B: AsRef<[u8]>, +{ + fn new(buf: Vec>, addr: Address) -> Self { + Self { buf, addr } + } + + pub fn assemble(self, buf: &mut A) -> Address + where + A: Assembler, + { + let data = self.buf.into_iter().map(|b| b.unwrap()); + buf.assemble(data); + self.addr + } +} + +pub trait Assembler where Self: Sized, B: AsRef<[u8]>, { - fn assemble(buf: impl IntoIterator) -> Result; + fn assemble(&mut self, data: impl IntoIterator); } #[derive(Debug, Error)] diff --git a/tuic/src/model/packet.rs b/tuic/src/model/packet.rs index d4c4483..181fc7c 100644 --- a/tuic/src/model/packet.rs +++ b/tuic/src/model/packet.rs @@ -1,6 +1,6 @@ use super::{ side::{self, Side}, - AssembleError, Assembled, UdpSessions, + Assemblable, AssembleError, UdpSessions, }; use crate::protocol::{Address, Header, Packet as PacketHeader}; use parking_lot::Mutex; @@ -77,10 +77,7 @@ where } } - pub fn assemble(self, data: B) -> Result, AssembleError> - where - A: Assembled, - { + pub fn assemble(self, data: B) -> Result>, AssembleError> { let Side::Rx(rx) = self.inner else { unreachable!() }; let mut sessions = rx.sessions.lock(); From 101d4427eb12cb83e5a77444eeacddfde993dcdd Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 26 Jan 2023 19:05:16 +0900 Subject: [PATCH 055/103] implement client-side `connect` and `packet` --- tuic-quinn/Cargo.toml | 8 +- tuic-quinn/src/lib.rs | 165 +++++++++++++++++++++++++++++++++--- tuic-quinn/src/marshal.rs | 16 ++++ tuic-quinn/src/unmarshal.rs | 7 ++ 4 files changed, 184 insertions(+), 12 deletions(-) create mode 100644 tuic-quinn/src/marshal.rs create mode 100644 tuic-quinn/src/unmarshal.rs diff --git a/tuic-quinn/Cargo.toml b/tuic-quinn/Cargo.toml index d98ceed..f990737 100644 --- a/tuic-quinn/Cargo.toml +++ b/tuic-quinn/Cargo.toml @@ -3,6 +3,10 @@ name = "tuic-quinn" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +async-trait = { version = "0.1.62", default-features = false } +bytes = { version = "1.3.0", default-features = false, features = ["std"] } +futures-util = { version = "0.3.25", default-features = false, features = ["io", "std"] } +quinn = { version = "0.9.3", default-features = false, features = ["futures-io"] } +thiserror = { version = "1.0.38", default-features = false } +tuic = { path = "../tuic", default-features = false, features = ["model"] } diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 7d12d9a..e9a5249 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -1,14 +1,159 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right -} +use self::{marshal::Marshal, side::Side}; +use bytes::Bytes; +use futures_util::{io::Cursor, AsyncRead, AsyncWrite, AsyncWriteExt}; +use quinn::{ + Connection as QuinnConnection, ConnectionError, RecvStream, SendDatagramError, SendStream, +}; +use std::{ + io::Error as IoError, + pin::Pin, + task::{Context, Poll}, +}; +use thiserror::Error; +use tuic::{ + model::{ + side::{Rx, Tx}, + Connect as ConnectModel, Connection as ConnectionModel, + }, + protocol::{Address, Header}, +}; -#[cfg(test)] -mod tests { - use super::*; +mod marshal; +mod unmarshal; - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); +pub mod side { + pub struct Client; + pub struct Server; + + pub(super) enum Side { + Client(C), + Server(S), } } + +pub struct Connection<'conn, Side> { + conn: &'conn QuinnConnection, + model: ConnectionModel, + _marker: Side, +} + +impl<'conn> Connection<'conn, side::Client> { + pub fn new(conn: &'conn QuinnConnection) -> Self { + Self { + conn, + model: ConnectionModel::new(), + _marker: side::Client, + } + } + + pub async fn connect(&self, addr: Address) -> Result { + let (mut send, recv) = self.conn.open_bi().await?; + let model = self.model.send_connect(addr); + model.header().marshal(&mut send).await?; + Ok(Connect::new(Side::Client(model), send, recv)) + } + + pub async fn packet_native( + &self, + pkt: impl AsRef<[u8]>, + assoc_id: u16, + addr: Address, + ) -> Result<(), Error> { + let Some(max_pkt_size) = self.conn.max_datagram_size() else { + return Err(Error::SendDatagram(SendDatagramError::Disabled)); + }; + + let model = self.model.send_packet(assoc_id, addr, max_pkt_size); + + for (header, frag) in model.into_fragments(pkt) { + let mut buf = Cursor::new(vec![0; header.len() + frag.len()]); + header.marshal(&mut buf).await?; + buf.write_all(frag).await.unwrap(); + self.conn.send_datagram(Bytes::from(buf.into_inner()))?; + } + + Ok(()) + } + + pub async fn packet_quic( + &self, + pkt: impl AsRef<[u8]>, + assoc_id: u16, + addr: Address, + ) -> Result<(), Error> { + let model = self.model.send_packet(assoc_id, addr, u16::MAX as usize); + let mut frags = model.into_fragments(pkt); + let (header, frag) = frags.next().unwrap(); + assert!(frags.next().is_none()); + + let mut send = self.conn.open_uni().await?; + header.marshal(&mut send).await?; + AsyncWriteExt::write_all(&mut send, frag).await?; + + Ok(()) + } +} + +pub struct Connect { + model: Side, ConnectModel>, + send: SendStream, + recv: RecvStream, +} + +impl Connect { + fn new( + model: Side, ConnectModel>, + send: SendStream, + recv: RecvStream, + ) -> Self { + Self { model, send, recv } + } + + pub fn addr(&self) -> &Address { + match &self.model { + Side::Client(model) => { + let Header::Connect(connect) = model.header() else { unreachable!() }; + &connect.addr() + } + Side::Server(model) => model.addr(), + } + } +} + +impl AsyncRead for Connect { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + AsyncRead::poll_read(Pin::new(&mut self.get_mut().recv), cx, buf) + } +} + +impl AsyncWrite for Connect { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut self.get_mut().send), cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut self.get_mut().send), cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_close(Pin::new(&mut self.get_mut().send), cx) + } +} + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Connection(#[from] ConnectionError), + #[error(transparent)] + SendDatagram(#[from] SendDatagramError), +} diff --git a/tuic-quinn/src/marshal.rs b/tuic-quinn/src/marshal.rs new file mode 100644 index 0000000..66c2d85 --- /dev/null +++ b/tuic-quinn/src/marshal.rs @@ -0,0 +1,16 @@ +use async_trait::async_trait; +use futures_util::AsyncWrite; +use std::io::Error as IoError; +use tuic::protocol::Header; + +#[async_trait] +pub(super) trait Marshal { + async fn marshal(&self, s: &mut impl AsyncWrite) -> Result<(), IoError>; +} + +#[async_trait] +impl Marshal for Header { + async fn marshal(&self, s: &mut impl AsyncWrite) -> Result<(), IoError> { + todo!() + } +} diff --git a/tuic-quinn/src/unmarshal.rs b/tuic-quinn/src/unmarshal.rs new file mode 100644 index 0000000..c7e6e11 --- /dev/null +++ b/tuic-quinn/src/unmarshal.rs @@ -0,0 +1,7 @@ +use async_trait::async_trait; +use futures_util::AsyncRead; + +#[async_trait] +trait Unmarshal { + fn unmarshal(&self, s: &mut impl AsyncRead) -> Result<(), ()>; +} From 10cee5276e8a0b429faf2a1cab9acadbb1aad0bc Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 26 Jan 2023 20:47:20 +0900 Subject: [PATCH 056/103] adding uni stream parsing methods --- tuic-quinn/src/lib.rs | 135 +++++++++++++++++++++++++++++------ tuic-quinn/src/unmarshal.rs | 19 ++++- tuic/src/model/mod.rs | 31 ++++++-- tuic/src/model/packet.rs | 25 +++++++ tuic/src/protocol/connect.rs | 4 +- tuic/src/protocol/mod.rs | 2 +- 6 files changed, 183 insertions(+), 33 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index e9a5249..b247b11 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -1,6 +1,10 @@ -use self::{marshal::Marshal, side::Side}; +use self::{ + marshal::Marshal, + side::Side, + unmarshal::{Unmarshal, UnmarshalError}, +}; use bytes::Bytes; -use futures_util::{io::Cursor, AsyncRead, AsyncWrite, AsyncWriteExt}; +use futures_util::{io::Cursor, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use quinn::{ Connection as QuinnConnection, ConnectionError, RecvStream, SendDatagramError, SendStream, }; @@ -13,7 +17,8 @@ use thiserror::Error; use tuic::{ model::{ side::{Rx, Tx}, - Connect as ConnectModel, Connection as ConnectionModel, + AssembleError, Connect as ConnectModel, Connection as ConnectionModel, + Packet as PacketModel, }, protocol::{Address, Header}, }; @@ -33,26 +38,11 @@ pub mod side { pub struct Connection<'conn, Side> { conn: &'conn QuinnConnection, - model: ConnectionModel, + model: ConnectionModel>, _marker: Side, } -impl<'conn> Connection<'conn, side::Client> { - pub fn new(conn: &'conn QuinnConnection) -> Self { - Self { - conn, - model: ConnectionModel::new(), - _marker: side::Client, - } - } - - pub async fn connect(&self, addr: Address) -> Result { - let (mut send, recv) = self.conn.open_bi().await?; - let model = self.model.send_connect(addr); - model.header().marshal(&mut send).await?; - Ok(Connect::new(Side::Client(model), send, recv)) - } - +impl<'conn, Side> Connection<'conn, Side> { pub async fn packet_native( &self, pkt: impl AsRef<[u8]>, @@ -92,6 +82,92 @@ impl<'conn> Connection<'conn, side::Client> { Ok(()) } + + async fn accept_packet_quic( + &self, + model: PacketModel>, + mut recv: &mut RecvStream, + ) -> Result, Address, u16)>, Error> { + let mut buf = vec![0; *model.size() as usize]; + AsyncReadExt::read_exact(&mut recv, &mut buf).await?; + let mut asm = Vec::new(); + + Ok(model + .assemble(buf)? + .map(|pkt| pkt.assemble(&mut asm)) + .map(|(addr, assoc_id)| (asm, addr, assoc_id))) + } +} + +impl<'conn> Connection<'conn, side::Client> { + pub fn new(conn: &'conn QuinnConnection) -> Self { + Self { + conn, + model: ConnectionModel::new(), + _marker: side::Client, + } + } + + pub async fn connect(&self, addr: Address) -> Result { + let (mut send, recv) = self.conn.open_bi().await?; + let model = self.model.send_connect(addr); + model.header().marshal(&mut send).await?; + Ok(Connect::new(Side::Client(model), send, recv)) + } + + pub async fn handle_uni_stream(&self, mut recv: RecvStream) -> Result { + match Header::unmarshal(&mut recv).await? { + Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), + Header::Connect(_) => Err(Error::BadCommand("connect")), + Header::Packet(pkt) => { + let model = self.model.recv_packet(pkt); + Ok(Task::Packet( + self.accept_packet_quic(model, &mut recv).await?, + )) + } + Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), + Header::Heartbeat(hb) => { + let _ = self.model.recv_heartbeat(hb); + Ok(Task::Heartbeat) + } + _ => unreachable!(), + } + } +} + +impl<'conn> Connection<'conn, side::Server> { + pub fn new(conn: &'conn QuinnConnection) -> Self { + Self { + conn, + model: ConnectionModel::new(), + _marker: side::Server, + } + } + + pub async fn handle_uni_stream(&self, mut recv: RecvStream) -> Result { + match Header::unmarshal(&mut recv).await? { + Header::Authenticate(auth) => { + let model = self.model.recv_authenticate(auth); + Ok(Task::Authenticate(*model.token())) + } + Header::Connect(_) => Err(Error::BadCommand("connect")), + Header::Packet(pkt) => { + let model = self.model.recv_packet(pkt); + Ok(Task::Packet( + self.accept_packet_quic(model, &mut recv).await?, + )) + } + Header::Dissociate(dissoc) => { + let _ = self.model.recv_dissociate(dissoc); + Ok(Task::Dissociate) + } + Header::Heartbeat(hb) => { + let _ = self.model.recv_heartbeat(hb); + Ok(Task::Heartbeat) + } + _ => unreachable!(), + } + } } pub struct Connect { @@ -112,8 +188,8 @@ impl Connect { pub fn addr(&self) -> &Address { match &self.model { Side::Client(model) => { - let Header::Connect(connect) = model.header() else { unreachable!() }; - &connect.addr() + let Header::Connect(conn) = model.header() else { unreachable!() }; + &conn.addr() } Side::Server(model) => model.addr(), } @@ -148,6 +224,15 @@ impl AsyncWrite for Connect { } } +#[non_exhaustive] +pub enum Task { + Authenticate([u8; 8]), + Connect(Connect), + Packet(Option<(Vec, Address, u16)>), + Dissociate, + Heartbeat, +} + #[derive(Debug, Error)] pub enum Error { #[error(transparent)] @@ -156,4 +241,10 @@ pub enum Error { Connection(#[from] ConnectionError), #[error(transparent)] SendDatagram(#[from] SendDatagramError), + #[error(transparent)] + Unmarshal(#[from] UnmarshalError), + #[error(transparent)] + Assemble(#[from] AssembleError), + #[error("{0}")] + BadCommand(&'static str), } diff --git a/tuic-quinn/src/unmarshal.rs b/tuic-quinn/src/unmarshal.rs index c7e6e11..1f1c9b7 100644 --- a/tuic-quinn/src/unmarshal.rs +++ b/tuic-quinn/src/unmarshal.rs @@ -1,7 +1,22 @@ use async_trait::async_trait; use futures_util::AsyncRead; +use thiserror::Error; +use tuic::protocol::Header; #[async_trait] -trait Unmarshal { - fn unmarshal(&self, s: &mut impl AsyncRead) -> Result<(), ()>; +pub(super) trait Unmarshal +where + Self: Sized, +{ + async fn unmarshal(s: &mut impl AsyncRead) -> Result; } + +#[async_trait] +impl Unmarshal for Header { + async fn unmarshal(s: &mut impl AsyncRead) -> Result { + todo!() + } +} + +#[derive(Debug, Error)] +pub enum UnmarshalError {} diff --git a/tuic/src/model/mod.rs b/tuic/src/model/mod.rs index f526e41..47b29a6 100644 --- a/tuic/src/model/mod.rs +++ b/tuic/src/model/mod.rs @@ -216,7 +216,7 @@ where self.sessions .entry(assoc_id) .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) - .insert(pkt_id, frag_total, frag_id, size, addr, data) + .insert(assoc_id, pkt_id, frag_total, frag_id, size, addr, data) } fn collect_garbage(&mut self, timeout: Duration) { @@ -273,6 +273,7 @@ where fn insert( &mut self, + assoc_id: u16, pkt_id: u16, frag_total: u8, frag_id: u8, @@ -284,7 +285,7 @@ where .pkt_buf .entry(pkt_id) .or_insert_with(|| PacketBuffer::new(frag_total)) - .insert(frag_total, frag_id, size, addr, data)?; + .insert(assoc_id, frag_total, frag_id, size, addr, data)?; if res.is_some() { self.pkt_buf.remove(&pkt_id); @@ -325,6 +326,7 @@ where fn insert( &mut self, + assoc_id: u16, frag_total: u8, frag_id: u8, size: u16, @@ -358,6 +360,7 @@ where Ok(Some(Assemblable::new( mem::take(&mut self.buf), self.addr.take(), + assoc_id, ))) } else { Ok(None) @@ -368,23 +371,28 @@ where pub struct Assemblable { buf: Vec>, addr: Address, + assoc_id: u16, } impl Assemblable where B: AsRef<[u8]>, { - fn new(buf: Vec>, addr: Address) -> Self { - Self { buf, addr } + fn new(buf: Vec>, addr: Address, assoc_id: u16) -> Self { + Self { + buf, + addr, + assoc_id, + } } - pub fn assemble(self, buf: &mut A) -> Address + pub fn assemble(self, buf: &mut A) -> (Address, u16) where A: Assembler, { let data = self.buf.into_iter().map(|b| b.unwrap()); buf.assemble(data); - self.addr + (self.addr, self.assoc_id) } } @@ -396,6 +404,17 @@ where fn assemble(&mut self, data: impl IntoIterator); } +impl Assembler for Vec +where + B: AsRef<[u8]>, +{ + fn assemble(&mut self, data: impl IntoIterator) { + for d in data { + self.extend_from_slice(d.as_ref()); + } + } +} + #[derive(Debug, Error)] pub enum AssembleError { #[error("invalid fragment size")] diff --git a/tuic/src/model/packet.rs b/tuic/src/model/packet.rs index 181fc7c..ea336e3 100644 --- a/tuic/src/model/packet.rs +++ b/tuic/src/model/packet.rs @@ -38,6 +38,16 @@ impl Packet { let Side::Tx(tx) = self.inner else { unreachable!() }; Fragments::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload) } + + pub fn assoc_id(&self) -> &u16 { + let Side::Tx(tx) = &self.inner else { unreachable!() }; + &tx.assoc_id + } + + pub fn addr(&self) -> &Address { + let Side::Tx(tx) = &self.inner else { unreachable!() }; + &tx.addr + } } pub struct Rx { @@ -91,6 +101,21 @@ where data, ) } + + pub fn assoc_id(&self) -> &u16 { + let Side::Rx(rx) = &self.inner else { unreachable!() }; + &rx.assoc_id + } + + pub fn addr(&self) -> &Address { + let Side::Rx(rx) = &self.inner else { unreachable!() }; + &rx.addr + } + + pub fn size(&self) -> &u16 { + let Side::Rx(rx) = &self.inner else { unreachable!() }; + &rx.size + } } pub struct Fragments<'a, P> diff --git a/tuic/src/protocol/connect.rs b/tuic/src/protocol/connect.rs index 1814558..20a4d96 100644 --- a/tuic/src/protocol/connect.rs +++ b/tuic/src/protocol/connect.rs @@ -33,7 +33,7 @@ impl Command for Connect { } impl From for (Address,) { - fn from(connect: Connect) -> Self { - (connect.addr,) + fn from(conn: Connect) -> Self { + (conn.addr,) } } diff --git a/tuic/src/protocol/mod.rs b/tuic/src/protocol/mod.rs index 8368314..c890893 100644 --- a/tuic/src/protocol/mod.rs +++ b/tuic/src/protocol/mod.rs @@ -51,7 +51,7 @@ impl Header { pub fn len(&self) -> usize { 2 + match self { Self::Authenticate(auth) => auth.len(), - Self::Connect(connect) => connect.len(), + Self::Connect(conn) => conn.len(), Self::Packet(packet) => packet.len(), Self::Dissociate(dissociate) => dissociate.len(), Self::Heartbeat(heartbeat) => heartbeat.len(), From 20384701a0c72535a9b03463f7ff9780e1dbe443 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 00:16:51 +0900 Subject: [PATCH 057/103] adding datagram and bi stream parsing methods --- tuic-quinn/src/lib.rs | 112 +++++++++++++++++++++++++++++++++++------- 1 file changed, 95 insertions(+), 17 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index b247b11..0cfc3e1 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -38,7 +38,7 @@ pub mod side { pub struct Connection<'conn, Side> { conn: &'conn QuinnConnection, - model: ConnectionModel>, + model: ConnectionModel, _marker: Side, } @@ -46,8 +46,8 @@ impl<'conn, Side> Connection<'conn, Side> { pub async fn packet_native( &self, pkt: impl AsRef<[u8]>, - assoc_id: u16, addr: Address, + assoc_id: u16, ) -> Result<(), Error> { let Some(max_pkt_size) = self.conn.max_datagram_size() else { return Err(Error::SendDatagram(SendDatagramError::Disabled)); @@ -68,8 +68,8 @@ impl<'conn, Side> Connection<'conn, Side> { pub async fn packet_quic( &self, pkt: impl AsRef<[u8]>, - assoc_id: u16, addr: Address, + assoc_id: u16, ) -> Result<(), Error> { let model = self.model.send_packet(assoc_id, addr, u16::MAX as usize); let mut frags = model.into_fragments(pkt); @@ -79,23 +79,58 @@ impl<'conn, Side> Connection<'conn, Side> { let mut send = self.conn.open_uni().await?; header.marshal(&mut send).await?; AsyncWriteExt::write_all(&mut send, frag).await?; + send.close().await?; Ok(()) } + pub async fn accept_datagram(&self, dg: Bytes) -> Result { + let mut dg = Cursor::new(dg); + + match Header::unmarshal(&mut dg).await? { + Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), + Header::Connect(_) => Err(Error::BadCommand("connect")), + Header::Packet(pkt) => { + let model = self.model.recv_packet(pkt); + let pos = dg.position() as usize; + let buf = dg.into_inner().slice(pos..pos + *model.size() as usize); + Ok(Task::Packet(self.accept_packet_native(model, buf).await?)) + } + Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), + Header::Heartbeat(hb) => { + let _ = self.model.recv_heartbeat(hb); + Ok(Task::Heartbeat) + } + _ => unreachable!(), + } + } + async fn accept_packet_quic( &self, - model: PacketModel>, + model: PacketModel, mut recv: &mut RecvStream, - ) -> Result, Address, u16)>, Error> { + ) -> Result, Error> { let mut buf = vec![0; *model.size() as usize]; AsyncReadExt::read_exact(&mut recv, &mut buf).await?; let mut asm = Vec::new(); Ok(model - .assemble(buf)? + .assemble(Bytes::from(buf))? .map(|pkt| pkt.assemble(&mut asm)) - .map(|(addr, assoc_id)| (asm, addr, assoc_id))) + .map(|(addr, assoc_id)| (Bytes::from(asm), addr, assoc_id))) + } + + async fn accept_packet_native( + &self, + model: PacketModel, + data: Bytes, + ) -> Result, Error> { + let mut asm = Vec::new(); + + Ok(model + .assemble(data)? + .map(|pkt| pkt.assemble(&mut asm)) + .map(|(addr, assoc_id)| (Bytes::from(asm), addr, assoc_id))) } } @@ -108,6 +143,14 @@ impl<'conn> Connection<'conn, side::Client> { } } + pub async fn authenticate(&self, token: [u8; 8]) -> Result<(), Error> { + let mut send = self.conn.open_uni().await?; + let model = self.model.send_authenticate(token); + model.header().marshal(&mut send).await?; + send.close().await?; + Ok(()) + } + pub async fn connect(&self, addr: Address) -> Result { let (mut send, recv) = self.conn.open_bi().await?; let model = self.model.send_connect(addr); @@ -115,7 +158,7 @@ impl<'conn> Connection<'conn, side::Client> { Ok(Connect::new(Side::Client(model), send, recv)) } - pub async fn handle_uni_stream(&self, mut recv: RecvStream) -> Result { + pub async fn accept_uni_stream(&self, mut recv: RecvStream) -> Result { match Header::unmarshal(&mut recv).await? { Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), Header::Connect(_) => Err(Error::BadCommand("connect")), @@ -126,13 +169,33 @@ impl<'conn> Connection<'conn, side::Client> { )) } Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), - Header::Heartbeat(hb) => { - let _ = self.model.recv_heartbeat(hb); - Ok(Task::Heartbeat) - } + Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), _ => unreachable!(), } } + + pub async fn accept_bi_stream( + &self, + _send: SendStream, + mut recv: RecvStream, + ) -> Result { + match Header::unmarshal(&mut recv).await? { + Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), + Header::Connect(_) => Err(Error::BadCommand("connect")), + Header::Packet(_) => Err(Error::BadCommand("packet")), + Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), + Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), + _ => unreachable!(), + } + } + + pub async fn heartbeat(&self) -> Result<(), Error> { + let model = self.model.send_heartbeat(); + let mut buf = Vec::with_capacity(model.header().len()); + model.header().marshal(&mut buf).await.unwrap(); + self.conn.send_datagram(Bytes::from(buf))?; + Ok(()) + } } impl<'conn> Connection<'conn, side::Server> { @@ -144,7 +207,7 @@ impl<'conn> Connection<'conn, side::Server> { } } - pub async fn handle_uni_stream(&self, mut recv: RecvStream) -> Result { + pub async fn accept_uni_stream(&self, mut recv: RecvStream) -> Result { match Header::unmarshal(&mut recv).await? { Header::Authenticate(auth) => { let model = self.model.recv_authenticate(auth); @@ -161,10 +224,25 @@ impl<'conn> Connection<'conn, side::Server> { let _ = self.model.recv_dissociate(dissoc); Ok(Task::Dissociate) } - Header::Heartbeat(hb) => { - let _ = self.model.recv_heartbeat(hb); - Ok(Task::Heartbeat) + Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), + _ => unreachable!(), + } + } + + pub async fn accept_bi_stream( + &self, + send: SendStream, + mut recv: RecvStream, + ) -> Result { + match Header::unmarshal(&mut recv).await? { + Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), + Header::Connect(conn) => { + let model = self.model.recv_connect(conn); + Ok(Task::Connect(Connect::new(Side::Server(model), send, recv))) } + Header::Packet(_) => Err(Error::BadCommand("packet")), + Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), + Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), _ => unreachable!(), } } @@ -228,7 +306,7 @@ impl AsyncWrite for Connect { pub enum Task { Authenticate([u8; 8]), Connect(Connect), - Packet(Option<(Vec, Address, u16)>), + Packet(Option<(Bytes, Address, u16)>), Dissociate, Heartbeat, } From e3760f4a015e94dd4ad543a535902c5ea8de873c Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 00:30:06 +0900 Subject: [PATCH 058/103] moving (un)marshal to crate `tuic` --- tuic-quinn/Cargo.toml | 3 +-- tuic-quinn/src/lib.rs | 11 ++--------- tuic/Cargo.toml | 5 ++++- tuic/src/lib.rs | 18 +++++++++++++++++- {tuic-quinn => tuic}/src/marshal.rs | 6 +++--- {tuic-quinn => tuic}/src/unmarshal.rs | 6 +++--- 6 files changed, 30 insertions(+), 19 deletions(-) rename {tuic-quinn => tuic}/src/marshal.rs (77%) rename {tuic-quinn => tuic}/src/unmarshal.rs (81%) diff --git a/tuic-quinn/Cargo.toml b/tuic-quinn/Cargo.toml index f990737..07e1ec3 100644 --- a/tuic-quinn/Cargo.toml +++ b/tuic-quinn/Cargo.toml @@ -4,9 +4,8 @@ version = "0.1.0" edition = "2021" [dependencies] -async-trait = { version = "0.1.62", default-features = false } bytes = { version = "1.3.0", default-features = false, features = ["std"] } futures-util = { version = "0.3.25", default-features = false, features = ["io", "std"] } quinn = { version = "0.9.3", default-features = false, features = ["futures-io"] } thiserror = { version = "1.0.38", default-features = false } -tuic = { path = "../tuic", default-features = false, features = ["model"] } +tuic = { path = "../tuic", default-features = false, features = ["marshal", "model"] } diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 0cfc3e1..64f3ca2 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -1,8 +1,4 @@ -use self::{ - marshal::Marshal, - side::Side, - unmarshal::{Unmarshal, UnmarshalError}, -}; +use self::side::Side; use bytes::Bytes; use futures_util::{io::Cursor, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use quinn::{ @@ -20,12 +16,9 @@ use tuic::{ AssembleError, Connect as ConnectModel, Connection as ConnectionModel, Packet as PacketModel, }, - protocol::{Address, Header}, + Address, Header, Marshal, Unmarshal, UnmarshalError, }; -mod marshal; -mod unmarshal; - pub mod side { pub struct Client; pub struct Server; diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index 490a448..526b169 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -4,11 +4,14 @@ version = "0.1.0" edition = "2021" [features] +marshal = ["async-trait", "futures-io"] model = ["parking_lot", "thiserror"] [dependencies] +async-trait = { version = "0.1.62", default-features = false, optional = true } +futures-io = { version = "0.3.25", default-features = false, features = ["std"], optional = true } parking_lot = { version = "0.12.1", default-features = false, optional = true } thiserror = { version = "1.0.38", default-features = false, optional = true } [dev-dependencies] -tuic = { path = ".", features = ["model"] } +tuic = { path = ".", features = ["marshal", "model"] } diff --git a/tuic/src/lib.rs b/tuic/src/lib.rs index 4eaef2f..48dbda5 100644 --- a/tuic/src/lib.rs +++ b/tuic/src/lib.rs @@ -1,6 +1,22 @@ //! The TUIC protocol -pub mod protocol; +mod protocol; + +#[cfg(feature = "marshal")] +mod marshal; + +#[cfg(feature = "marshal")] +mod unmarshal; + +pub use self::protocol::{ + Address, Authenticate, Command, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, +}; + +#[cfg(feature = "marshal")] +pub use self::{ + marshal::Marshal, + unmarshal::{Unmarshal, UnmarshalError}, +}; #[cfg(feature = "model")] pub mod model; diff --git a/tuic-quinn/src/marshal.rs b/tuic/src/marshal.rs similarity index 77% rename from tuic-quinn/src/marshal.rs rename to tuic/src/marshal.rs index 66c2d85..86ddbc9 100644 --- a/tuic-quinn/src/marshal.rs +++ b/tuic/src/marshal.rs @@ -1,10 +1,10 @@ +use crate::protocol::Header; use async_trait::async_trait; -use futures_util::AsyncWrite; +use futures_io::AsyncWrite; use std::io::Error as IoError; -use tuic::protocol::Header; #[async_trait] -pub(super) trait Marshal { +pub trait Marshal { async fn marshal(&self, s: &mut impl AsyncWrite) -> Result<(), IoError>; } diff --git a/tuic-quinn/src/unmarshal.rs b/tuic/src/unmarshal.rs similarity index 81% rename from tuic-quinn/src/unmarshal.rs rename to tuic/src/unmarshal.rs index 1f1c9b7..d933fc8 100644 --- a/tuic-quinn/src/unmarshal.rs +++ b/tuic/src/unmarshal.rs @@ -1,10 +1,10 @@ +use crate::protocol::Header; use async_trait::async_trait; -use futures_util::AsyncRead; +use futures_io::AsyncRead; use thiserror::Error; -use tuic::protocol::Header; #[async_trait] -pub(super) trait Unmarshal +pub trait Unmarshal where Self: Sized, { From fc6444f86ed82ee4ab0eaa692cadb7bc1d697807 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 00:40:06 +0900 Subject: [PATCH 059/103] better error messages for `AssembleError` --- tuic/src/model/mod.rs | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/tuic/src/model/mod.rs b/tuic/src/model/mod.rs index 47b29a6..98f59b4 100644 --- a/tuic/src/model/mod.rs +++ b/tuic/src/model/mod.rs @@ -333,20 +333,26 @@ where addr: Address, data: B, ) -> Result>, AssembleError> { - if data.as_ref().len() != size as usize { - return Err(AssembleError::InvalidFragmentSize); - } + assert_eq!(data.as_ref().len(), size as usize); if frag_id >= frag_total { - return Err(AssembleError::InvalidFragmentId); + return Err(AssembleError::InvalidFragmentId(frag_total, frag_id)); } - if (frag_id == 0 && addr.is_none()) || (frag_id != 0 && !addr.is_none()) { - return Err(AssembleError::InvalidAddress); + if frag_id == 0 && addr.is_none() { + return Err(AssembleError::InvalidAddress( + "no address in first fragment", + )); + } + + if frag_id != 0 && !addr.is_none() { + return Err(AssembleError::InvalidAddress( + "address in non-first fragment", + )); } if self.buf[frag_id as usize].is_some() { - return Err(AssembleError::DuplicatedFragment); + return Err(AssembleError::DuplicatedFragment(frag_id)); } self.buf[frag_id as usize] = Some(data); @@ -417,12 +423,10 @@ where #[derive(Debug, Error)] pub enum AssembleError { - #[error("invalid fragment size")] - InvalidFragmentSize, - #[error("invalid fragment id")] - InvalidFragmentId, - #[error("invalid address")] - InvalidAddress, - #[error("duplicated fragment")] - DuplicatedFragment, + #[error("invalid fragment id {1} in total {0} fragments")] + InvalidFragmentId(u8, u8), + #[error("{0}")] + InvalidAddress(&'static str), + #[error("duplicated fragment: {0}")] + DuplicatedFragment(u8), } From 823ae55024ed257d56cc3d85001a1921f885ab00 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 00:42:23 +0900 Subject: [PATCH 060/103] `collect_garbage()` for `tuic_quinn::Connection` --- tuic-quinn/src/lib.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 64f3ca2..cd6973e 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -8,6 +8,7 @@ use std::{ io::Error as IoError, pin::Pin, task::{Context, Poll}, + time::Duration, }; use thiserror::Error; use tuic::{ @@ -98,6 +99,10 @@ impl<'conn, Side> Connection<'conn, Side> { } } + pub fn collect_garbage(&self, timeout: Duration) { + self.model.collect_garbage(timeout); + } + async fn accept_packet_quic( &self, model: PacketModel, From 437776abb030d5be1c7309f2d45a2848b0d50b7c Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 01:33:46 +0900 Subject: [PATCH 061/103] stop client from receiving heartbeat --- tuic-quinn/src/lib.rs | 72 +++++++++++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 27 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index cd6973e..2b25756 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -78,27 +78,6 @@ impl<'conn, Side> Connection<'conn, Side> { Ok(()) } - pub async fn accept_datagram(&self, dg: Bytes) -> Result { - let mut dg = Cursor::new(dg); - - match Header::unmarshal(&mut dg).await? { - Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), - Header::Connect(_) => Err(Error::BadCommand("connect")), - Header::Packet(pkt) => { - let model = self.model.recv_packet(pkt); - let pos = dg.position() as usize; - let buf = dg.into_inner().slice(pos..pos + *model.size() as usize); - Ok(Task::Packet(self.accept_packet_native(model, buf).await?)) - } - Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), - Header::Heartbeat(hb) => { - let _ = self.model.recv_heartbeat(hb); - Ok(Task::Heartbeat) - } - _ => unreachable!(), - } - } - pub fn collect_garbage(&self, timeout: Duration) { self.model.collect_garbage(timeout); } @@ -156,6 +135,14 @@ impl<'conn> Connection<'conn, side::Client> { Ok(Connect::new(Side::Client(model), send, recv)) } + pub async fn heartbeat(&self) -> Result<(), Error> { + let model = self.model.send_heartbeat(); + let mut buf = Vec::with_capacity(model.header().len()); + model.header().marshal(&mut buf).await.unwrap(); + self.conn.send_datagram(Bytes::from(buf))?; + Ok(()) + } + pub async fn accept_uni_stream(&self, mut recv: RecvStream) -> Result { match Header::unmarshal(&mut recv).await? { Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), @@ -187,12 +174,22 @@ impl<'conn> Connection<'conn, side::Client> { } } - pub async fn heartbeat(&self) -> Result<(), Error> { - let model = self.model.send_heartbeat(); - let mut buf = Vec::with_capacity(model.header().len()); - model.header().marshal(&mut buf).await.unwrap(); - self.conn.send_datagram(Bytes::from(buf))?; - Ok(()) + pub async fn accept_datagram(&self, dg: Bytes) -> Result { + let mut dg = Cursor::new(dg); + + match Header::unmarshal(&mut dg).await? { + Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), + Header::Connect(_) => Err(Error::BadCommand("connect")), + Header::Packet(pkt) => { + let model = self.model.recv_packet(pkt); + let pos = dg.position() as usize; + let buf = dg.into_inner().slice(pos..pos + *model.size() as usize); + Ok(Task::Packet(self.accept_packet_native(model, buf).await?)) + } + Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), + Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), + _ => unreachable!(), + } } } @@ -244,6 +241,27 @@ impl<'conn> Connection<'conn, side::Server> { _ => unreachable!(), } } + + pub async fn accept_datagram(&self, dg: Bytes) -> Result { + let mut dg = Cursor::new(dg); + + match Header::unmarshal(&mut dg).await? { + Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), + Header::Connect(_) => Err(Error::BadCommand("connect")), + Header::Packet(pkt) => { + let model = self.model.recv_packet(pkt); + let pos = dg.position() as usize; + let buf = dg.into_inner().slice(pos..pos + *model.size() as usize); + Ok(Task::Packet(self.accept_packet_native(model, buf).await?)) + } + Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), + Header::Heartbeat(hb) => { + let _ = self.model.recv_heartbeat(hb); + Ok(Task::Heartbeat) + } + _ => unreachable!(), + } + } } pub struct Connect { From 011e397c60b7ac58a125f2602cfe65abd6540faa Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 01:41:44 +0900 Subject: [PATCH 062/103] get rid of `async_trait` in (un)marshal --- tuic-quinn/Cargo.toml | 2 +- tuic-quinn/src/lib.rs | 24 ++++++++++++------------ tuic/Cargo.toml | 5 ++--- tuic/src/lib.rs | 19 ++++++++----------- tuic/src/marshal.rs | 11 ++--------- tuic/src/unmarshal.rs | 14 ++------------ 6 files changed, 27 insertions(+), 48 deletions(-) diff --git a/tuic-quinn/Cargo.toml b/tuic-quinn/Cargo.toml index 07e1ec3..f51f7ce 100644 --- a/tuic-quinn/Cargo.toml +++ b/tuic-quinn/Cargo.toml @@ -8,4 +8,4 @@ bytes = { version = "1.3.0", default-features = false, features = ["std"] } futures-util = { version = "0.3.25", default-features = false, features = ["io", "std"] } quinn = { version = "0.9.3", default-features = false, features = ["futures-io"] } thiserror = { version = "1.0.38", default-features = false } -tuic = { path = "../tuic", default-features = false, features = ["marshal", "model"] } +tuic = { path = "../tuic", default-features = false, features = ["async_marshal", "model"] } diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 2b25756..3c8b2ea 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -17,7 +17,7 @@ use tuic::{ AssembleError, Connect as ConnectModel, Connection as ConnectionModel, Packet as PacketModel, }, - Address, Header, Marshal, Unmarshal, UnmarshalError, + Address, Header, UnmarshalError, }; pub mod side { @@ -51,7 +51,7 @@ impl<'conn, Side> Connection<'conn, Side> { for (header, frag) in model.into_fragments(pkt) { let mut buf = Cursor::new(vec![0; header.len() + frag.len()]); - header.marshal(&mut buf).await?; + header.async_marshal(&mut buf).await?; buf.write_all(frag).await.unwrap(); self.conn.send_datagram(Bytes::from(buf.into_inner()))?; } @@ -71,7 +71,7 @@ impl<'conn, Side> Connection<'conn, Side> { assert!(frags.next().is_none()); let mut send = self.conn.open_uni().await?; - header.marshal(&mut send).await?; + header.async_marshal(&mut send).await?; AsyncWriteExt::write_all(&mut send, frag).await?; send.close().await?; @@ -123,7 +123,7 @@ impl<'conn> Connection<'conn, side::Client> { pub async fn authenticate(&self, token: [u8; 8]) -> Result<(), Error> { let mut send = self.conn.open_uni().await?; let model = self.model.send_authenticate(token); - model.header().marshal(&mut send).await?; + model.header().async_marshal(&mut send).await?; send.close().await?; Ok(()) } @@ -131,20 +131,20 @@ impl<'conn> Connection<'conn, side::Client> { pub async fn connect(&self, addr: Address) -> Result { let (mut send, recv) = self.conn.open_bi().await?; let model = self.model.send_connect(addr); - model.header().marshal(&mut send).await?; + model.header().async_marshal(&mut send).await?; Ok(Connect::new(Side::Client(model), send, recv)) } pub async fn heartbeat(&self) -> Result<(), Error> { let model = self.model.send_heartbeat(); let mut buf = Vec::with_capacity(model.header().len()); - model.header().marshal(&mut buf).await.unwrap(); + model.header().async_marshal(&mut buf).await.unwrap(); self.conn.send_datagram(Bytes::from(buf))?; Ok(()) } pub async fn accept_uni_stream(&self, mut recv: RecvStream) -> Result { - match Header::unmarshal(&mut recv).await? { + match Header::async_unmarshal(&mut recv).await? { Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), Header::Connect(_) => Err(Error::BadCommand("connect")), Header::Packet(pkt) => { @@ -164,7 +164,7 @@ impl<'conn> Connection<'conn, side::Client> { _send: SendStream, mut recv: RecvStream, ) -> Result { - match Header::unmarshal(&mut recv).await? { + match Header::async_unmarshal(&mut recv).await? { Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), Header::Connect(_) => Err(Error::BadCommand("connect")), Header::Packet(_) => Err(Error::BadCommand("packet")), @@ -177,7 +177,7 @@ impl<'conn> Connection<'conn, side::Client> { pub async fn accept_datagram(&self, dg: Bytes) -> Result { let mut dg = Cursor::new(dg); - match Header::unmarshal(&mut dg).await? { + match Header::async_unmarshal(&mut dg).await? { Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), Header::Connect(_) => Err(Error::BadCommand("connect")), Header::Packet(pkt) => { @@ -203,7 +203,7 @@ impl<'conn> Connection<'conn, side::Server> { } pub async fn accept_uni_stream(&self, mut recv: RecvStream) -> Result { - match Header::unmarshal(&mut recv).await? { + match Header::async_unmarshal(&mut recv).await? { Header::Authenticate(auth) => { let model = self.model.recv_authenticate(auth); Ok(Task::Authenticate(*model.token())) @@ -229,7 +229,7 @@ impl<'conn> Connection<'conn, side::Server> { send: SendStream, mut recv: RecvStream, ) -> Result { - match Header::unmarshal(&mut recv).await? { + match Header::async_unmarshal(&mut recv).await? { Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), Header::Connect(conn) => { let model = self.model.recv_connect(conn); @@ -245,7 +245,7 @@ impl<'conn> Connection<'conn, side::Server> { pub async fn accept_datagram(&self, dg: Bytes) -> Result { let mut dg = Cursor::new(dg); - match Header::unmarshal(&mut dg).await? { + match Header::async_unmarshal(&mut dg).await? { Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), Header::Connect(_) => Err(Error::BadCommand("connect")), Header::Packet(pkt) => { diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index 526b169..f19b12a 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -4,14 +4,13 @@ version = "0.1.0" edition = "2021" [features] -marshal = ["async-trait", "futures-io"] +async_marshal = ["futures-io"] model = ["parking_lot", "thiserror"] [dependencies] -async-trait = { version = "0.1.62", default-features = false, optional = true } futures-io = { version = "0.3.25", default-features = false, features = ["std"], optional = true } parking_lot = { version = "0.12.1", default-features = false, optional = true } thiserror = { version = "1.0.38", default-features = false, optional = true } [dev-dependencies] -tuic = { path = ".", features = ["marshal", "model"] } +tuic = { path = ".", features = ["async_marshal", "model"] } diff --git a/tuic/src/lib.rs b/tuic/src/lib.rs index 48dbda5..05c95ad 100644 --- a/tuic/src/lib.rs +++ b/tuic/src/lib.rs @@ -2,21 +2,18 @@ mod protocol; -#[cfg(feature = "marshal")] -mod marshal; - -#[cfg(feature = "marshal")] -mod unmarshal; - pub use self::protocol::{ Address, Authenticate, Command, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, }; -#[cfg(feature = "marshal")] -pub use self::{ - marshal::Marshal, - unmarshal::{Unmarshal, UnmarshalError}, -}; +#[cfg(feature = "async_marshal")] +mod marshal; + +#[cfg(feature = "async_marshal")] +mod unmarshal; + +#[cfg(feature = "async_marshal")] +pub use self::unmarshal::UnmarshalError; #[cfg(feature = "model")] pub mod model; diff --git a/tuic/src/marshal.rs b/tuic/src/marshal.rs index 86ddbc9..2cbeba4 100644 --- a/tuic/src/marshal.rs +++ b/tuic/src/marshal.rs @@ -1,16 +1,9 @@ use crate::protocol::Header; -use async_trait::async_trait; use futures_io::AsyncWrite; use std::io::Error as IoError; -#[async_trait] -pub trait Marshal { - async fn marshal(&self, s: &mut impl AsyncWrite) -> Result<(), IoError>; -} - -#[async_trait] -impl Marshal for Header { - async fn marshal(&self, s: &mut impl AsyncWrite) -> Result<(), IoError> { +impl Header { + pub async fn async_marshal(&self, s: &mut impl AsyncWrite) -> Result<(), IoError> { todo!() } } diff --git a/tuic/src/unmarshal.rs b/tuic/src/unmarshal.rs index d933fc8..aa1fd56 100644 --- a/tuic/src/unmarshal.rs +++ b/tuic/src/unmarshal.rs @@ -1,19 +1,9 @@ use crate::protocol::Header; -use async_trait::async_trait; use futures_io::AsyncRead; use thiserror::Error; -#[async_trait] -pub trait Unmarshal -where - Self: Sized, -{ - async fn unmarshal(s: &mut impl AsyncRead) -> Result; -} - -#[async_trait] -impl Unmarshal for Header { - async fn unmarshal(s: &mut impl AsyncRead) -> Result { +impl Header { + pub async fn async_unmarshal(s: &mut impl AsyncRead) -> Result { todo!() } } From 6229a08e61f227409901abcc2a51e27b8c0952cd Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 13:10:30 +0900 Subject: [PATCH 063/103] implement async (un)marshal for protocol --- tuic-quinn/src/lib.rs | 8 +- tuic/Cargo.toml | 5 +- tuic/src/lib.rs | 2 +- tuic/src/marshal.rs | 86 ++++++++++++++++-- tuic/src/model/authenticate.rs | 4 +- tuic/src/model/dissociate.rs | 4 +- tuic/src/model/packet.rs | 12 +-- tuic/src/protocol/address.rs | 95 -------------------- tuic/src/protocol/authenticate.rs | 14 ++- tuic/src/protocol/connect.rs | 10 +-- tuic/src/protocol/dissociate.rs | 14 ++- tuic/src/protocol/heartbeat.rs | 10 +-- tuic/src/protocol/mod.rs | 114 +++++++++++++++++++++--- tuic/src/protocol/packet.rs | 40 ++++----- tuic/src/unmarshal.rs | 140 ++++++++++++++++++++++++++++-- 15 files changed, 372 insertions(+), 186 deletions(-) delete mode 100644 tuic/src/protocol/address.rs diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 3c8b2ea..d84a33c 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -87,7 +87,7 @@ impl<'conn, Side> Connection<'conn, Side> { model: PacketModel, mut recv: &mut RecvStream, ) -> Result, Error> { - let mut buf = vec![0; *model.size() as usize]; + let mut buf = vec![0; model.size() as usize]; AsyncReadExt::read_exact(&mut recv, &mut buf).await?; let mut asm = Vec::new(); @@ -183,7 +183,7 @@ impl<'conn> Connection<'conn, side::Client> { Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); let pos = dg.position() as usize; - let buf = dg.into_inner().slice(pos..pos + *model.size() as usize); + let buf = dg.into_inner().slice(pos..pos + model.size() as usize); Ok(Task::Packet(self.accept_packet_native(model, buf).await?)) } Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), @@ -206,7 +206,7 @@ impl<'conn> Connection<'conn, side::Server> { match Header::async_unmarshal(&mut recv).await? { Header::Authenticate(auth) => { let model = self.model.recv_authenticate(auth); - Ok(Task::Authenticate(*model.token())) + Ok(Task::Authenticate(model.token())) } Header::Connect(_) => Err(Error::BadCommand("connect")), Header::Packet(pkt) => { @@ -251,7 +251,7 @@ impl<'conn> Connection<'conn, side::Server> { Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); let pos = dg.position() as usize; - let buf = dg.into_inner().slice(pos..pos + *model.size() as usize); + let buf = dg.into_inner().slice(pos..pos + model.size() as usize); Ok(Task::Packet(self.accept_packet_native(model, buf).await?)) } Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index f19b12a..2774890 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -4,11 +4,12 @@ version = "0.1.0" edition = "2021" [features] -async_marshal = ["futures-io"] +async_marshal = ["bytes", "futures-util"] model = ["parking_lot", "thiserror"] [dependencies] -futures-io = { version = "0.3.25", default-features = false, features = ["std"], optional = true } +bytes = { version = "1.3.0", default-features = false, features = ["std"], optional = true } +futures-util = { version = "0.3.25", default-features = false, features = ["io", "std"], optional = true } parking_lot = { version = "0.12.1", default-features = false, optional = true } thiserror = { version = "1.0.38", default-features = false, optional = true } diff --git a/tuic/src/lib.rs b/tuic/src/lib.rs index 05c95ad..cfa0ac6 100644 --- a/tuic/src/lib.rs +++ b/tuic/src/lib.rs @@ -3,7 +3,7 @@ mod protocol; pub use self::protocol::{ - Address, Authenticate, Command, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, + Address, Authenticate, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, }; #[cfg(feature = "async_marshal")] diff --git a/tuic/src/marshal.rs b/tuic/src/marshal.rs index 2cbeba4..2c756bf 100644 --- a/tuic/src/marshal.rs +++ b/tuic/src/marshal.rs @@ -1,9 +1,85 @@ -use crate::protocol::Header; -use futures_io::AsyncWrite; -use std::io::Error as IoError; +use crate::protocol::{ + Address, Authenticate, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, +}; +use bytes::BufMut; +use futures_util::{AsyncWrite, AsyncWriteExt}; +use std::{io::Error as IoError, net::SocketAddr}; impl Header { - pub async fn async_marshal(&self, s: &mut impl AsyncWrite) -> Result<(), IoError> { - todo!() + pub async fn async_marshal(&self, s: &mut (impl AsyncWrite + Unpin)) -> Result<(), IoError> { + let mut buf = vec![0; self.len()]; + self.write(&mut buf); + s.write_all(&buf).await + } + + pub fn write(&self, buf: &mut impl BufMut) { + buf.put_u8(VERSION); + buf.put_u8(self.type_code()); + + match self { + Self::Authenticate(auth) => auth.write(buf), + Self::Connect(conn) => conn.write(buf), + Self::Packet(packet) => packet.write(buf), + Self::Dissociate(dissociate) => dissociate.write(buf), + Self::Heartbeat(heartbeat) => heartbeat.write(buf), + } } } + +impl Address { + fn write(&self, buf: &mut impl BufMut) { + buf.put_u8(self.type_code()); + + match self { + Self::None => {} + Self::DomainAddress(domain, port) => { + buf.put_u8(domain.len() as u8); + buf.put_slice(domain.as_bytes()); + buf.put_u16(*port); + } + Self::SocketAddress(SocketAddr::V4(addr)) => { + buf.put_slice(&addr.ip().octets()); + buf.put_u16(addr.port()); + } + Self::SocketAddress(SocketAddr::V6(addr)) => { + for seg in addr.ip().segments() { + buf.put_u16(seg); + } + buf.put_u16(addr.port()); + } + } + } +} + +impl Authenticate { + fn write(&self, buf: &mut impl BufMut) { + buf.put_slice(&self.token()); + } +} + +impl Connect { + fn write(&self, buf: &mut impl BufMut) { + self.addr().write(buf); + } +} + +impl Packet { + fn write(&self, buf: &mut impl BufMut) { + buf.put_u16(self.assoc_id()); + buf.put_u16(self.pkt_id()); + buf.put_u8(self.frag_total()); + buf.put_u8(self.frag_id()); + buf.put_u16(self.size()); + self.addr().write(buf); + } +} + +impl Dissociate { + fn write(&self, buf: &mut impl BufMut) { + buf.put_u16(self.assoc_id()); + } +} + +impl Heartbeat { + fn write(&self, _buf: &mut impl BufMut) {} +} diff --git a/tuic/src/model/authenticate.rs b/tuic/src/model/authenticate.rs index ca54f39..e02940d 100644 --- a/tuic/src/model/authenticate.rs +++ b/tuic/src/model/authenticate.rs @@ -38,8 +38,8 @@ impl Authenticate { } } - pub fn token(&self) -> &[u8; 8] { + pub fn token(&self) -> [u8; 8] { let Side::Rx(rx) = &self.inner else { unreachable!() }; - &rx.token + rx.token } } diff --git a/tuic/src/model/dissociate.rs b/tuic/src/model/dissociate.rs index e6e4072..79321b0 100644 --- a/tuic/src/model/dissociate.rs +++ b/tuic/src/model/dissociate.rs @@ -38,8 +38,8 @@ impl Dissociate { } } - pub fn assoc_id(&self) -> &u16 { + pub fn assoc_id(&self) -> u16 { let Side::Rx(rx) = &self.inner else { unreachable!() }; - &rx.assoc_id + rx.assoc_id } } diff --git a/tuic/src/model/packet.rs b/tuic/src/model/packet.rs index ea336e3..6bc5fe1 100644 --- a/tuic/src/model/packet.rs +++ b/tuic/src/model/packet.rs @@ -39,9 +39,9 @@ impl Packet { Fragments::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload) } - pub fn assoc_id(&self) -> &u16 { + pub fn assoc_id(&self) -> u16 { let Side::Tx(tx) = &self.inner else { unreachable!() }; - &tx.assoc_id + tx.assoc_id } pub fn addr(&self) -> &Address { @@ -102,9 +102,9 @@ where ) } - pub fn assoc_id(&self) -> &u16 { + pub fn assoc_id(&self) -> u16 { let Side::Rx(rx) = &self.inner else { unreachable!() }; - &rx.assoc_id + rx.assoc_id } pub fn addr(&self) -> &Address { @@ -112,9 +112,9 @@ where &rx.addr } - pub fn size(&self) -> &u16 { + pub fn size(&self) -> u16 { let Side::Rx(rx) = &self.inner else { unreachable!() }; - &rx.size + rx.size } } diff --git a/tuic/src/protocol/address.rs b/tuic/src/protocol/address.rs deleted file mode 100644 index 78eaefc..0000000 --- a/tuic/src/protocol/address.rs +++ /dev/null @@ -1,95 +0,0 @@ -use std::{ - fmt::{Display, Formatter, Result as FmtResult}, - mem, - net::SocketAddr, -}; - -/// Address -/// -/// ```plain -/// +------+----------+ -/// | TYPE | ADDR | -/// +------+----------+ -/// | 1 | Variable | -/// +------+----------+ -/// ``` -/// -/// The address type can be one of the following: -/// -/// - 0xff: None -/// - 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name) -/// - 0x01: IPv4 address -/// - 0x02: IPv6 address -/// -/// The port number is encoded in 2 bytes after the Domain name / IP address. -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub enum Address { - None, - DomainAddress(String, u16), - SocketAddress(SocketAddr), -} - -impl Address { - pub const TYPE_CODE_NONE: u8 = 0xff; - pub const TYPE_CODE_DOMAIN: u8 = 0x00; - pub const TYPE_CODE_IPV4: u8 = 0x01; - pub const TYPE_CODE_IPV6: u8 = 0x02; - - pub fn type_code(&self) -> u8 { - match self { - Self::None => Self::TYPE_CODE_NONE, - Self::DomainAddress(_, _) => Self::TYPE_CODE_DOMAIN, - Self::SocketAddress(addr) => match addr { - SocketAddr::V4(_) => Self::TYPE_CODE_IPV4, - SocketAddr::V6(_) => Self::TYPE_CODE_IPV6, - }, - } - } - - pub fn len(&self) -> usize { - 1 + match self { - Address::None => 0, - Address::DomainAddress(addr, _) => 1 + addr.len() + 2, - Address::SocketAddress(addr) => match addr { - SocketAddr::V4(_) => 1 * 4 + 2, - SocketAddr::V6(_) => 2 * 8 + 2, - }, - } - } - - pub fn take(&mut self) -> Self { - mem::take(self) - } - - pub fn is_none(&self) -> bool { - matches!(self, Self::None) - } - - pub fn is_domain(&self) -> bool { - matches!(self, Self::DomainAddress(_, _)) - } - - pub fn is_ipv4(&self) -> bool { - matches!(self, Self::SocketAddress(SocketAddr::V4(_))) - } - - pub fn is_ipv6(&self) -> bool { - matches!(self, Self::SocketAddress(SocketAddr::V6(_))) - } -} - -impl Display for Address { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - Self::None => write!(f, "none"), - Self::DomainAddress(addr, port) => write!(f, "{addr}:{port}"), - Self::SocketAddress(addr) => write!(f, "{addr}"), - } - } -} - -impl Default for Address { - fn default() -> Self { - Self::None - } -} diff --git a/tuic/src/protocol/authenticate.rs b/tuic/src/protocol/authenticate.rs index eebbafd..b17f791 100644 --- a/tuic/src/protocol/authenticate.rs +++ b/tuic/src/protocol/authenticate.rs @@ -1,5 +1,3 @@ -use super::Command; - // +-------+ // | TOKEN | // +-------+ @@ -11,23 +9,21 @@ pub struct Authenticate { } impl Authenticate { - pub(super) const TYPE_CODE: u8 = 0x00; + const TYPE_CODE: u8 = 0x00; pub const fn new(token: [u8; 8]) -> Self { Self { token } } - pub fn token(&self) -> &[u8; 8] { - &self.token + pub fn token(&self) -> [u8; 8] { + self.token } -} -impl Command for Authenticate { - fn type_code() -> u8 { + pub const fn type_code() -> u8 { Self::TYPE_CODE } - fn len(&self) -> usize { + pub fn len(&self) -> usize { 8 } } diff --git a/tuic/src/protocol/connect.rs b/tuic/src/protocol/connect.rs index 20a4d96..6139c83 100644 --- a/tuic/src/protocol/connect.rs +++ b/tuic/src/protocol/connect.rs @@ -1,4 +1,4 @@ -use super::{Address, Command}; +use super::Address; // +----------+ // | ADDR | @@ -11,7 +11,7 @@ pub struct Connect { } impl Connect { - pub(super) const TYPE_CODE: u8 = 0x01; + const TYPE_CODE: u8 = 0x01; pub const fn new(addr: Address) -> Self { Self { addr } @@ -20,14 +20,12 @@ impl Connect { pub fn addr(&self) -> &Address { &self.addr } -} -impl Command for Connect { - fn type_code() -> u8 { + pub const fn type_code() -> u8 { Self::TYPE_CODE } - fn len(&self) -> usize { + pub fn len(&self) -> usize { self.addr.len() } } diff --git a/tuic/src/protocol/dissociate.rs b/tuic/src/protocol/dissociate.rs index 94734f5..86caa19 100644 --- a/tuic/src/protocol/dissociate.rs +++ b/tuic/src/protocol/dissociate.rs @@ -1,5 +1,3 @@ -use super::Command; - // +----------+ // | ASSOC_ID | // +----------+ @@ -11,23 +9,21 @@ pub struct Dissociate { } impl Dissociate { - pub(super) const TYPE_CODE: u8 = 0x03; + const TYPE_CODE: u8 = 0x03; pub const fn new(assoc_id: u16) -> Self { Self { assoc_id } } - pub fn assoc_id(&self) -> &u16 { - &self.assoc_id + pub fn assoc_id(&self) -> u16 { + self.assoc_id } -} -impl Command for Dissociate { - fn type_code() -> u8 { + pub const fn type_code() -> u8 { Self::TYPE_CODE } - fn len(&self) -> usize { + pub fn len(&self) -> usize { 2 } } diff --git a/tuic/src/protocol/heartbeat.rs b/tuic/src/protocol/heartbeat.rs index 91087e8..dd8143a 100644 --- a/tuic/src/protocol/heartbeat.rs +++ b/tuic/src/protocol/heartbeat.rs @@ -1,5 +1,3 @@ -use super::Command; - // +-+ // | | // +-+ @@ -9,19 +7,17 @@ use super::Command; pub struct Heartbeat; impl Heartbeat { - pub(super) const TYPE_CODE: u8 = 0x04; + const TYPE_CODE: u8 = 0x04; pub const fn new() -> Self { Self } -} -impl Command for Heartbeat { - fn type_code() -> u8 { + pub const fn type_code() -> u8 { Self::TYPE_CODE } - fn len(&self) -> usize { + pub fn len(&self) -> usize { 0 } } diff --git a/tuic/src/protocol/mod.rs b/tuic/src/protocol/mod.rs index c890893..3ade4cb 100644 --- a/tuic/src/protocol/mod.rs +++ b/tuic/src/protocol/mod.rs @@ -1,4 +1,9 @@ -mod address; +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + mem, + net::SocketAddr, +}; + mod authenticate; mod connect; mod dissociate; @@ -6,8 +11,8 @@ mod heartbeat; mod packet; pub use self::{ - address::Address, authenticate::Authenticate, connect::Connect, dissociate::Dissociate, - heartbeat::Heartbeat, packet::Packet, + authenticate::Authenticate, connect::Connect, dissociate::Dissociate, heartbeat::Heartbeat, + packet::Packet, }; pub const VERSION: u8 = 0x05; @@ -32,13 +37,13 @@ pub enum Header { } impl Header { - pub const TYPE_CODE_AUTHENTICATE: u8 = Authenticate::TYPE_CODE; - pub const TYPE_CODE_CONNECT: u8 = Connect::TYPE_CODE; - pub const TYPE_CODE_PACKET: u8 = Packet::TYPE_CODE; - pub const TYPE_CODE_DISSOCIATE: u8 = Dissociate::TYPE_CODE; - pub const TYPE_CODE_HEARTBEAT: u8 = Heartbeat::TYPE_CODE; + pub const TYPE_CODE_AUTHENTICATE: u8 = Authenticate::type_code(); + pub const TYPE_CODE_CONNECT: u8 = Connect::type_code(); + pub const TYPE_CODE_PACKET: u8 = Packet::type_code(); + pub const TYPE_CODE_DISSOCIATE: u8 = Dissociate::type_code(); + pub const TYPE_CODE_HEARTBEAT: u8 = Heartbeat::type_code(); - pub fn type_code(&self) -> u8 { + pub const fn type_code(&self) -> u8 { match self { Self::Authenticate(_) => Authenticate::type_code(), Self::Connect(_) => Connect::type_code(), @@ -59,7 +64,92 @@ impl Header { } } -pub trait Command { - fn type_code() -> u8; - fn len(&self) -> usize; +/// Address +/// +/// ```plain +/// +------+----------+ +/// | TYPE | ADDR | +/// +------+----------+ +/// | 1 | Variable | +/// +------+----------+ +/// ``` +/// +/// The address type can be one of the following: +/// +/// - 0xff: None +/// - 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name) +/// - 0x01: IPv4 address +/// - 0x02: IPv6 address +/// +/// The port number is encoded in 2 bytes after the Domain name / IP address. +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Address { + None, + DomainAddress(String, u16), + SocketAddress(SocketAddr), +} + +impl Address { + pub const TYPE_CODE_NONE: u8 = 0xff; + pub const TYPE_CODE_DOMAIN: u8 = 0x00; + pub const TYPE_CODE_IPV4: u8 = 0x01; + pub const TYPE_CODE_IPV6: u8 = 0x02; + + pub const fn type_code(&self) -> u8 { + match self { + Self::None => Self::TYPE_CODE_NONE, + Self::DomainAddress(_, _) => Self::TYPE_CODE_DOMAIN, + Self::SocketAddress(addr) => match addr { + SocketAddr::V4(_) => Self::TYPE_CODE_IPV4, + SocketAddr::V6(_) => Self::TYPE_CODE_IPV6, + }, + } + } + + pub fn len(&self) -> usize { + 1 + match self { + Address::None => 0, + Address::DomainAddress(addr, _) => 1 + addr.len() + 2, + Address::SocketAddress(addr) => match addr { + SocketAddr::V4(_) => 1 * 4 + 2, + SocketAddr::V6(_) => 2 * 8 + 2, + }, + } + } + + pub fn take(&mut self) -> Self { + mem::take(self) + } + + pub fn is_none(&self) -> bool { + matches!(self, Self::None) + } + + pub fn is_domain(&self) -> bool { + matches!(self, Self::DomainAddress(_, _)) + } + + pub fn is_ipv4(&self) -> bool { + matches!(self, Self::SocketAddress(SocketAddr::V4(_))) + } + + pub fn is_ipv6(&self) -> bool { + matches!(self, Self::SocketAddress(SocketAddr::V6(_))) + } +} + +impl Display for Address { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + Self::None => write!(f, "none"), + Self::DomainAddress(addr, port) => write!(f, "{addr}:{port}"), + Self::SocketAddress(addr) => write!(f, "{addr}"), + } + } +} + +impl Default for Address { + fn default() -> Self { + Self::None + } } diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs index 71c79e5..d9c6a5d 100644 --- a/tuic/src/protocol/packet.rs +++ b/tuic/src/protocol/packet.rs @@ -1,4 +1,4 @@ -use super::{Address, Command}; +use super::Address; // +----------+--------+------------+---------+------+----------+ // | ASSOC_ID | PKT_ID | FRAG_TOTAL | FRAG_ID | SIZE | ADDR | @@ -16,7 +16,7 @@ pub struct Packet { } impl Packet { - pub(super) const TYPE_CODE: u8 = 0x02; + const TYPE_CODE: u8 = 0x02; pub const fn new( assoc_id: u16, @@ -36,42 +36,40 @@ impl Packet { } } - pub fn assoc_id(&self) -> &u16 { - &self.assoc_id + pub fn assoc_id(&self) -> u16 { + self.assoc_id } - pub fn pkt_id(&self) -> &u16 { - &self.pkt_id + pub fn pkt_id(&self) -> u16 { + self.pkt_id } - pub fn frag_total(&self) -> &u8 { - &self.frag_total + pub fn frag_total(&self) -> u8 { + self.frag_total } - pub fn frag_id(&self) -> &u8 { - &self.frag_id + pub fn frag_id(&self) -> u8 { + self.frag_id } - pub fn size(&self) -> &u16 { - &self.size + pub fn size(&self) -> u16 { + self.size } pub fn addr(&self) -> &Address { &self.addr } - pub const fn len_without_addr() -> usize { - 2 + 2 + 1 + 1 + 2 - } -} - -impl Command for Packet { - fn type_code() -> u8 { + pub const fn type_code() -> u8 { Self::TYPE_CODE } - fn len(&self) -> usize { - 2 + 2 + 1 + 1 + 2 + self.addr.len() + pub fn len(&self) -> usize { + Self::len_without_addr() + self.addr.len() + } + + pub const fn len_without_addr() -> usize { + 2 + 2 + 1 + 1 + 2 } } diff --git a/tuic/src/unmarshal.rs b/tuic/src/unmarshal.rs index aa1fd56..de2bb1a 100644 --- a/tuic/src/unmarshal.rs +++ b/tuic/src/unmarshal.rs @@ -1,12 +1,142 @@ -use crate::protocol::Header; -use futures_io::AsyncRead; +use crate::protocol::{ + Address, Authenticate, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, +}; +use futures_util::{AsyncRead, AsyncReadExt}; +use std::{io::Error as IoError, net::SocketAddr, string::FromUtf8Error}; use thiserror::Error; impl Header { - pub async fn async_unmarshal(s: &mut impl AsyncRead) -> Result { - todo!() + pub async fn async_unmarshal(s: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buf = [0; 1]; + s.read_exact(&mut buf).await?; + let ver = buf[0]; + + if ver != VERSION { + return Err(UnmarshalError::InvalidVersion(ver)); + } + + let mut buf = [0; 1]; + s.read_exact(&mut buf).await?; + let cmd = buf[0]; + + match cmd { + Header::TYPE_CODE_AUTHENTICATE => { + Authenticate::async_read(s).await.map(Self::Authenticate) + } + Header::TYPE_CODE_CONNECT => Connect::async_read(s).await.map(Self::Connect), + Header::TYPE_CODE_PACKET => Packet::async_read(s).await.map(Self::Packet), + Header::TYPE_CODE_DISSOCIATE => Dissociate::async_read(s).await.map(Self::Dissociate), + Header::TYPE_CODE_HEARTBEAT => Heartbeat::async_read(s).await.map(Self::Heartbeat), + _ => Err(UnmarshalError::InvalidCommand(cmd)), + } + } +} + +impl Address { + async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buf = [0; 1]; + s.read_exact(&mut buf).await?; + let type_code = buf[0]; + + match type_code { + Address::TYPE_CODE_NONE => Ok(Self::None), + Address::TYPE_CODE_DOMAIN => { + let mut buf = [0; 1]; + s.read_exact(&mut buf).await?; + let len = buf[0] as usize; + + let mut buf = vec![0; len + 2]; + s.read_exact(&mut buf).await?; + let port = u16::from_be_bytes([buf[len], buf[len + 1]]); + buf.truncate(len); + let domain = String::from_utf8(buf)?; + + Ok(Self::DomainAddress(domain, port)) + } + Address::TYPE_CODE_IPV4 => { + let mut buf = [0; 6]; + s.read_exact(&mut buf).await?; + let ip = [buf[0], buf[1], buf[2], buf[3]]; + let port = u16::from_be_bytes([buf[4], buf[5]]); + Ok(Self::SocketAddress(SocketAddr::from((ip, port)))) + } + Address::TYPE_CODE_IPV6 => { + let mut buf = [0; 18]; + s.read_exact(&mut buf).await?; + let ip = [ + u16::from_be_bytes([buf[0], buf[1]]), + u16::from_be_bytes([buf[2], buf[3]]), + u16::from_be_bytes([buf[4], buf[5]]), + u16::from_be_bytes([buf[6], buf[7]]), + u16::from_be_bytes([buf[8], buf[9]]), + u16::from_be_bytes([buf[10], buf[11]]), + u16::from_be_bytes([buf[12], buf[13]]), + u16::from_be_bytes([buf[14], buf[15]]), + ]; + let port = u16::from_be_bytes([buf[16], buf[17]]); + + Ok(Self::SocketAddress(SocketAddr::from((ip, port)))) + } + _ => Err(UnmarshalError::InvalidAddressType(type_code)), + } + } +} + +impl Authenticate { + async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buf = [0; 8]; + s.read_exact(&mut buf).await?; + Ok(Self::new(buf)) + } +} + +impl Connect { + async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { + Ok(Self::new(Address::async_read(s).await?)) + } +} + +impl Packet { + async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buf = [0; 8]; + s.read_exact(&mut buf).await?; + + let assoc_id = u16::from_be_bytes([buf[0], buf[1]]); + let pkt_id = u16::from_be_bytes([buf[2], buf[3]]); + let frag_total = buf[4]; + let frag_id = buf[5]; + let size = u16::from_be_bytes([buf[6], buf[7]]); + let addr = Address::async_read(s).await?; + + Ok(Self::new(assoc_id, pkt_id, frag_total, frag_id, size, addr)) + } +} + +impl Dissociate { + async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buf = [0; 2]; + s.read_exact(&mut buf).await?; + let assoc_id = u16::from_be_bytes(buf); + Ok(Self::new(assoc_id)) + } +} + +impl Heartbeat { + async fn async_read(_s: &mut (impl AsyncRead + Unpin)) -> Result { + Ok(Self::new()) } } #[derive(Debug, Error)] -pub enum UnmarshalError {} +pub enum UnmarshalError { + #[error(transparent)] + Io(#[from] IoError), + #[error("invalid version: {0}")] + InvalidVersion(u8), + #[error("invalid command: {0}")] + InvalidCommand(u8), + #[error("invalid address type: {0}")] + InvalidAddressType(u8), + #[error("address parsing error: {0}")] + AddressParse(#[from] FromUtf8Error), +} From 8b985b15867b0411a626bb854b0a6bb40b741ac7 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 13:22:45 +0900 Subject: [PATCH 064/103] implement blocking (un)marshal for protocol --- tuic/Cargo.toml | 3 +- tuic/src/lib.rs | 6 +- tuic/src/marshal.rs | 13 ++++- tuic/src/unmarshal.rs | 126 +++++++++++++++++++++++++++++++++++++++++- 4 files changed, 142 insertions(+), 6 deletions(-) diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index 2774890..59cc861 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [features] async_marshal = ["bytes", "futures-util"] +marshal = ["bytes"] model = ["parking_lot", "thiserror"] [dependencies] @@ -14,4 +15,4 @@ parking_lot = { version = "0.12.1", default-features = false, optional = true } thiserror = { version = "1.0.38", default-features = false, optional = true } [dev-dependencies] -tuic = { path = ".", features = ["async_marshal", "model"] } +tuic = { path = ".", features = ["async_marshal", "marshal", "model"] } diff --git a/tuic/src/lib.rs b/tuic/src/lib.rs index cfa0ac6..bb3aaf5 100644 --- a/tuic/src/lib.rs +++ b/tuic/src/lib.rs @@ -6,13 +6,13 @@ pub use self::protocol::{ Address, Authenticate, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, }; -#[cfg(feature = "async_marshal")] +#[cfg(any(feature = "async_marshal", feature = "marshal"))] mod marshal; -#[cfg(feature = "async_marshal")] +#[cfg(any(feature = "async_marshal", feature = "marshal"))] mod unmarshal; -#[cfg(feature = "async_marshal")] +#[cfg(any(feature = "async_marshal", feature = "marshal"))] pub use self::unmarshal::UnmarshalError; #[cfg(feature = "model")] diff --git a/tuic/src/marshal.rs b/tuic/src/marshal.rs index 2c756bf..94a9068 100644 --- a/tuic/src/marshal.rs +++ b/tuic/src/marshal.rs @@ -3,15 +3,26 @@ use crate::protocol::{ }; use bytes::BufMut; use futures_util::{AsyncWrite, AsyncWriteExt}; -use std::{io::Error as IoError, net::SocketAddr}; +use std::{ + io::{Error as IoError, Write}, + net::SocketAddr, +}; impl Header { + #[cfg(feature = "async_marshal")] pub async fn async_marshal(&self, s: &mut (impl AsyncWrite + Unpin)) -> Result<(), IoError> { let mut buf = vec![0; self.len()]; self.write(&mut buf); s.write_all(&buf).await } + #[cfg(feature = "marshal")] + pub fn marshal(&self, s: &mut impl Write) -> Result<(), IoError> { + let mut buf = vec![0; self.len()]; + self.write(&mut buf); + s.write_all(&buf) + } + pub fn write(&self, buf: &mut impl BufMut) { buf.put_u8(VERSION); buf.put_u8(self.type_code()); diff --git a/tuic/src/unmarshal.rs b/tuic/src/unmarshal.rs index de2bb1a..7b4b48a 100644 --- a/tuic/src/unmarshal.rs +++ b/tuic/src/unmarshal.rs @@ -2,10 +2,15 @@ use crate::protocol::{ Address, Authenticate, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, }; use futures_util::{AsyncRead, AsyncReadExt}; -use std::{io::Error as IoError, net::SocketAddr, string::FromUtf8Error}; +use std::{ + io::{Error as IoError, Read}, + net::SocketAddr, + string::FromUtf8Error, +}; use thiserror::Error; impl Header { + #[cfg(feature = "async_marshal")] pub async fn async_unmarshal(s: &mut (impl AsyncRead + Unpin)) -> Result { let mut buf = [0; 1]; s.read_exact(&mut buf).await?; @@ -30,9 +35,34 @@ impl Header { _ => Err(UnmarshalError::InvalidCommand(cmd)), } } + + #[cfg(feature = "marshal")] + pub fn unmarshal(s: &mut impl Read) -> Result { + let mut buf = [0; 1]; + s.read_exact(&mut buf)?; + let ver = buf[0]; + + if ver != VERSION { + return Err(UnmarshalError::InvalidVersion(ver)); + } + + let mut buf = [0; 1]; + s.read_exact(&mut buf)?; + let cmd = buf[0]; + + match cmd { + Header::TYPE_CODE_AUTHENTICATE => Authenticate::read(s).map(Self::Authenticate), + Header::TYPE_CODE_CONNECT => Connect::read(s).map(Self::Connect), + Header::TYPE_CODE_PACKET => Packet::read(s).map(Self::Packet), + Header::TYPE_CODE_DISSOCIATE => Dissociate::read(s).map(Self::Dissociate), + Header::TYPE_CODE_HEARTBEAT => Heartbeat::read(s).map(Self::Heartbeat), + _ => Err(UnmarshalError::InvalidCommand(cmd)), + } + } } impl Address { + #[cfg(feature = "async_marshal")] async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { let mut buf = [0; 1]; s.read_exact(&mut buf).await?; @@ -80,23 +110,87 @@ impl Address { _ => Err(UnmarshalError::InvalidAddressType(type_code)), } } + + #[cfg(feature = "marshal")] + fn read(s: &mut impl Read) -> Result { + let mut buf = [0; 1]; + s.read_exact(&mut buf)?; + let type_code = buf[0]; + + match type_code { + Address::TYPE_CODE_NONE => Ok(Self::None), + Address::TYPE_CODE_DOMAIN => { + let mut buf = [0; 1]; + s.read_exact(&mut buf)?; + let len = buf[0] as usize; + + let mut buf = vec![0; len + 2]; + s.read_exact(&mut buf)?; + let port = u16::from_be_bytes([buf[len], buf[len + 1]]); + buf.truncate(len); + let domain = String::from_utf8(buf)?; + + Ok(Self::DomainAddress(domain, port)) + } + Address::TYPE_CODE_IPV4 => { + let mut buf = [0; 6]; + s.read_exact(&mut buf)?; + let ip = [buf[0], buf[1], buf[2], buf[3]]; + let port = u16::from_be_bytes([buf[4], buf[5]]); + Ok(Self::SocketAddress(SocketAddr::from((ip, port)))) + } + Address::TYPE_CODE_IPV6 => { + let mut buf = [0; 18]; + s.read_exact(&mut buf)?; + let ip = [ + u16::from_be_bytes([buf[0], buf[1]]), + u16::from_be_bytes([buf[2], buf[3]]), + u16::from_be_bytes([buf[4], buf[5]]), + u16::from_be_bytes([buf[6], buf[7]]), + u16::from_be_bytes([buf[8], buf[9]]), + u16::from_be_bytes([buf[10], buf[11]]), + u16::from_be_bytes([buf[12], buf[13]]), + u16::from_be_bytes([buf[14], buf[15]]), + ]; + let port = u16::from_be_bytes([buf[16], buf[17]]); + + Ok(Self::SocketAddress(SocketAddr::from((ip, port)))) + } + _ => Err(UnmarshalError::InvalidAddressType(type_code)), + } + } } impl Authenticate { + #[cfg(feature = "async_marshal")] async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { let mut buf = [0; 8]; s.read_exact(&mut buf).await?; Ok(Self::new(buf)) } + + #[cfg(feature = "marshal")] + fn read(s: &mut impl Read) -> Result { + let mut buf = [0; 8]; + s.read_exact(&mut buf)?; + Ok(Self::new(buf)) + } } impl Connect { + #[cfg(feature = "async_marshal")] async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { Ok(Self::new(Address::async_read(s).await?)) } + + #[cfg(feature = "marshal")] + fn read(s: &mut impl Read) -> Result { + Ok(Self::new(Address::read(s)?)) + } } impl Packet { + #[cfg(feature = "async_marshal")] async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { let mut buf = [0; 8]; s.read_exact(&mut buf).await?; @@ -110,21 +204,51 @@ impl Packet { Ok(Self::new(assoc_id, pkt_id, frag_total, frag_id, size, addr)) } + + #[cfg(feature = "marshal")] + fn read(s: &mut impl Read) -> Result { + let mut buf = [0; 8]; + s.read_exact(&mut buf)?; + + let assoc_id = u16::from_be_bytes([buf[0], buf[1]]); + let pkt_id = u16::from_be_bytes([buf[2], buf[3]]); + let frag_total = buf[4]; + let frag_id = buf[5]; + let size = u16::from_be_bytes([buf[6], buf[7]]); + let addr = Address::read(s)?; + + Ok(Self::new(assoc_id, pkt_id, frag_total, frag_id, size, addr)) + } } impl Dissociate { + #[cfg(feature = "async_marshal")] async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { let mut buf = [0; 2]; s.read_exact(&mut buf).await?; let assoc_id = u16::from_be_bytes(buf); Ok(Self::new(assoc_id)) } + + #[cfg(feature = "marshal")] + fn read(s: &mut impl Read) -> Result { + let mut buf = [0; 2]; + s.read_exact(&mut buf)?; + let assoc_id = u16::from_be_bytes(buf); + Ok(Self::new(assoc_id)) + } } impl Heartbeat { + #[cfg(feature = "async_marshal")] async fn async_read(_s: &mut (impl AsyncRead + Unpin)) -> Result { Ok(Self::new()) } + + #[cfg(feature = "marshal")] + fn read(_s: &mut impl Read) -> Result { + Ok(Self::new()) + } } #[derive(Debug, Error)] From 8d4c3dc8da5f5896362cb63120e640a18a8ec863 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 13:41:57 +0900 Subject: [PATCH 065/103] removing several unnecessary async marks --- tuic-quinn/Cargo.toml | 2 +- tuic-quinn/src/lib.rs | 41 +++++++++++++++++++++++++---------------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/tuic-quinn/Cargo.toml b/tuic-quinn/Cargo.toml index f51f7ce..fa901c1 100644 --- a/tuic-quinn/Cargo.toml +++ b/tuic-quinn/Cargo.toml @@ -8,4 +8,4 @@ bytes = { version = "1.3.0", default-features = false, features = ["std"] } futures-util = { version = "0.3.25", default-features = false, features = ["io", "std"] } quinn = { version = "0.9.3", default-features = false, features = ["futures-io"] } thiserror = { version = "1.0.38", default-features = false } -tuic = { path = "../tuic", default-features = false, features = ["async_marshal", "model"] } +tuic = { path = "../tuic", default-features = false, features = ["async_marshal", "marshal", "model"] } diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index d84a33c..9d41dc2 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -1,11 +1,11 @@ use self::side::Side; -use bytes::Bytes; -use futures_util::{io::Cursor, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use bytes::{BufMut, Bytes}; +use futures_util::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use quinn::{ Connection as QuinnConnection, ConnectionError, RecvStream, SendDatagramError, SendStream, }; use std::{ - io::Error as IoError, + io::{Cursor, Error as IoError}, pin::Pin, task::{Context, Poll}, time::Duration, @@ -37,7 +37,7 @@ pub struct Connection<'conn, Side> { } impl<'conn, Side> Connection<'conn, Side> { - pub async fn packet_native( + pub fn packet_native( &self, pkt: impl AsRef<[u8]>, addr: Address, @@ -50,10 +50,10 @@ impl<'conn, Side> Connection<'conn, Side> { let model = self.model.send_packet(assoc_id, addr, max_pkt_size); for (header, frag) in model.into_fragments(pkt) { - let mut buf = Cursor::new(vec![0; header.len() + frag.len()]); - header.async_marshal(&mut buf).await?; - buf.write_all(frag).await.unwrap(); - self.conn.send_datagram(Bytes::from(buf.into_inner()))?; + let mut buf = vec![0; header.len() + frag.len()]; + header.write(&mut buf); + buf.put_slice(frag); + self.conn.send_datagram(Bytes::from(buf))?; } Ok(()) @@ -97,7 +97,7 @@ impl<'conn, Side> Connection<'conn, Side> { .map(|(addr, assoc_id)| (Bytes::from(asm), addr, assoc_id))) } - async fn accept_packet_native( + fn accept_packet_native( &self, model: PacketModel, data: Bytes, @@ -174,17 +174,17 @@ impl<'conn> Connection<'conn, side::Client> { } } - pub async fn accept_datagram(&self, dg: Bytes) -> Result { + pub fn accept_datagram(&self, dg: Bytes) -> Result { let mut dg = Cursor::new(dg); - match Header::async_unmarshal(&mut dg).await? { + match Header::unmarshal(&mut dg)? { Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), Header::Connect(_) => Err(Error::BadCommand("connect")), Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); let pos = dg.position() as usize; let buf = dg.into_inner().slice(pos..pos + model.size() as usize); - Ok(Task::Packet(self.accept_packet_native(model, buf).await?)) + Ok(Task::Packet(self.accept_packet_native(model, buf)?)) } Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), @@ -242,17 +242,17 @@ impl<'conn> Connection<'conn, side::Server> { } } - pub async fn accept_datagram(&self, dg: Bytes) -> Result { + pub fn accept_datagram(&self, dg: Bytes) -> Result { let mut dg = Cursor::new(dg); - match Header::async_unmarshal(&mut dg).await? { + match Header::unmarshal(&mut dg)? { Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), Header::Connect(_) => Err(Error::BadCommand("connect")), Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); let pos = dg.position() as usize; let buf = dg.into_inner().slice(pos..pos + model.size() as usize); - Ok(Task::Packet(self.accept_packet_native(model, buf).await?)) + Ok(Task::Packet(self.accept_packet_native(model, buf)?)) } Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), Header::Heartbeat(hb) => { @@ -336,9 +336,18 @@ pub enum Error { #[error(transparent)] SendDatagram(#[from] SendDatagramError), #[error(transparent)] - Unmarshal(#[from] UnmarshalError), + Unmarshal(UnmarshalError), #[error(transparent)] Assemble(#[from] AssembleError), #[error("{0}")] BadCommand(&'static str), } + +impl From for Error { + fn from(err: UnmarshalError) -> Self { + match err { + UnmarshalError::Io(err) => Self::Io(err), + err => Self::Unmarshal(err), + } + } +} From 227a5ddf10da0f5a9d478fde474e4d8d29c4d988 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 13:44:05 +0900 Subject: [PATCH 066/103] new line at end of files --- .gitignore | 2 +- Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 9029b80..a80cdf2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ .vscode/ target/ .DS_Store -Cargo.lock \ No newline at end of file +Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index 05881e9..cede657 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,2 +1,2 @@ [workspace] -members = ["tuic", "tuic-quinn", "tuic-server", "tuic-client"] \ No newline at end of file +members = ["tuic", "tuic-quinn", "tuic-server", "tuic-client"] From 6561afe3a9fd90fb6e3bc5101c6c6ff0055a9170 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 14:20:59 +0900 Subject: [PATCH 067/103] better error handling --- tuic-quinn/src/lib.rs | 117 +++++++++++++++++++++++++++--------------- 1 file changed, 75 insertions(+), 42 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 9d41dc2..d4b6875 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -144,32 +144,42 @@ impl<'conn> Connection<'conn, side::Client> { } pub async fn accept_uni_stream(&self, mut recv: RecvStream) -> Result { - match Header::async_unmarshal(&mut recv).await? { - Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), - Header::Connect(_) => Err(Error::BadCommand("connect")), + let header = match Header::async_unmarshal(&mut recv).await { + Ok(header) => header, + Err(err) => return Err(Error::UnmarshalUniStream(err, recv)), + }; + + match header { + Header::Authenticate(_) => Err(Error::BadCommandUniStream("authenticate", recv)), + Header::Connect(_) => Err(Error::BadCommandUniStream("connect", recv)), Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); Ok(Task::Packet( self.accept_packet_quic(model, &mut recv).await?, )) } - Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), - Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), + Header::Dissociate(_) => Err(Error::BadCommandUniStream("dissociate", recv)), + Header::Heartbeat(_) => Err(Error::BadCommandUniStream("heartbeat", recv)), _ => unreachable!(), } } pub async fn accept_bi_stream( &self, - _send: SendStream, + send: SendStream, mut recv: RecvStream, ) -> Result { - match Header::async_unmarshal(&mut recv).await? { - Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), - Header::Connect(_) => Err(Error::BadCommand("connect")), - Header::Packet(_) => Err(Error::BadCommand("packet")), - Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), - Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), + let header = match Header::async_unmarshal(&mut recv).await { + Ok(header) => header, + Err(err) => return Err(Error::UnmarshalBiStream(err, send, recv)), + }; + + match header { + Header::Authenticate(_) => Err(Error::BadCommandBiStream("authenticate", send, recv)), + Header::Connect(_) => Err(Error::BadCommandBiStream("connect", send, recv)), + Header::Packet(_) => Err(Error::BadCommandBiStream("packet", send, recv)), + Header::Dissociate(_) => Err(Error::BadCommandBiStream("dissociate", send, recv)), + Header::Heartbeat(_) => Err(Error::BadCommandBiStream("heartbeat", send, recv)), _ => unreachable!(), } } @@ -177,17 +187,24 @@ impl<'conn> Connection<'conn, side::Client> { pub fn accept_datagram(&self, dg: Bytes) -> Result { let mut dg = Cursor::new(dg); - match Header::unmarshal(&mut dg)? { - Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), - Header::Connect(_) => Err(Error::BadCommand("connect")), + let header = match Header::unmarshal(&mut dg) { + Ok(header) => header, + Err(err) => return Err(Error::UnmarshalDatagram(err, dg.into_inner())), + }; + + match header { + Header::Authenticate(_) => { + Err(Error::BadCommandDatagram("authenticate", dg.into_inner())) + } + Header::Connect(_) => Err(Error::BadCommandDatagram("connect", dg.into_inner())), Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); let pos = dg.position() as usize; let buf = dg.into_inner().slice(pos..pos + model.size() as usize); Ok(Task::Packet(self.accept_packet_native(model, buf)?)) } - Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), - Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), + Header::Dissociate(_) => Err(Error::BadCommandDatagram("dissociate", dg.into_inner())), + Header::Heartbeat(_) => Err(Error::BadCommandDatagram("heartbeat", dg.into_inner())), _ => unreachable!(), } } @@ -203,12 +220,17 @@ impl<'conn> Connection<'conn, side::Server> { } pub async fn accept_uni_stream(&self, mut recv: RecvStream) -> Result { - match Header::async_unmarshal(&mut recv).await? { + let header = match Header::async_unmarshal(&mut recv).await { + Ok(header) => header, + Err(err) => return Err(Error::UnmarshalUniStream(err, recv)), + }; + + match header { Header::Authenticate(auth) => { let model = self.model.recv_authenticate(auth); Ok(Task::Authenticate(model.token())) } - Header::Connect(_) => Err(Error::BadCommand("connect")), + Header::Connect(_) => Err(Error::BadCommandUniStream("connect", recv)), Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); Ok(Task::Packet( @@ -219,7 +241,7 @@ impl<'conn> Connection<'conn, side::Server> { let _ = self.model.recv_dissociate(dissoc); Ok(Task::Dissociate) } - Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), + Header::Heartbeat(_) => Err(Error::BadCommandUniStream("heartbeat", recv)), _ => unreachable!(), } } @@ -229,15 +251,20 @@ impl<'conn> Connection<'conn, side::Server> { send: SendStream, mut recv: RecvStream, ) -> Result { - match Header::async_unmarshal(&mut recv).await? { - Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), + let header = match Header::async_unmarshal(&mut recv).await { + Ok(header) => header, + Err(err) => return Err(Error::UnmarshalBiStream(err, send, recv)), + }; + + match header { + Header::Authenticate(_) => Err(Error::BadCommandBiStream("authenticate", send, recv)), Header::Connect(conn) => { let model = self.model.recv_connect(conn); Ok(Task::Connect(Connect::new(Side::Server(model), send, recv))) } - Header::Packet(_) => Err(Error::BadCommand("packet")), - Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), - Header::Heartbeat(_) => Err(Error::BadCommand("heartbeat")), + Header::Packet(_) => Err(Error::BadCommandBiStream("packet", send, recv)), + Header::Dissociate(_) => Err(Error::BadCommandBiStream("dissociate", send, recv)), + Header::Heartbeat(_) => Err(Error::BadCommandBiStream("heartbeat", send, recv)), _ => unreachable!(), } } @@ -245,16 +272,23 @@ impl<'conn> Connection<'conn, side::Server> { pub fn accept_datagram(&self, dg: Bytes) -> Result { let mut dg = Cursor::new(dg); - match Header::unmarshal(&mut dg)? { - Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), - Header::Connect(_) => Err(Error::BadCommand("connect")), + let header = match Header::unmarshal(&mut dg) { + Ok(header) => header, + Err(err) => return Err(Error::UnmarshalDatagram(err, dg.into_inner())), + }; + + match header { + Header::Authenticate(_) => { + Err(Error::BadCommandDatagram("authenticate", dg.into_inner())) + } + Header::Connect(_) => Err(Error::BadCommandDatagram("connect", dg.into_inner())), Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); let pos = dg.position() as usize; let buf = dg.into_inner().slice(pos..pos + model.size() as usize); Ok(Task::Packet(self.accept_packet_native(model, buf)?)) } - Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), + Header::Dissociate(_) => Err(Error::BadCommandDatagram("dissociate", dg.into_inner())), Header::Heartbeat(hb) => { let _ = self.model.recv_heartbeat(hb); Ok(Task::Heartbeat) @@ -336,18 +370,17 @@ pub enum Error { #[error(transparent)] SendDatagram(#[from] SendDatagramError), #[error(transparent)] - Unmarshal(UnmarshalError), - #[error(transparent)] Assemble(#[from] AssembleError), - #[error("{0}")] - BadCommand(&'static str), -} - -impl From for Error { - fn from(err: UnmarshalError) -> Self { - match err { - UnmarshalError::Io(err) => Self::Io(err), - err => Self::Unmarshal(err), - } - } + #[error("error unmarshaling uni_stream: {0}")] + UnmarshalUniStream(UnmarshalError, RecvStream), + #[error("error unmarshaling bi_stream: {0}")] + UnmarshalBiStream(UnmarshalError, SendStream, RecvStream), + #[error("error unmarshaling datagram: {0}")] + UnmarshalDatagram(UnmarshalError, Bytes), + #[error("bad command `{0}` from uni_stream")] + BadCommandUniStream(&'static str, RecvStream), + #[error("bad command `{0}` from bi_stream")] + BadCommandBiStream(&'static str, SendStream, RecvStream), + #[error("bad command `{0}` from datagram")] + BadCommandDatagram(&'static str, Bytes), } From d7689c775682afb3b4e4dde7afd68e1b0e56ba40 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 14:26:30 +0900 Subject: [PATCH 068/103] expose `assoc_id` in `Task::Dissociate` --- tuic-quinn/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index d4b6875..4619a21 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -238,8 +238,8 @@ impl<'conn> Connection<'conn, side::Server> { )) } Header::Dissociate(dissoc) => { - let _ = self.model.recv_dissociate(dissoc); - Ok(Task::Dissociate) + let model = self.model.recv_dissociate(dissoc); + Ok(Task::Dissociate(model.assoc_id())) } Header::Heartbeat(_) => Err(Error::BadCommandUniStream("heartbeat", recv)), _ => unreachable!(), @@ -357,7 +357,7 @@ pub enum Task { Authenticate([u8; 8]), Connect(Connect), Packet(Option<(Bytes, Address, u16)>), - Dissociate, + Dissociate(u16), Heartbeat, } From de63151e5d39203e567ea78b8c514f8800659a9c Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 14:37:47 +0900 Subject: [PATCH 069/103] allowing packet from uni_stream to be fragmented --- tuic-quinn/src/lib.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 4619a21..dfeb8c9 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -66,14 +66,13 @@ impl<'conn, Side> Connection<'conn, Side> { assoc_id: u16, ) -> Result<(), Error> { let model = self.model.send_packet(assoc_id, addr, u16::MAX as usize); - let mut frags = model.into_fragments(pkt); - let (header, frag) = frags.next().unwrap(); - assert!(frags.next().is_none()); - let mut send = self.conn.open_uni().await?; - header.async_marshal(&mut send).await?; - AsyncWriteExt::write_all(&mut send, frag).await?; - send.close().await?; + for (header, frag) in model.into_fragments(pkt) { + let mut send = self.conn.open_uni().await?; + header.async_marshal(&mut send).await?; + AsyncWriteExt::write_all(&mut send, frag).await?; + send.close().await?; + } Ok(()) } From 937a21a0cd41c0e26bfaa513f024ee57c57925ed Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 28 Jan 2023 20:15:24 +0900 Subject: [PATCH 070/103] removing lifetime limitation of `tuic-quinn` --- tuic-quinn/src/lib.rs | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index dfeb8c9..636b15a 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -30,13 +30,13 @@ pub mod side { } } -pub struct Connection<'conn, Side> { - conn: &'conn QuinnConnection, +pub struct Connection { + conn: QuinnConnection, model: ConnectionModel, _marker: Side, } -impl<'conn, Side> Connection<'conn, Side> { +impl Connection { pub fn packet_native( &self, pkt: impl AsRef<[u8]>, @@ -77,6 +77,14 @@ impl<'conn, Side> Connection<'conn, Side> { Ok(()) } + pub fn task_connect_count(&self) -> usize { + self.model.task_connect_count() + } + + pub fn task_associate_count(&self) -> usize { + self.model.task_associate_count() + } + pub fn collect_garbage(&self, timeout: Duration) { self.model.collect_garbage(timeout); } @@ -110,8 +118,8 @@ impl<'conn, Side> Connection<'conn, Side> { } } -impl<'conn> Connection<'conn, side::Client> { - pub fn new(conn: &'conn QuinnConnection) -> Self { +impl Connection { + pub fn new(conn: QuinnConnection) -> Self { Self { conn, model: ConnectionModel::new(), @@ -209,8 +217,8 @@ impl<'conn> Connection<'conn, side::Client> { } } -impl<'conn> Connection<'conn, side::Server> { - pub fn new(conn: &'conn QuinnConnection) -> Self { +impl Connection { + pub fn new(conn: QuinnConnection) -> Self { Self { conn, model: ConnectionModel::new(), From e4eca91cea27e1efa75ac0bda41c271782f9a160 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 29 Jan 2023 02:05:27 +0900 Subject: [PATCH 071/103] mark connection model as clone-able --- tuic-quinn/src/lib.rs | 3 +++ tuic/src/model/mod.rs | 1 + 2 files changed, 4 insertions(+) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 636b15a..986eeb0 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -21,7 +21,9 @@ use tuic::{ }; pub mod side { + #[derive(Clone)] pub struct Client; + #[derive(Clone)] pub struct Server; pub(super) enum Side { @@ -30,6 +32,7 @@ pub mod side { } } +#[derive(Clone)] pub struct Connection { conn: QuinnConnection, model: ConnectionModel, diff --git a/tuic/src/model/mod.rs b/tuic/src/model/mod.rs index 98f59b4..a4486dc 100644 --- a/tuic/src/model/mod.rs +++ b/tuic/src/model/mod.rs @@ -28,6 +28,7 @@ pub use self::{ packet::{Fragments, Packet}, }; +#[derive(Clone)] pub struct Connection { udp_sessions: Arc>>, task_connect_count: TaskCount, From 5db0a69f3aa738cfd9117200e9489386b6a6b202 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 29 Jan 2023 19:48:23 +0900 Subject: [PATCH 072/103] implement client endpoint & connection --- tuic-client/Cargo.toml | 13 +- tuic-client/src/config.rs | 61 +++++++++ tuic-client/src/connection.rs | 224 ++++++++++++++++++++++++++++++++++ tuic-client/src/error.rs | 18 +++ tuic-client/src/main.rs | 26 +++- tuic-client/src/socks5.rs | 7 ++ 6 files changed, 345 insertions(+), 4 deletions(-) create mode 100644 tuic-client/src/config.rs create mode 100644 tuic-client/src/connection.rs create mode 100644 tuic-client/src/error.rs create mode 100644 tuic-client/src/socks5.rs diff --git a/tuic-client/Cargo.toml b/tuic-client/Cargo.toml index 035a0de..cce7a68 100644 --- a/tuic-client/Cargo.toml +++ b/tuic-client/Cargo.toml @@ -3,6 +3,15 @@ name = "tuic-client" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +bytes = { version = "1.3.0", default-features = false, features = ["std"] } +lexopt = { version = "0.3.0", default-features = false } +once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] } +parking_lot = { version = "0.12.1", default-features = false } +quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] } +serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] } +serde_json = { version = "1.0.91", default-features = false, features = ["std"] } +thiserror = { version = "1.0.38", default-features = false } +tokio = { version = "1.24.2", default-features = false, features = ["macros", "parking_lot", "rt-multi-thread", "time"] } +tuic = { path = "../tuic", default-features = false } +tuic-quinn = { path = "../tuic-quinn", default-features = false } diff --git a/tuic-client/src/config.rs b/tuic-client/src/config.rs new file mode 100644 index 0000000..f590ee6 --- /dev/null +++ b/tuic-client/src/config.rs @@ -0,0 +1,61 @@ +use lexopt::{Arg, Error as ArgumentError, Parser}; +use serde::Deserialize; +use serde_json::Error as SerdeError; +use std::{ffi::OsString, fs::File, io::Error as IoError}; +use thiserror::Error; + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Config {} + +impl Config { + pub fn parse(args: A) -> Result + where + A: IntoIterator, + A::Item: Into, + { + let mut parser = Parser::from_iter(args); + let mut path = None; + + while let Some(arg) = parser.next()? { + match arg { + Arg::Short('c') | Arg::Long("config") => { + if path.is_none() { + path = Some(parser.value()?); + } else { + return Err(ConfigError::Argument(arg.unexpected())); + } + } + Arg::Short('v') | Arg::Long("version") => { + return Err(ConfigError::Version(env!("CARGO_PKG_VERSION"))) + } + Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(todo!())), + _ => return Err(ConfigError::Argument(arg.unexpected())), + } + } + + if path.is_none() { + return Err(ConfigError::NoConfig); + } + + let file = File::open(path.unwrap())?; + + Ok(serde_json::from_reader(file)?) + } +} + +#[derive(Debug, Error)] +pub enum ConfigError { + #[error("transparent")] + Argument(#[from] ArgumentError), + #[error("no config file specified")] + NoConfig, + #[error("{0}")] + Version(&'static str), + #[error("{0}")] + Help(&'static str), + #[error("transparent")] + Io(#[from] IoError), + #[error("transparent")] + Serde(#[from] SerdeError), +} diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs new file mode 100644 index 0000000..f7f8d7d --- /dev/null +++ b/tuic-client/src/connection.rs @@ -0,0 +1,224 @@ +use crate::{error::Error, socks5}; +use bytes::Bytes; +use once_cell::sync::OnceCell; +use parking_lot::Mutex; +use quinn::{ + Connection as QuinnConnection, Endpoint as QuinnEndpoint, RecvStream, SendStream, VarInt, +}; +use std::{ + net::SocketAddr, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Weak, + }, + time::Duration, +}; +use tokio::{ + sync::{Mutex as AsyncMutex, OnceCell as AsyncOnceCell}, + time, +}; +use tuic_quinn::{side, Connection as Model, Task}; + +static ENDPOINT: OnceCell> = OnceCell::new(); +static CONNECTION: AsyncOnceCell> = AsyncOnceCell::const_new(); + +const DEFAULT_CONCURRENT_STREAMS: usize = 32; + +struct Endpoint { + ep: QuinnEndpoint, +} + +impl Endpoint { + fn new() -> Result { + let ep = QuinnEndpoint::client(SocketAddr::from(([0, 0, 0, 0], 0)))?; + Ok(Self { ep }) + } + + async fn connect(&self) -> Result { + let conn = self + .ep + .connect(SocketAddr::from(([127, 0, 0, 1], 8080)), "localhost")? + .await + .map(Connection::new)?; + + tokio::spawn(conn.clone().accept()); + + Ok(conn) + } +} + +#[derive(Clone)] +pub struct Connection { + conn: QuinnConnection, + model: Model, + remote_uni_stream_cnt: StreamCount, + remote_bi_stream_cnt: StreamCount, + max_concurrent_uni_streams: Arc, + max_concurrent_bi_streams: Arc, +} + +impl Connection { + fn new(conn: QuinnConnection) -> Self { + Self { + conn: conn.clone(), + model: Model::::new(conn), + remote_uni_stream_cnt: StreamCount::new(), + remote_bi_stream_cnt: StreamCount::new(), + max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), + max_concurrent_bi_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), + } + } + + pub async fn get() -> Result { + let try_init_conn = async { + ENDPOINT + .get_or_try_init(|| Endpoint::new().map(Mutex::new)) + .map(|ep| ep.lock())? + .connect() + .await + .map(AsyncMutex::new) + }; + + let try_get_conn = async { + let conn = CONNECTION + .get_or_try_init(|| try_init_conn) + .await? + .lock() + .await; + + Ok::<_, Error>(conn) + }; + + let mut conn = time::timeout(Duration::from_secs(5), try_get_conn) + .await + .map_err(|_| Error::Timeout)??; + + if conn.is_closed() { + let new_conn = ENDPOINT.get().unwrap().lock().connect().await?; + *conn = new_conn; + } + + Ok(conn.clone()) + } + + fn is_closed(&self) -> bool { + self.conn.close_reason().is_some() + } + + async fn accept_uni_stream(&self) -> Result<(RecvStream, StreamRegister), Error> { + let max = self.max_concurrent_uni_streams.load(Ordering::Relaxed); + + if self.remote_uni_stream_cnt.get() == max { + self.max_concurrent_uni_streams + .store(max * 2, Ordering::Relaxed); + + self.conn + .set_max_concurrent_uni_streams(VarInt::from((max * 2) as u32)); + } + + let recv = self.conn.accept_uni().await?; + let reg = self.remote_uni_stream_cnt.register(); + Ok((recv, reg)) + } + + async fn accept_bi_stream(&self) -> Result<(SendStream, RecvStream, StreamRegister), Error> { + let max = self.max_concurrent_bi_streams.load(Ordering::Relaxed); + + if self.remote_bi_stream_cnt.get() == max { + self.max_concurrent_bi_streams + .store(max * 2, Ordering::Relaxed); + + self.conn + .set_max_concurrent_bi_streams(VarInt::from((max * 2) as u32)); + } + + let (send, recv) = self.conn.accept_bi().await?; + let reg = self.remote_bi_stream_cnt.register(); + Ok((send, recv, reg)) + } + + async fn accept_datagram(&self) -> Result { + Ok(self.conn.read_datagram().await?) + } + + async fn handle_uni_stream(self, recv: RecvStream, _reg: StreamRegister) { + let res = match self.model.accept_uni_stream(recv).await { + Err(err) => Err(Error::from(err)), + Ok(Task::Packet(Some((pkt, addr, assoc_id)))) => { + socks5::recv_pkt(pkt, addr, assoc_id).await + } + _ => unreachable!(), + }; + + match res { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + } + + async fn handle_bi_stream(self, send: SendStream, recv: RecvStream, _reg: StreamRegister) { + let res = match self.model.accept_bi_stream(send, recv).await { + Err(err) => Err(Error::from(err)), + _ => unreachable!(), + }; + + match res { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + } + + async fn handle_datagram(self, dg: Bytes) { + let res = match self.model.accept_datagram(dg) { + Err(err) => Err(Error::from(err)), + Ok(Task::Packet(Some((pkt, addr, assoc_id)))) => { + socks5::recv_pkt(pkt, addr, assoc_id).await + } + _ => unreachable!(), + }; + + match res { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + } + + async fn accept(self) { + let err = loop { + tokio::select! { + res = self.accept_uni_stream() => match res { + Ok((recv, reg)) => tokio::spawn(self.clone().handle_uni_stream(recv, reg)), + Err(err) => break err, + }, + res = self.accept_bi_stream() => match res { + Ok((send, recv, reg)) => tokio::spawn(self.clone().handle_bi_stream(send, recv, reg)), + Err(err) => break err, + }, + res = self.accept_datagram() => match res { + Ok(dg) => tokio::spawn(self.clone().handle_datagram(dg)), + Err(err) => break err, + }, + }; + }; + + eprintln!("{err}"); + } +} + +#[derive(Clone)] +struct StreamCount(Arc<()>); +struct StreamRegister(Weak<()>); + +impl StreamCount { + fn new() -> Self { + Self(Arc::new(())) + } + + fn register(&self) -> StreamRegister { + StreamRegister(Arc::downgrade(&self.0)) + } + + fn get(&self) -> usize { + Arc::weak_count(&self.0) + } +} diff --git a/tuic-client/src/error.rs b/tuic-client/src/error.rs new file mode 100644 index 0000000..c250262 --- /dev/null +++ b/tuic-client/src/error.rs @@ -0,0 +1,18 @@ +use quinn::{ConnectError, ConnectionError}; +use std::io::Error as IoError; +use thiserror::Error; +use tuic_quinn::Error as ModelError; + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Connect(#[from] ConnectError), + #[error(transparent)] + Connection(#[from] ConnectionError), + #[error(transparent)] + Model(#[from] ModelError), + #[error("timeout")] + Timeout, +} diff --git a/tuic-client/src/main.rs b/tuic-client/src/main.rs index e7a11a9..8910e62 100644 --- a/tuic-client/src/main.rs +++ b/tuic-client/src/main.rs @@ -1,3 +1,25 @@ -fn main() { - println!("Hello, world!"); +use self::{ + config::{Config, ConfigError}, + connection::Connection, +}; +use std::{env, process}; + +mod config; +mod connection; +mod error; +mod socks5; + +#[tokio::main] +async fn main() { + let cfg = match Config::parse(env::args_os()) { + Ok(cfg) => cfg, + Err(ConfigError::Version(msg) | ConfigError::Help(msg)) => { + println!("{msg}"); + process::exit(0); + } + Err(err) => { + eprintln!("{err}"); + process::exit(1); + } + }; } diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs new file mode 100644 index 0000000..8ae37ef --- /dev/null +++ b/tuic-client/src/socks5.rs @@ -0,0 +1,7 @@ +use crate::error::Error; +use bytes::Bytes; +use tuic::Address; + +pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { + todo!() +} From 837910527fd06823d2e654c350dfb06819b947bb Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sun, 29 Jan 2023 23:52:27 +0900 Subject: [PATCH 073/103] only read header when encountering task `Packet` --- tuic-quinn/src/lib.rs | 86 ++++++++++++++++++++++++------------------- 1 file changed, 48 insertions(+), 38 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 986eeb0..4fa60b6 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -91,34 +91,6 @@ impl Connection { pub fn collect_garbage(&self, timeout: Duration) { self.model.collect_garbage(timeout); } - - async fn accept_packet_quic( - &self, - model: PacketModel, - mut recv: &mut RecvStream, - ) -> Result, Error> { - let mut buf = vec![0; model.size() as usize]; - AsyncReadExt::read_exact(&mut recv, &mut buf).await?; - let mut asm = Vec::new(); - - Ok(model - .assemble(Bytes::from(buf))? - .map(|pkt| pkt.assemble(&mut asm)) - .map(|(addr, assoc_id)| (Bytes::from(asm), addr, assoc_id))) - } - - fn accept_packet_native( - &self, - model: PacketModel, - data: Bytes, - ) -> Result, Error> { - let mut asm = Vec::new(); - - Ok(model - .assemble(data)? - .map(|pkt| pkt.assemble(&mut asm)) - .map(|(addr, assoc_id)| (Bytes::from(asm), addr, assoc_id))) - } } impl Connection { @@ -164,9 +136,7 @@ impl Connection { Header::Connect(_) => Err(Error::BadCommandUniStream("connect", recv)), Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); - Ok(Task::Packet( - self.accept_packet_quic(model, &mut recv).await?, - )) + Ok(Task::Packet(Packet::new(model, PacketSource::Quic(recv)))) } Header::Dissociate(_) => Err(Error::BadCommandUniStream("dissociate", recv)), Header::Heartbeat(_) => Err(Error::BadCommandUniStream("heartbeat", recv)), @@ -210,8 +180,13 @@ impl Connection { Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); let pos = dg.position() as usize; - let buf = dg.into_inner().slice(pos..pos + model.size() as usize); - Ok(Task::Packet(self.accept_packet_native(model, buf)?)) + let mut buf = dg.into_inner(); + if (pos + model.size() as usize) < buf.len() { + buf = buf.slice(pos..pos + model.size() as usize); + Ok(Task::Packet(Packet::new(model, PacketSource::Native(buf)))) + } else { + Err(Error::PayloadLength(model.size() as usize, buf.len() - pos)) + } } Header::Dissociate(_) => Err(Error::BadCommandDatagram("dissociate", dg.into_inner())), Header::Heartbeat(_) => Err(Error::BadCommandDatagram("heartbeat", dg.into_inner())), @@ -243,9 +218,7 @@ impl Connection { Header::Connect(_) => Err(Error::BadCommandUniStream("connect", recv)), Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); - Ok(Task::Packet( - self.accept_packet_quic(model, &mut recv).await?, - )) + Ok(Task::Packet(Packet::new(model, PacketSource::Quic(recv)))) } Header::Dissociate(dissoc) => { let model = self.model.recv_dissociate(dissoc); @@ -296,7 +269,7 @@ impl Connection { let model = self.model.recv_packet(pkt); let pos = dg.position() as usize; let buf = dg.into_inner().slice(pos..pos + model.size() as usize); - Ok(Task::Packet(self.accept_packet_native(model, buf)?)) + Ok(Task::Packet(Packet::new(model, PacketSource::Native(buf)))) } Header::Dissociate(_) => Err(Error::BadCommandDatagram("dissociate", dg.into_inner())), Header::Heartbeat(hb) => { @@ -362,11 +335,46 @@ impl AsyncWrite for Connect { } } +pub struct Packet { + model: PacketModel, + src: PacketSource, +} + +enum PacketSource { + Quic(RecvStream), + Native(Bytes), +} + +impl Packet { + fn new(model: PacketModel, src: PacketSource) -> Self { + Self { src, model } + } + + pub async fn accept(self) -> Result, Error> { + let pkt = match self.src { + PacketSource::Quic(mut recv) => { + let mut buf = vec![0; self.model.size() as usize]; + AsyncReadExt::read_exact(&mut recv, &mut buf).await?; + Bytes::from(buf) + } + PacketSource::Native(pkt) => pkt, + }; + + let mut asm = Vec::new(); + + Ok(self + .model + .assemble(pkt)? + .map(|pkt| pkt.assemble(&mut asm)) + .map(|(addr, assoc_id)| (Bytes::from(asm), addr, assoc_id))) + } +} + #[non_exhaustive] pub enum Task { Authenticate([u8; 8]), Connect(Connect), - Packet(Option<(Bytes, Address, u16)>), + Packet(Packet), Dissociate(u16), Heartbeat, } @@ -379,6 +387,8 @@ pub enum Error { Connection(#[from] ConnectionError), #[error(transparent)] SendDatagram(#[from] SendDatagramError), + #[error("expecting payload length {0} but got {1}")] + PayloadLength(usize, usize), #[error(transparent)] Assemble(#[from] AssembleError), #[error("error unmarshaling uni_stream: {0}")] From 290545636c2bfc8bcd85199b7db83431f945725a Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 30 Jan 2023 00:54:56 +0900 Subject: [PATCH 074/103] implement socks5 command `connect` --- tuic-client/Cargo.toml | 5 +- tuic-client/src/connection.rs | 39 ++++++++------ tuic-client/src/main.rs | 12 +++-- tuic-client/src/socks5.rs | 95 +++++++++++++++++++++++++++++++++-- 4 files changed, 127 insertions(+), 24 deletions(-) diff --git a/tuic-client/Cargo.toml b/tuic-client/Cargo.toml index cce7a68..2238f24 100644 --- a/tuic-client/Cargo.toml +++ b/tuic-client/Cargo.toml @@ -7,11 +7,14 @@ edition = "2021" bytes = { version = "1.3.0", default-features = false, features = ["std"] } lexopt = { version = "0.3.0", default-features = false } once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] } -parking_lot = { version = "0.12.1", default-features = false } +parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] } quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] } serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] } serde_json = { version = "1.0.91", default-features = false, features = ["std"] } +socks5-proto = { version = "0.3.3", default-features = false } +socks5-server = { version = "0.8.3", default-features = false } thiserror = { version = "1.0.38", default-features = false } tokio = { version = "1.24.2", default-features = false, features = ["macros", "parking_lot", "rt-multi-thread", "time"] } +tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] } tuic = { path = "../tuic", default-features = false } tuic-quinn = { path = "../tuic-quinn", default-features = false } diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index f7f8d7d..bea7e1d 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -17,7 +17,8 @@ use tokio::{ sync::{Mutex as AsyncMutex, OnceCell as AsyncOnceCell}, time, }; -use tuic_quinn::{side, Connection as Model, Task}; +use tuic::Address; +use tuic_quinn::{side, Connect, Connection as Model, Task}; static ENDPOINT: OnceCell> = OnceCell::new(); static CONNECTION: AsyncOnceCell> = AsyncOnceCell::const_new(); @@ -80,25 +81,29 @@ impl Connection { }; let try_get_conn = async { - let conn = CONNECTION + let mut conn = CONNECTION .get_or_try_init(|| try_init_conn) .await? .lock() .await; - Ok::<_, Error>(conn) + if conn.is_closed() { + let new_conn = ENDPOINT.get().unwrap().lock().connect().await?; + *conn = new_conn; + } + + Ok::<_, Error>(conn.clone()) }; - let mut conn = time::timeout(Duration::from_secs(5), try_get_conn) + let conn = time::timeout(Duration::from_secs(5), try_get_conn) .await .map_err(|_| Error::Timeout)??; - if conn.is_closed() { - let new_conn = ENDPOINT.get().unwrap().lock().connect().await?; - *conn = new_conn; - } + Ok(conn) + } - Ok(conn.clone()) + pub async fn connect(&self, addr: Address) -> Result { + Ok(self.model.connect(addr).await?) } fn is_closed(&self) -> bool { @@ -144,9 +149,11 @@ impl Connection { async fn handle_uni_stream(self, recv: RecvStream, _reg: StreamRegister) { let res = match self.model.accept_uni_stream(recv).await { Err(err) => Err(Error::from(err)), - Ok(Task::Packet(Some((pkt, addr, assoc_id)))) => { - socks5::recv_pkt(pkt, addr, assoc_id).await - } + Ok(Task::Packet(pkt)) => match pkt.accept().await { + Ok(Some((pkt, addr, assoc_id))) => socks5::recv_pkt(pkt, addr, assoc_id).await, + Ok(None) => Ok(()), + Err(err) => Err(Error::from(err)), + }, _ => unreachable!(), }; @@ -171,9 +178,11 @@ impl Connection { async fn handle_datagram(self, dg: Bytes) { let res = match self.model.accept_datagram(dg) { Err(err) => Err(Error::from(err)), - Ok(Task::Packet(Some((pkt, addr, assoc_id)))) => { - socks5::recv_pkt(pkt, addr, assoc_id).await - } + Ok(Task::Packet(pkt)) => match pkt.accept().await { + Ok(Some((pkt, addr, assoc_id))) => socks5::recv_pkt(pkt, addr, assoc_id).await, + Ok(None) => Ok(()), + Err(err) => Err(Error::from(err)), + }, _ => unreachable!(), }; diff --git a/tuic-client/src/main.rs b/tuic-client/src/main.rs index 8910e62..20cc0b6 100644 --- a/tuic-client/src/main.rs +++ b/tuic-client/src/main.rs @@ -1,7 +1,4 @@ -use self::{ - config::{Config, ConfigError}, - connection::Connection, -}; +use self::config::{Config, ConfigError}; use std::{env, process}; mod config; @@ -11,7 +8,7 @@ mod socks5; #[tokio::main] async fn main() { - let cfg = match Config::parse(env::args_os()) { + let _cfg = match Config::parse(env::args_os()) { Ok(cfg) => cfg, Err(ConfigError::Version(msg) | ConfigError::Help(msg)) => { println!("{msg}"); @@ -22,4 +19,9 @@ async fn main() { process::exit(1); } }; + + if let Err(err) = socks5::start().await { + eprintln!("{err}"); + process::exit(1); + } } diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index 8ae37ef..d52a010 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -1,7 +1,96 @@ -use crate::error::Error; +use crate::{connection::Connection as TuicConnection, error::Error}; use bytes::Bytes; -use tuic::Address; +use socks5_proto::{Address, Reply}; +use socks5_server::{ + auth::NoAuth, + connection::{associate, bind, connect}, + Associate, Bind, Connect, Connection, Server, +}; +use std::sync::Arc; +use tokio::io::{self, AsyncWriteExt}; +use tokio_util::compat::FuturesAsyncReadCompatExt; +use tuic::Address as TuicAddress; -pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { +pub async fn start() -> Result<(), Error> { + let server = Server::bind("127.0.0.1:5000", Arc::new(NoAuth)).await?; + + while let Ok((conn, _)) = server.accept().await { + tokio::spawn(async move { + let res = match conn.handshake().await { + Ok(Connection::Associate(associate, addr)) => { + handle_associate(associate, addr).await + } + Ok(Connection::Bind(bind, addr)) => handle_bind(bind, addr).await, + Ok(Connection::Connect(connect, addr)) => handle_connect(connect, addr).await, + Err(err) => Err(Error::from(err)), + }; + + match res { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + } + }); + } + + Ok(()) +} + +async fn handle_associate( + assoc: Associate, + addr: Address, +) -> Result<(), Error> { + todo!() +} + +async fn handle_bind(bind: Bind, _addr: Address) -> Result<(), Error> { + let mut conn = bind + .reply(Reply::CommandNotSupported, Address::unspecified()) + .await?; + let _ = conn.shutdown().await; + Ok(()) +} + +async fn handle_connect(conn: Connect, addr: Address) -> Result<(), Error> { + let target_addr = match addr { + Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), + Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), + }; + + let relay = match TuicConnection::get().await { + Ok(conn) => conn.connect(target_addr).await, + Err(err) => Err(err), + }; + + match relay { + Ok(relay) => { + let mut relay = relay.compat(); + let conn = conn.reply(Reply::Succeeded, Address::unspecified()).await; + + match conn { + Ok(mut conn) => match io::copy_bidirectional(&mut conn, &mut relay).await { + Ok(_) => Ok(()), + Err(err) => { + let _ = conn.shutdown().await; + let _ = relay.shutdown().await; + Err(Error::from(err)) + } + }, + Err(err) => { + let _ = relay.shutdown().await; + Err(Error::from(err)) + } + } + } + Err(err) => { + let mut conn = conn + .reply(Reply::GeneralFailure, Address::unspecified()) + .await?; + let _ = conn.shutdown().await; + Err(err) + } + } +} + +pub async fn recv_pkt(pkt: Bytes, addr: TuicAddress, assoc_id: u16) -> Result<(), Error> { todo!() } From 51dde2737852c280fc69b24509ea1d49624e09ef Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 30 Jan 2023 02:08:32 +0900 Subject: [PATCH 075/103] add `Connection::dissociate` in `tuic-quinn` --- tuic-quinn/src/lib.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 4fa60b6..15d1fb9 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -117,6 +117,14 @@ impl Connection { Ok(Connect::new(Side::Client(model), send, recv)) } + pub async fn dissociate(&self, assoc_id: u16) -> Result<(), Error> { + let mut send = self.conn.open_uni().await?; + let model = self.model.send_dissociate(assoc_id); + model.header().async_marshal(&mut send).await?; + send.close().await?; + Ok(()) + } + pub async fn heartbeat(&self) -> Result<(), Error> { let model = self.model.send_heartbeat(); let mut buf = Vec::with_capacity(model.header().len()); From b85bf70e94c1cec3034ce15e20d137633131a783 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 30 Jan 2023 13:12:12 +0900 Subject: [PATCH 076/103] impl socks5 associate sending packets mechanism --- tuic-client/Cargo.toml | 2 +- tuic-client/src/connection.rs | 10 +++ tuic-client/src/socks5.rs | 117 ++++++++++++++++++++++++++++++++-- 3 files changed, 123 insertions(+), 6 deletions(-) diff --git a/tuic-client/Cargo.toml b/tuic-client/Cargo.toml index 2238f24..8ceec2f 100644 --- a/tuic-client/Cargo.toml +++ b/tuic-client/Cargo.toml @@ -14,7 +14,7 @@ serde_json = { version = "1.0.91", default-features = false, features = ["std"] socks5-proto = { version = "0.3.3", default-features = false } socks5-server = { version = "0.8.3", default-features = false } thiserror = { version = "1.0.38", default-features = false } -tokio = { version = "1.24.2", default-features = false, features = ["macros", "parking_lot", "rt-multi-thread", "time"] } +tokio = { version = "1.24.2", default-features = false, features = ["macros", "net", "parking_lot", "rt-multi-thread", "time"] } tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] } tuic = { path = "../tuic", default-features = false } tuic-quinn = { path = "../tuic-quinn", default-features = false } diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index bea7e1d..c6064a4 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -106,6 +106,16 @@ impl Connection { Ok(self.model.connect(addr).await?) } + pub async fn packet(&self, pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { + self.model.packet_quic(pkt, addr, assoc_id).await?; // TODO + Ok(()) + } + + pub async fn dissociate(&self, assoc_id: u16) -> Result<(), Error> { + self.model.dissociate(assoc_id).await?; + Ok(()) + } + fn is_closed(&self) -> bool { self.conn.close_reason().is_some() } diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index d52a010..a5bb2ac 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -1,16 +1,33 @@ use crate::{connection::Connection as TuicConnection, error::Error}; use bytes::Bytes; +use once_cell::sync::Lazy; +use parking_lot::Mutex; use socks5_proto::{Address, Reply}; use socks5_server::{ auth::NoAuth, connection::{associate, bind, connect}, - Associate, Bind, Connect, Connection, Server, + Associate, AssociatedUdpSocket, Bind, Connect, Connection, Server, +}; +use std::{ + collections::HashMap, + io::{Error as IoError, ErrorKind}, + net::SocketAddr, + sync::{ + atomic::{AtomicU16, Ordering}, + Arc, + }, +}; +use tokio::{ + io::{self, AsyncWriteExt}, + net::UdpSocket, }; -use std::sync::Arc; -use tokio::io::{self, AsyncWriteExt}; use tokio_util::compat::FuturesAsyncReadCompatExt; use tuic::Address as TuicAddress; +static NEXT_ASSOCIATE_ID: AtomicU16 = AtomicU16::new(0); +static UDP_SESSIONS: Lazy>>> = + Lazy::new(|| Mutex::new(HashMap::new())); + pub async fn start() -> Result<(), Error> { let server = Server::bind("127.0.0.1:5000", Arc::new(NoAuth)).await?; @@ -37,9 +54,31 @@ pub async fn start() -> Result<(), Error> { async fn handle_associate( assoc: Associate, - addr: Address, + _addr: Address, ) -> Result<(), Error> { - todo!() + let assoc_socket = UdpSocket::bind(SocketAddr::from((assoc.local_addr()?.ip(), 0))) + .await + .and_then(|socket| { + socket + .local_addr() + .map(|addr| (Arc::new(AssociatedUdpSocket::from((socket, 1500))), addr)) + }); + + match assoc_socket { + Ok((assoc_socket, assoc_addr)) => { + let assoc = assoc + .reply(Reply::Succeeded, Address::SocketAddress(assoc_addr)) + .await?; + send_pkt(assoc, assoc_socket).await + } + Err(err) => { + let mut assoc = assoc + .reply(Reply::GeneralFailure, Address::unspecified()) + .await?; + let _ = assoc.shutdown().await; + Err(Error::from(err)) + } + } } async fn handle_bind(bind: Bind, _addr: Address) -> Result<(), Error> { @@ -91,6 +130,74 @@ async fn handle_connect(conn: Connect, addr: Address) -> Res } } +async fn send_pkt( + mut assoc: Associate, + assoc_socket: Arc, +) -> Result<(), Error> { + let assoc_id = NEXT_ASSOCIATE_ID.fetch_add(1, Ordering::AcqRel); + UDP_SESSIONS.lock().insert(assoc_id, assoc_socket.clone()); + let mut connected = None; + + async fn accept_pkt( + assoc_socket: &AssociatedUdpSocket, + connected: &mut Option, + assoc_id: u16, + ) -> Result<(), Error> { + let (pkt, frag, dst_addr, src_addr) = assoc_socket.recv_from().await?; + + if let Some(connected) = connected { + if connected != &src_addr { + Err(IoError::new( + ErrorKind::Other, + format!("invalid source address: {src_addr}"), + ))?; + } + } else { + assoc_socket.connect(src_addr).await?; + *connected = Some(src_addr); + } + + if frag != 0 { + Err(IoError::new( + ErrorKind::Other, + format!("fragmented packet is not supported"), + ))?; + } + + let target_addr = match dst_addr { + Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), + Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), + }; + + TuicConnection::get() + .await? + .packet(pkt, target_addr, assoc_id) + .await + } + + let res = tokio::select! { + res = assoc.wait_until_closed() => res, + _ = async { loop { + if let Err(err) = accept_pkt(&assoc_socket, &mut connected, assoc_id).await { + eprintln!("{err}"); + } + }} => unreachable!(), + }; + + let _ = assoc.shutdown().await; + UDP_SESSIONS.lock().remove(&assoc_id); + + match TuicConnection::get().await { + Ok(conn) => match conn.dissociate(assoc_id).await { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + }, + Err(err) => eprintln!("{err}"), + } + + Ok(res?) +} + pub async fn recv_pkt(pkt: Bytes, addr: TuicAddress, assoc_id: u16) -> Result<(), Error> { todo!() } From a3d5b31807ab54d0d035b70bd5b4cec645949673 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 30 Jan 2023 19:56:42 +0900 Subject: [PATCH 077/103] restrict packet resource on client side --- tuic-quinn/src/lib.rs | 32 +++++++++++++++++++++----------- tuic/src/model/mod.rs | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 15d1fb9..c742cd0 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -143,8 +143,12 @@ impl Connection { Header::Authenticate(_) => Err(Error::BadCommandUniStream("authenticate", recv)), Header::Connect(_) => Err(Error::BadCommandUniStream("connect", recv)), Header::Packet(pkt) => { - let model = self.model.recv_packet(pkt); - Ok(Task::Packet(Packet::new(model, PacketSource::Quic(recv)))) + let assoc_id = pkt.assoc_id(); + self.model + .recv_packet(pkt) + .map_or(Err(Error::InvalidUdpSession(assoc_id)), |pkt| { + Ok(Task::Packet(Packet::new(pkt, PacketSource::Quic(recv)))) + }) } Header::Dissociate(_) => Err(Error::BadCommandUniStream("dissociate", recv)), Header::Heartbeat(_) => Err(Error::BadCommandUniStream("heartbeat", recv)), @@ -186,14 +190,18 @@ impl Connection { } Header::Connect(_) => Err(Error::BadCommandDatagram("connect", dg.into_inner())), Header::Packet(pkt) => { - let model = self.model.recv_packet(pkt); - let pos = dg.position() as usize; - let mut buf = dg.into_inner(); - if (pos + model.size() as usize) < buf.len() { - buf = buf.slice(pos..pos + model.size() as usize); - Ok(Task::Packet(Packet::new(model, PacketSource::Native(buf)))) + let assoc_id = pkt.assoc_id(); + if let Some(pkt) = self.model.recv_packet(pkt) { + let pos = dg.position() as usize; + let mut buf = dg.into_inner(); + if (pos + pkt.size() as usize) < buf.len() { + buf = buf.slice(pos..pos + pkt.size() as usize); + Ok(Task::Packet(Packet::new(pkt, PacketSource::Native(buf)))) + } else { + Err(Error::PayloadLength(pkt.size() as usize, buf.len() - pos)) + } } else { - Err(Error::PayloadLength(model.size() as usize, buf.len() - pos)) + Err(Error::InvalidUdpSession(assoc_id)) } } Header::Dissociate(_) => Err(Error::BadCommandDatagram("dissociate", dg.into_inner())), @@ -225,7 +233,7 @@ impl Connection { } Header::Connect(_) => Err(Error::BadCommandUniStream("connect", recv)), Header::Packet(pkt) => { - let model = self.model.recv_packet(pkt); + let model = self.model.recv_packet_unrestricted(pkt); Ok(Task::Packet(Packet::new(model, PacketSource::Quic(recv)))) } Header::Dissociate(dissoc) => { @@ -274,7 +282,7 @@ impl Connection { } Header::Connect(_) => Err(Error::BadCommandDatagram("connect", dg.into_inner())), Header::Packet(pkt) => { - let model = self.model.recv_packet(pkt); + let model = self.model.recv_packet_unrestricted(pkt); let pos = dg.position() as usize; let buf = dg.into_inner().slice(pos..pos + model.size() as usize); Ok(Task::Packet(Packet::new(model, PacketSource::Native(buf)))) @@ -397,6 +405,8 @@ pub enum Error { SendDatagram(#[from] SendDatagramError), #[error("expecting payload length {0} but got {1}")] PayloadLength(usize, usize), + #[error("invalid udp session {0}")] + InvalidUdpSession(u16), #[error(transparent)] Assemble(#[from] AssembleError), #[error("error unmarshaling uni_stream: {0}")] diff --git a/tuic/src/model/mod.rs b/tuic/src/model/mod.rs index a4486dc..a174e82 100644 --- a/tuic/src/model/mod.rs +++ b/tuic/src/model/mod.rs @@ -78,7 +78,7 @@ where .send_packet(assoc_id, addr, max_pkt_size) } - pub fn recv_packet(&self, header: PacketHeader) -> Packet { + pub fn recv_packet(&self, header: PacketHeader) -> Option> { let (assoc_id, pkt_id, frag_total, frag_id, size, addr) = header.into(); self.udp_sessions.lock().recv_packet( self.udp_sessions.clone(), @@ -91,6 +91,19 @@ where ) } + pub fn recv_packet_unrestricted(&self, header: PacketHeader) -> Packet { + let (assoc_id, pkt_id, frag_total, frag_id, size, addr) = header.into(); + self.udp_sessions.lock().recv_packet_unrestricted( + self.udp_sessions.clone(), + assoc_id, + pkt_id, + frag_total, + frag_id, + size, + addr, + ) + } + pub fn send_dissociate(&self, assoc_id: u16) -> Dissociate { self.udp_sessions.lock().send_dissociate(assoc_id) } @@ -166,7 +179,7 @@ where } } - fn send_packet<'a>( + fn send_packet( &mut self, assoc_id: u16, addr: Address, @@ -178,7 +191,22 @@ where .send_packet(assoc_id, addr, max_pkt_size) } - fn recv_packet<'a>( + fn recv_packet( + &mut self, + sessions: Arc>, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + ) -> Option> { + self.sessions.get_mut(&assoc_id).map(|session| { + session.recv_packet(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr) + }) + } + + fn recv_packet_unrestricted( &mut self, sessions: Arc>, assoc_id: u16, From 43d81c5eca304d2cb2ec914c11553b0182301e6f Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 31 Jan 2023 15:25:18 +0900 Subject: [PATCH 078/103] destruct local UDP session first when dissociating --- tuic-quinn/src/lib.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index c742cd0..8c94025 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -103,23 +103,23 @@ impl Connection { } pub async fn authenticate(&self, token: [u8; 8]) -> Result<(), Error> { - let mut send = self.conn.open_uni().await?; let model = self.model.send_authenticate(token); + let mut send = self.conn.open_uni().await?; model.header().async_marshal(&mut send).await?; send.close().await?; Ok(()) } pub async fn connect(&self, addr: Address) -> Result { - let (mut send, recv) = self.conn.open_bi().await?; let model = self.model.send_connect(addr); + let (mut send, recv) = self.conn.open_bi().await?; model.header().async_marshal(&mut send).await?; Ok(Connect::new(Side::Client(model), send, recv)) } pub async fn dissociate(&self, assoc_id: u16) -> Result<(), Error> { - let mut send = self.conn.open_uni().await?; let model = self.model.send_dissociate(assoc_id); + let mut send = self.conn.open_uni().await?; model.header().async_marshal(&mut send).await?; send.close().await?; Ok(()) From 96ddb4884a345ed3d3c5219b35fd79ff889a277e Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 31 Jan 2023 15:25:51 +0900 Subject: [PATCH 079/103] implement client packet receiving mechanism --- tuic-client/src/connection.rs | 23 +++++++++++++++++++++-- tuic-client/src/socks5.rs | 7 +++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index c6064a4..e726190 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -5,6 +5,7 @@ use parking_lot::Mutex; use quinn::{ Connection as QuinnConnection, Endpoint as QuinnEndpoint, RecvStream, SendStream, VarInt, }; +use socks5_proto::Address as Socks5Address; use std::{ net::SocketAddr, sync::{ @@ -160,7 +161,16 @@ impl Connection { let res = match self.model.accept_uni_stream(recv).await { Err(err) => Err(Error::from(err)), Ok(Task::Packet(pkt)) => match pkt.accept().await { - Ok(Some((pkt, addr, assoc_id))) => socks5::recv_pkt(pkt, addr, assoc_id).await, + Ok(Some((pkt, addr, assoc_id))) => { + let addr = match addr { + Address::None => unreachable!(), + Address::DomainAddress(domain, port) => { + Socks5Address::DomainAddress(domain, port) + } + Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), + }; + socks5::recv_pkt(pkt, addr, assoc_id).await + } Ok(None) => Ok(()), Err(err) => Err(Error::from(err)), }, @@ -189,7 +199,16 @@ impl Connection { let res = match self.model.accept_datagram(dg) { Err(err) => Err(Error::from(err)), Ok(Task::Packet(pkt)) => match pkt.accept().await { - Ok(Some((pkt, addr, assoc_id))) => socks5::recv_pkt(pkt, addr, assoc_id).await, + Ok(Some((pkt, addr, assoc_id))) => { + let addr = match addr { + Address::None => unreachable!(), + Address::DomainAddress(domain, port) => { + Socks5Address::DomainAddress(domain, port) + } + Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), + }; + socks5::recv_pkt(pkt, addr, assoc_id).await + } Ok(None) => Ok(()), Err(err) => Err(Error::from(err)), }, diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index a5bb2ac..70ed148 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -198,6 +198,9 @@ async fn send_pkt( Ok(res?) } -pub async fn recv_pkt(pkt: Bytes, addr: TuicAddress, assoc_id: u16) -> Result<(), Error> { - todo!() +pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { + let sessions = UDP_SESSIONS.lock(); + let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() }; + assoc_socket.send(pkt, 0, addr).await?; + Ok(()) } From 59773d20513e8ca408414e5e0773da33215712f1 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 31 Jan 2023 15:36:15 +0900 Subject: [PATCH 080/103] implement client auth and heartbeat --- tuic-client/src/connection.rs | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index e726190..3c3345c 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -43,7 +43,7 @@ impl Endpoint { .await .map(Connection::new)?; - tokio::spawn(conn.clone().accept()); + tokio::spawn(conn.clone().init()); Ok(conn) } @@ -221,7 +221,32 @@ impl Connection { } } - async fn accept(self) { + async fn authenticate(self) { + match self.model.authenticate([0; 8]).await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + } + + async fn heartbeat(self) { + loop { + time::sleep(Duration::from_secs(5)).await; + + if self.is_closed() { + break; + } + + match self.model.heartbeat().await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + } + } + + async fn init(self) { + tokio::spawn(self.clone().authenticate()); + tokio::spawn(self.clone().heartbeat()); + let err = loop { tokio::select! { res = self.accept_uni_stream() => match res { From 1065e3e938a61485521b5af243a999ccedd6c6f0 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Tue, 31 Jan 2023 16:49:53 +0900 Subject: [PATCH 081/103] use `register-count` --- tuic-client/Cargo.toml | 1 + tuic-client/src/connection.rs | 45 ++++++++++--------------------- tuic/Cargo.toml | 3 ++- tuic/src/model/connect.rs | 14 +++++----- tuic/src/model/mod.rs | 51 ++++++++++++----------------------- 5 files changed, 40 insertions(+), 74 deletions(-) diff --git a/tuic-client/Cargo.toml b/tuic-client/Cargo.toml index 8ceec2f..3af8623 100644 --- a/tuic-client/Cargo.toml +++ b/tuic-client/Cargo.toml @@ -9,6 +9,7 @@ lexopt = { version = "0.3.0", default-features = false } once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] } parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] } quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] } +register-count = { version = "0.1.0", default-features = false, features = ["std"] } serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] } serde_json = { version = "1.0.91", default-features = false, features = ["std"] } socks5-proto = { version = "0.3.3", default-features = false } diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index 3c3345c..159af6e 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -5,12 +5,13 @@ use parking_lot::Mutex; use quinn::{ Connection as QuinnConnection, Endpoint as QuinnEndpoint, RecvStream, SendStream, VarInt, }; +use register_count::{Counter, Register}; use socks5_proto::Address as Socks5Address; use std::{ net::SocketAddr, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, Weak, + Arc, }, time::Duration, }; @@ -53,8 +54,8 @@ impl Endpoint { pub struct Connection { conn: QuinnConnection, model: Model, - remote_uni_stream_cnt: StreamCount, - remote_bi_stream_cnt: StreamCount, + remote_uni_stream_cnt: Counter, + remote_bi_stream_cnt: Counter, max_concurrent_uni_streams: Arc, max_concurrent_bi_streams: Arc, } @@ -64,8 +65,8 @@ impl Connection { Self { conn: conn.clone(), model: Model::::new(conn), - remote_uni_stream_cnt: StreamCount::new(), - remote_bi_stream_cnt: StreamCount::new(), + remote_uni_stream_cnt: Counter::new(), + remote_bi_stream_cnt: Counter::new(), max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), max_concurrent_bi_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), } @@ -121,10 +122,10 @@ impl Connection { self.conn.close_reason().is_some() } - async fn accept_uni_stream(&self) -> Result<(RecvStream, StreamRegister), Error> { + async fn accept_uni_stream(&self) -> Result<(RecvStream, Register), Error> { let max = self.max_concurrent_uni_streams.load(Ordering::Relaxed); - if self.remote_uni_stream_cnt.get() == max { + if self.remote_uni_stream_cnt.count() == max { self.max_concurrent_uni_streams .store(max * 2, Ordering::Relaxed); @@ -133,14 +134,14 @@ impl Connection { } let recv = self.conn.accept_uni().await?; - let reg = self.remote_uni_stream_cnt.register(); + let reg = self.remote_uni_stream_cnt.reg(); Ok((recv, reg)) } - async fn accept_bi_stream(&self) -> Result<(SendStream, RecvStream, StreamRegister), Error> { + async fn accept_bi_stream(&self) -> Result<(SendStream, RecvStream, Register), Error> { let max = self.max_concurrent_bi_streams.load(Ordering::Relaxed); - if self.remote_bi_stream_cnt.get() == max { + if self.remote_bi_stream_cnt.count() == max { self.max_concurrent_bi_streams .store(max * 2, Ordering::Relaxed); @@ -149,7 +150,7 @@ impl Connection { } let (send, recv) = self.conn.accept_bi().await?; - let reg = self.remote_bi_stream_cnt.register(); + let reg = self.remote_bi_stream_cnt.reg(); Ok((send, recv, reg)) } @@ -157,7 +158,7 @@ impl Connection { Ok(self.conn.read_datagram().await?) } - async fn handle_uni_stream(self, recv: RecvStream, _reg: StreamRegister) { + async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { let res = match self.model.accept_uni_stream(recv).await { Err(err) => Err(Error::from(err)), Ok(Task::Packet(pkt)) => match pkt.accept().await { @@ -183,7 +184,7 @@ impl Connection { } } - async fn handle_bi_stream(self, send: SendStream, recv: RecvStream, _reg: StreamRegister) { + async fn handle_bi_stream(self, send: SendStream, recv: RecvStream, _reg: Register) { let res = match self.model.accept_bi_stream(send, recv).await { Err(err) => Err(Error::from(err)), _ => unreachable!(), @@ -267,21 +268,3 @@ impl Connection { eprintln!("{err}"); } } - -#[derive(Clone)] -struct StreamCount(Arc<()>); -struct StreamRegister(Weak<()>); - -impl StreamCount { - fn new() -> Self { - Self(Arc::new(())) - } - - fn register(&self) -> StreamRegister { - StreamRegister(Arc::downgrade(&self.0)) - } - - fn get(&self) -> usize { - Arc::weak_count(&self.0) - } -} diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index 59cc861..7aff4d1 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -6,12 +6,13 @@ edition = "2021" [features] async_marshal = ["bytes", "futures-util"] marshal = ["bytes"] -model = ["parking_lot", "thiserror"] +model = ["parking_lot", "register-count", "thiserror"] [dependencies] bytes = { version = "1.3.0", default-features = false, features = ["std"], optional = true } futures-util = { version = "0.3.25", default-features = false, features = ["io", "std"], optional = true } parking_lot = { version = "0.12.1", default-features = false, optional = true } +register-count = { version = "0.1.0", default-features = false, features = ["std"], optional = true } thiserror = { version = "1.0.38", default-features = false, optional = true } [dev-dependencies] diff --git a/tuic/src/model/connect.rs b/tuic/src/model/connect.rs index 0e40479..672cb59 100644 --- a/tuic/src/model/connect.rs +++ b/tuic/src/model/connect.rs @@ -1,8 +1,6 @@ -use super::{ - side::{self, Side}, - TaskRegister, -}; +use super::side::{self, Side}; use crate::protocol::{Address, Connect as ConnectHeader, Header}; +use register_count::Register; pub struct Connect { inner: Side, @@ -11,11 +9,11 @@ pub struct Connect { struct Tx { header: Header, - _task_reg: TaskRegister, + _task_reg: Register, } impl Connect { - pub(super) fn new(task_reg: TaskRegister, addr: Address) -> Self { + pub(super) fn new(task_reg: Register, addr: Address) -> Self { Self { inner: Side::Tx(Tx { header: Header::Connect(ConnectHeader::new(addr)), @@ -33,11 +31,11 @@ impl Connect { struct Rx { addr: Address, - _task_reg: TaskRegister, + _task_reg: Register, } impl Connect { - pub(super) fn new(task_reg: TaskRegister, addr: Address) -> Self { + pub(super) fn new(task_reg: Register, addr: Address) -> Self { Self { inner: Side::Rx(Rx { addr, diff --git a/tuic/src/model/mod.rs b/tuic/src/model/mod.rs index a174e82..5ce3c71 100644 --- a/tuic/src/model/mod.rs +++ b/tuic/src/model/mod.rs @@ -3,12 +3,13 @@ use crate::protocol::{ Dissociate as DissociateHeader, Heartbeat as HeartbeatHeader, Packet as PacketHeader, }; use parking_lot::Mutex; +use register_count::{Counter, Register}; use std::{ collections::HashMap, mem, sync::{ atomic::{AtomicU16, Ordering}, - Arc, Weak, + Arc, }, time::{Duration, Instant}, }; @@ -31,8 +32,8 @@ pub use self::{ #[derive(Clone)] pub struct Connection { udp_sessions: Arc>>, - task_connect_count: TaskCount, - task_associate_count: TaskCount, + task_connect_count: Counter, + task_associate_count: Counter, } impl Connection @@ -40,11 +41,11 @@ where B: AsRef<[u8]>, { pub fn new() -> Self { - let task_associate_count = TaskCount::new(); + let task_associate_count = Counter::new(); Self { udp_sessions: Arc::new(Mutex::new(UdpSessions::new(task_associate_count.clone()))), - task_connect_count: TaskCount::new(), + task_connect_count: Counter::new(), task_associate_count, } } @@ -59,12 +60,12 @@ where } pub fn send_connect(&self, addr: Address) -> Connect { - Connect::::new(self.task_connect_count.register(), addr) + Connect::::new(self.task_connect_count.reg(), addr) } pub fn recv_connect(&self, header: ConnectHeader) -> Connect { let (addr,) = header.into(); - Connect::::new(self.task_connect_count.register(), addr) + Connect::::new(self.task_connect_count.reg(), addr) } pub fn send_packet( @@ -123,11 +124,11 @@ where } pub fn task_connect_count(&self) -> usize { - self.task_connect_count.get() + self.task_connect_count.count() } pub fn task_associate_count(&self) -> usize { - self.task_associate_count.get() + self.task_associate_count.count() } pub fn collect_garbage(&self, timeout: Duration) { @@ -135,24 +136,6 @@ where } } -#[derive(Clone)] -struct TaskCount(Arc<()>); -struct TaskRegister(Weak<()>); - -impl TaskCount { - fn new() -> Self { - Self(Arc::new(())) - } - - fn register(&self) -> TaskRegister { - TaskRegister(Arc::downgrade(&self.0)) - } - - fn get(&self) -> usize { - Arc::weak_count(&self.0) - } -} - pub mod side { pub struct Tx; pub struct Rx; @@ -165,14 +148,14 @@ pub mod side { struct UdpSessions { sessions: HashMap>, - task_associate_count: TaskCount, + task_associate_count: Counter, } impl UdpSessions where B: AsRef<[u8]>, { - fn new(task_associate_count: TaskCount) -> Self { + fn new(task_associate_count: Counter) -> Self { Self { sessions: HashMap::new(), task_associate_count, @@ -187,7 +170,7 @@ where ) -> Packet { self.sessions .entry(assoc_id) - .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) + .or_insert_with(|| UdpSession::new(self.task_associate_count.reg())) .send_packet(assoc_id, addr, max_pkt_size) } @@ -218,7 +201,7 @@ where ) -> Packet { self.sessions .entry(assoc_id) - .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) + .or_insert_with(|| UdpSession::new(self.task_associate_count.reg())) .recv_packet(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr) } @@ -244,7 +227,7 @@ where ) -> Result>, AssembleError> { self.sessions .entry(assoc_id) - .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) + .or_insert_with(|| UdpSession::new(self.task_associate_count.reg())) .insert(assoc_id, pkt_id, frag_total, frag_id, size, addr, data) } @@ -258,14 +241,14 @@ where struct UdpSession { pkt_buf: HashMap>, next_pkt_id: AtomicU16, - _task_reg: TaskRegister, + _task_reg: Register, } impl UdpSession where B: AsRef<[u8]>, { - fn new(task_reg: TaskRegister) -> Self { + fn new(task_reg: Register) -> Self { Self { pkt_buf: HashMap::new(), next_pkt_id: AtomicU16::new(0), From bcc79d7c5d100f10ece09ce1071e5cfc798d628f Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 2 Feb 2023 13:45:36 +0900 Subject: [PATCH 082/103] implement client config reading --- tuic-client/src/config.rs | 136 ++++++++++++-- tuic-client/src/connection.rs | 46 ++++- tuic-client/src/main.rs | 29 ++- tuic-client/src/socks5.rs | 343 ++++++++++++++++++---------------- tuic-client/src/utils.rs | 65 +++++++ 5 files changed, 437 insertions(+), 182 deletions(-) create mode 100644 tuic-client/src/utils.rs diff --git a/tuic-client/src/config.rs b/tuic-client/src/config.rs index f590ee6..81f35db 100644 --- a/tuic-client/src/config.rs +++ b/tuic-client/src/config.rs @@ -1,19 +1,115 @@ +use crate::utils::{self, CongestionControl, UdpRelayMode}; use lexopt::{Arg, Error as ArgumentError, Parser}; -use serde::Deserialize; +use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde_json::Error as SerdeError; -use std::{ffi::OsString, fs::File, io::Error as IoError}; +use std::{ + env::ArgsOs, + fs::File, + io::Error as IoError, + net::{IpAddr, SocketAddr}, + time::Duration, +}; use thiserror::Error; +const HELP_MSG: &str = r#" +Usage tuic-client [arguments] + +Arguments: + -c, --config Path to the config file (required) + -v, --version Print the version + -h, --help Print this help message +"#; + #[derive(Deserialize)] #[serde(deny_unknown_fields)] -pub struct Config {} +pub struct Config { + pub relay: Relay, + pub local: Local, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Relay { + #[serde(deserialize_with = "deserialize_server")] + pub server: (String, u16), + pub token: String, + pub ip: Option, + #[serde(default = "default::relay::certificates")] + pub certificates: Vec, + #[serde( + default = "default::relay::udp_relay_mode", + deserialize_with = "utils::deserialize_from_str" + )] + pub udp_relay_mode: UdpRelayMode, + #[serde( + default = "default::relay::congestion_control", + deserialize_with = "utils::deserialize_from_str" + )] + pub congestion_control: CongestionControl, + #[serde(default = "default::relay::alpn")] + pub alpn: Vec, + #[serde(default = "default::relay::zero_rtt_handshake")] + pub zero_rtt_handshake: bool, + #[serde(default = "default::relay::timeout")] + pub timeout: Duration, + #[serde(default = "default::relay::heartbeat")] + pub heartbeat: Duration, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Local { + pub server: SocketAddr, + pub username: Option, + pub password: Option, + pub dual_stack: Option, + #[serde(default = "default::local::max_packet_size")] + pub max_packet_size: usize, +} + +mod default { + pub mod relay { + use crate::utils::{CongestionControl, UdpRelayMode}; + use std::time::Duration; + + pub const fn certificates() -> Vec { + Vec::new() + } + + pub const fn udp_relay_mode() -> UdpRelayMode { + UdpRelayMode::Native + } + + pub const fn congestion_control() -> CongestionControl { + CongestionControl::Cubic + } + + pub const fn alpn() -> Vec { + Vec::new() + } + + pub const fn zero_rtt_handshake() -> bool { + false + } + + pub const fn timeout() -> Duration { + Duration::from_secs(8) + } + + pub const fn heartbeat() -> Duration { + Duration::from_secs(3) + } + } + + pub mod local { + pub const fn max_packet_size() -> usize { + 1500 + } + } +} impl Config { - pub fn parse(args: A) -> Result - where - A: IntoIterator, - A::Item: Into, - { + pub fn parse(args: ArgsOs) -> Result { let mut parser = Parser::from_iter(args); let mut path = None; @@ -29,7 +125,7 @@ impl Config { Arg::Short('v') | Arg::Long("version") => { return Err(ConfigError::Version(env!("CARGO_PKG_VERSION"))) } - Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(todo!())), + Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(HELP_MSG)), _ => return Err(ConfigError::Argument(arg.unexpected())), } } @@ -44,9 +140,25 @@ impl Config { } } +pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error> +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + let mut parts = s.split(':'); + + match (parts.next(), parts.next(), parts.next()) { + (Some(domain), Some(port), None) => port.parse().map_or_else( + |e| Err(DeError::custom(e)), + |port| Ok((domain.to_owned(), port)), + ), + _ => Err(DeError::custom("invalid server address")), + } +} + #[derive(Debug, Error)] pub enum ConfigError { - #[error("transparent")] + #[error(transparent)] Argument(#[from] ArgumentError), #[error("no config file specified")] NoConfig, @@ -54,8 +166,8 @@ pub enum ConfigError { Version(&'static str), #[error("{0}")] Help(&'static str), - #[error("transparent")] + #[error(transparent)] Io(#[from] IoError), - #[error("transparent")] + #[error(transparent)] Serde(#[from] SerdeError), } diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index 159af6e..7d14e8c 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -1,4 +1,9 @@ -use crate::{error::Error, socks5}; +use crate::{ + config::Relay, + error::Error, + socks5::Server as Socks5Server, + utils::{ServerAddr, UdpRelayMode}, +}; use bytes::Bytes; use once_cell::sync::OnceCell; use parking_lot::Mutex; @@ -27,14 +32,36 @@ static CONNECTION: AsyncOnceCell> = AsyncOnceCell::const_ const DEFAULT_CONCURRENT_STREAMS: usize = 32; -struct Endpoint { +pub struct Endpoint { ep: QuinnEndpoint, + server: ServerAddr, + token: Vec, + udp_relay_mode: UdpRelayMode, + zero_rtt_handshake: bool, + timeout: Duration, + heartbeat: Duration, } impl Endpoint { - fn new() -> Result { - let ep = QuinnEndpoint::client(SocketAddr::from(([0, 0, 0, 0], 0)))?; - Ok(Self { ep }) + pub fn set_config(cfg: Relay) -> Result<(), Error> { + let ep = todo!(); + + let ep = Self { + ep, + server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip), + token: cfg.token.into_bytes(), + udp_relay_mode: cfg.udp_relay_mode, + zero_rtt_handshake: cfg.zero_rtt_handshake, + timeout: cfg.timeout, + heartbeat: cfg.heartbeat, + }; + + ENDPOINT + .set(Mutex::new(ep)) + .map_err(|_| "endpoint already initialized") + .unwrap(); + + Ok(()) } async fn connect(&self) -> Result { @@ -75,8 +102,9 @@ impl Connection { pub async fn get() -> Result { let try_init_conn = async { ENDPOINT - .get_or_try_init(|| Endpoint::new().map(Mutex::new)) - .map(|ep| ep.lock())? + .get() + .unwrap() + .lock() .connect() .await .map(AsyncMutex::new) @@ -170,7 +198,7 @@ impl Connection { } Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), }; - socks5::recv_pkt(pkt, addr, assoc_id).await + Socks5Server::recv_pkt(pkt, addr, assoc_id).await } Ok(None) => Ok(()), Err(err) => Err(Error::from(err)), @@ -208,7 +236,7 @@ impl Connection { } Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), }; - socks5::recv_pkt(pkt, addr, assoc_id).await + Socks5Server::recv_pkt(pkt, addr, assoc_id).await } Ok(None) => Ok(()), Err(err) => Err(Error::from(err)), diff --git a/tuic-client/src/main.rs b/tuic-client/src/main.rs index 20cc0b6..e1d83e5 100644 --- a/tuic-client/src/main.rs +++ b/tuic-client/src/main.rs @@ -1,14 +1,20 @@ -use self::config::{Config, ConfigError}; +use socks5::Server; + +use self::{ + config::{Config, ConfigError}, + connection::Endpoint, +}; use std::{env, process}; mod config; mod connection; mod error; mod socks5; +mod utils; #[tokio::main] async fn main() { - let _cfg = match Config::parse(env::args_os()) { + let cfg = match Config::parse(env::args_os()) { Ok(cfg) => cfg, Err(ConfigError::Version(msg) | ConfigError::Help(msg)) => { println!("{msg}"); @@ -20,8 +26,21 @@ async fn main() { } }; - if let Err(err) = socks5::start().await { - eprintln!("{err}"); - process::exit(1); + match Endpoint::set_config(cfg.relay) { + Ok(()) => {} + Err(err) => { + eprintln!("{err}"); + process::exit(1); + } } + + match Server::set_config(cfg.local).await { + Ok(()) => {} + Err(err) => { + eprintln!("{err}"); + process::exit(1); + } + } + + Server::start().await; } diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index 70ed148..00af2b5 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -1,12 +1,12 @@ -use crate::{connection::Connection as TuicConnection, error::Error}; +use crate::{config::Local, connection::Connection as TuicConnection, error::Error}; use bytes::Bytes; -use once_cell::sync::Lazy; +use once_cell::sync::{Lazy, OnceCell}; use parking_lot::Mutex; use socks5_proto::{Address, Reply}; use socks5_server::{ auth::NoAuth, connection::{associate, bind, connect}, - Associate, AssociatedUdpSocket, Bind, Connect, Connection, Server, + Associate, AssociatedUdpSocket, Bind, Connect, Connection, Server as Socks5Server, }; use std::{ collections::HashMap, @@ -24,183 +24,214 @@ use tokio::{ use tokio_util::compat::FuturesAsyncReadCompatExt; use tuic::Address as TuicAddress; +static SERVER: OnceCell = OnceCell::new(); static NEXT_ASSOCIATE_ID: AtomicU16 = AtomicU16::new(0); static UDP_SESSIONS: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); -pub async fn start() -> Result<(), Error> { - let server = Server::bind("127.0.0.1:5000", Arc::new(NoAuth)).await?; +pub struct Server { + inner: Socks5Server, + dual_stack: Option, + max_packet_size: usize, +} - while let Ok((conn, _)) = server.accept().await { - tokio::spawn(async move { - let res = match conn.handshake().await { - Ok(Connection::Associate(associate, addr)) => { - handle_associate(associate, addr).await +impl Server { + pub async fn set_config(cfg: Local) -> Result<(), Error> { + let server = Socks5Server::bind(cfg.server, Arc::new(NoAuth)).await?; + + let server = Self { + inner: server, + dual_stack: cfg.dual_stack, + max_packet_size: cfg.max_packet_size, + }; + + SERVER + .set(server) + .map_err(|_| "socks5 server already initialized") + .unwrap(); + + Ok(()) + } + + pub async fn start() { + let server = SERVER.get().unwrap(); + + loop { + match server.inner.accept().await { + Ok((conn, _)) => { + tokio::spawn(async move { + let res = match conn.handshake().await { + Ok(Connection::Associate(associate, addr)) => { + Self::handle_associate(associate, addr).await + } + Ok(Connection::Bind(bind, addr)) => Self::handle_bind(bind, addr).await, + Ok(Connection::Connect(connect, addr)) => { + Self::handle_connect(connect, addr).await + } + Err(err) => Err(Error::from(err)), + }; + + match res { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + } + }); } - Ok(Connection::Bind(bind, addr)) => handle_bind(bind, addr).await, - Ok(Connection::Connect(connect, addr)) => handle_connect(connect, addr).await, - Err(err) => Err(Error::from(err)), - }; - - match res { - Ok(_) => {} Err(err) => eprintln!("{err}"), } - }); - } - - Ok(()) -} - -async fn handle_associate( - assoc: Associate, - _addr: Address, -) -> Result<(), Error> { - let assoc_socket = UdpSocket::bind(SocketAddr::from((assoc.local_addr()?.ip(), 0))) - .await - .and_then(|socket| { - socket - .local_addr() - .map(|addr| (Arc::new(AssociatedUdpSocket::from((socket, 1500))), addr)) - }); - - match assoc_socket { - Ok((assoc_socket, assoc_addr)) => { - let assoc = assoc - .reply(Reply::Succeeded, Address::SocketAddress(assoc_addr)) - .await?; - send_pkt(assoc, assoc_socket).await - } - Err(err) => { - let mut assoc = assoc - .reply(Reply::GeneralFailure, Address::unspecified()) - .await?; - let _ = assoc.shutdown().await; - Err(Error::from(err)) } } -} -async fn handle_bind(bind: Bind, _addr: Address) -> Result<(), Error> { - let mut conn = bind - .reply(Reply::CommandNotSupported, Address::unspecified()) - .await?; - let _ = conn.shutdown().await; - Ok(()) -} - -async fn handle_connect(conn: Connect, addr: Address) -> Result<(), Error> { - let target_addr = match addr { - Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), - Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), - }; - - let relay = match TuicConnection::get().await { - Ok(conn) => conn.connect(target_addr).await, - Err(err) => Err(err), - }; - - match relay { - Ok(relay) => { - let mut relay = relay.compat(); - let conn = conn.reply(Reply::Succeeded, Address::unspecified()).await; - - match conn { - Ok(mut conn) => match io::copy_bidirectional(&mut conn, &mut relay).await { - Ok(_) => Ok(()), - Err(err) => { - let _ = conn.shutdown().await; - let _ = relay.shutdown().await; - Err(Error::from(err)) - } - }, - Err(err) => { - let _ = relay.shutdown().await; - Err(Error::from(err)) - } - } - } - Err(err) => { - let mut conn = conn - .reply(Reply::GeneralFailure, Address::unspecified()) - .await?; - let _ = conn.shutdown().await; - Err(err) - } - } -} - -async fn send_pkt( - mut assoc: Associate, - assoc_socket: Arc, -) -> Result<(), Error> { - let assoc_id = NEXT_ASSOCIATE_ID.fetch_add(1, Ordering::AcqRel); - UDP_SESSIONS.lock().insert(assoc_id, assoc_socket.clone()); - let mut connected = None; - - async fn accept_pkt( - assoc_socket: &AssociatedUdpSocket, - connected: &mut Option, - assoc_id: u16, + async fn handle_associate( + assoc: Associate, + _addr: Address, ) -> Result<(), Error> { - let (pkt, frag, dst_addr, src_addr) = assoc_socket.recv_from().await?; + let assoc_socket = UdpSocket::bind(SocketAddr::from((assoc.local_addr()?.ip(), 0))) + .await + .and_then(|socket| { + socket + .local_addr() + .map(|addr| (Arc::new(AssociatedUdpSocket::from((socket, 1500))), addr)) + }); - if let Some(connected) = connected { - if connected != &src_addr { - Err(IoError::new( - ErrorKind::Other, - format!("invalid source address: {src_addr}"), - ))?; + match assoc_socket { + Ok((assoc_socket, assoc_addr)) => { + let assoc = assoc + .reply(Reply::Succeeded, Address::SocketAddress(assoc_addr)) + .await?; + Self::send_pkt(assoc, assoc_socket).await + } + Err(err) => { + let mut assoc = assoc + .reply(Reply::GeneralFailure, Address::unspecified()) + .await?; + let _ = assoc.shutdown().await; + Err(Error::from(err)) } - } else { - assoc_socket.connect(src_addr).await?; - *connected = Some(src_addr); } + } - if frag != 0 { - Err(IoError::new( - ErrorKind::Other, - format!("fragmented packet is not supported"), - ))?; - } + async fn handle_bind(bind: Bind, _addr: Address) -> Result<(), Error> { + let mut conn = bind + .reply(Reply::CommandNotSupported, Address::unspecified()) + .await?; + let _ = conn.shutdown().await; + Ok(()) + } - let target_addr = match dst_addr { + async fn handle_connect(conn: Connect, addr: Address) -> Result<(), Error> { + let target_addr = match addr { Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), }; - TuicConnection::get() - .await? - .packet(pkt, target_addr, assoc_id) - .await - } + let relay = match TuicConnection::get().await { + Ok(conn) => conn.connect(target_addr).await, + Err(err) => Err(err), + }; - let res = tokio::select! { - res = assoc.wait_until_closed() => res, - _ = async { loop { - if let Err(err) = accept_pkt(&assoc_socket, &mut connected, assoc_id).await { - eprintln!("{err}"); + match relay { + Ok(relay) => { + let mut relay = relay.compat(); + let conn = conn.reply(Reply::Succeeded, Address::unspecified()).await; + + match conn { + Ok(mut conn) => match io::copy_bidirectional(&mut conn, &mut relay).await { + Ok(_) => Ok(()), + Err(err) => { + let _ = conn.shutdown().await; + let _ = relay.shutdown().await; + Err(Error::from(err)) + } + }, + Err(err) => { + let _ = relay.shutdown().await; + Err(Error::from(err)) + } + } } - }} => unreachable!(), - }; - - let _ = assoc.shutdown().await; - UDP_SESSIONS.lock().remove(&assoc_id); - - match TuicConnection::get().await { - Ok(conn) => match conn.dissociate(assoc_id).await { - Ok(_) => {} - Err(err) => eprintln!("{err}"), - }, - Err(err) => eprintln!("{err}"), + Err(err) => { + let mut conn = conn + .reply(Reply::GeneralFailure, Address::unspecified()) + .await?; + let _ = conn.shutdown().await; + Err(err) + } + } } - Ok(res?) -} + async fn send_pkt( + mut assoc: Associate, + assoc_socket: Arc, + ) -> Result<(), Error> { + let assoc_id = NEXT_ASSOCIATE_ID.fetch_add(1, Ordering::AcqRel); + UDP_SESSIONS.lock().insert(assoc_id, assoc_socket.clone()); + let mut connected = None; -pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { - let sessions = UDP_SESSIONS.lock(); - let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() }; - assoc_socket.send(pkt, 0, addr).await?; - Ok(()) + async fn accept_pkt( + assoc_socket: &AssociatedUdpSocket, + connected: &mut Option, + assoc_id: u16, + ) -> Result<(), Error> { + let (pkt, frag, dst_addr, src_addr) = assoc_socket.recv_from().await?; + + if let Some(connected) = connected { + if connected != &src_addr { + Err(IoError::new( + ErrorKind::Other, + format!("invalid source address: {src_addr}"), + ))?; + } + } else { + assoc_socket.connect(src_addr).await?; + *connected = Some(src_addr); + } + + if frag != 0 { + Err(IoError::new( + ErrorKind::Other, + format!("fragmented packet is not supported"), + ))?; + } + + let target_addr = match dst_addr { + Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), + Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), + }; + + TuicConnection::get() + .await? + .packet(pkt, target_addr, assoc_id) + .await + } + + let res = tokio::select! { + res = assoc.wait_until_closed() => res, + _ = async { loop { + if let Err(err) = accept_pkt(&assoc_socket, &mut connected, assoc_id).await { + eprintln!("{err}"); + } + }} => unreachable!(), + }; + + let _ = assoc.shutdown().await; + UDP_SESSIONS.lock().remove(&assoc_id); + + match TuicConnection::get().await { + Ok(conn) => match conn.dissociate(assoc_id).await { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + }, + Err(err) => eprintln!("{err}"), + } + + Ok(res?) + } + + pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { + let sessions = UDP_SESSIONS.lock(); + let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() }; + assoc_socket.send(pkt, 0, addr).await?; + Ok(()) + } } diff --git a/tuic-client/src/utils.rs b/tuic-client/src/utils.rs new file mode 100644 index 0000000..30ee389 --- /dev/null +++ b/tuic-client/src/utils.rs @@ -0,0 +1,65 @@ +use serde::{de::Error as DeError, Deserialize, Deserializer}; +use std::{fmt::Display, net::IpAddr, str::FromStr}; + +pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result +where + T: FromStr, + ::Err: Display, + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + T::from_str(&s).map_err(DeError::custom) +} + +pub struct ServerAddr { + domain: String, + port: u16, + ip: Option, +} + +impl ServerAddr { + pub fn new(domain: String, port: u16, ip: Option) -> Self { + Self { domain, port, ip } + } +} + +pub enum UdpRelayMode { + Native, + Quic, +} + +impl FromStr for UdpRelayMode { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + if s.eq_ignore_ascii_case("native") { + Ok(Self::Native) + } else if s.eq_ignore_ascii_case("quic") { + Ok(Self::Quic) + } else { + Err("invalid UDP relay mode") + } + } +} + +pub enum CongestionControl { + Cubic, + NewReno, + Bbr, +} + +impl FromStr for CongestionControl { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + if s.eq_ignore_ascii_case("cubic") { + Ok(Self::Cubic) + } else if s.eq_ignore_ascii_case("new_reno") || s.eq_ignore_ascii_case("newreno") { + Ok(Self::NewReno) + } else if s.eq_ignore_ascii_case("bbr") { + Ok(Self::Bbr) + } else { + Err("invalid congestion control") + } + } +} From 7d395ca8259ecb58b40ec7fb93b52aba896e3873 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 2 Feb 2023 20:44:47 +0900 Subject: [PATCH 083/103] reading client config from socks5 server --- tuic-client/Cargo.toml | 1 + tuic-client/src/error.rs | 2 ++ tuic-client/src/main.rs | 2 +- tuic-client/src/socks5.rs | 72 +++++++++++++++++++++++++++++++-------- 4 files changed, 61 insertions(+), 16 deletions(-) diff --git a/tuic-client/Cargo.toml b/tuic-client/Cargo.toml index 3af8623..1a614f1 100644 --- a/tuic-client/Cargo.toml +++ b/tuic-client/Cargo.toml @@ -12,6 +12,7 @@ quinn = { version = "0.9.3", default-features = false, features = ["futures-io", register-count = { version = "0.1.0", default-features = false, features = ["std"] } serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] } serde_json = { version = "1.0.91", default-features = false, features = ["std"] } +socket2 = { version = "0.4.7", default-features = false } socks5-proto = { version = "0.3.3", default-features = false } socks5-server = { version = "0.8.3", default-features = false } thiserror = { version = "1.0.38", default-features = false } diff --git a/tuic-client/src/error.rs b/tuic-client/src/error.rs index c250262..74fa0c5 100644 --- a/tuic-client/src/error.rs +++ b/tuic-client/src/error.rs @@ -15,4 +15,6 @@ pub enum Error { Model(#[from] ModelError), #[error("timeout")] Timeout, + #[error("invalid authentication")] + InvalidAuth, } diff --git a/tuic-client/src/main.rs b/tuic-client/src/main.rs index e1d83e5..cc06f30 100644 --- a/tuic-client/src/main.rs +++ b/tuic-client/src/main.rs @@ -34,7 +34,7 @@ async fn main() { } } - match Server::set_config(cfg.local).await { + match Server::set_config(cfg.local) { Ok(()) => {} Err(err) => { eprintln!("{err}"); diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index 00af2b5..bc82251 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -2,16 +2,17 @@ use crate::{config::Local, connection::Connection as TuicConnection, error::Erro use bytes::Bytes; use once_cell::sync::{Lazy, OnceCell}; use parking_lot::Mutex; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use socks5_proto::{Address, Reply}; use socks5_server::{ - auth::NoAuth, + auth::{NoAuth, Password}, connection::{associate, bind, connect}, - Associate, AssociatedUdpSocket, Bind, Connect, Connection, Server as Socks5Server, + Associate, AssociatedUdpSocket, Auth, Bind, Connect, Connection, Server as Socks5Server, }; use std::{ collections::HashMap, io::{Error as IoError, ErrorKind}, - net::SocketAddr, + net::{IpAddr, SocketAddr, TcpListener as StdTcpListener, UdpSocket as StdUdpSocket}, sync::{ atomic::{AtomicU16, Ordering}, Arc, @@ -19,7 +20,7 @@ use std::{ }; use tokio::{ io::{self, AsyncWriteExt}, - net::UdpSocket, + net::{TcpListener, UdpSocket}, }; use tokio_util::compat::FuturesAsyncReadCompatExt; use tuic::Address as TuicAddress; @@ -31,16 +32,41 @@ static UDP_SESSIONS: Lazy>>> = pub struct Server { inner: Socks5Server, + addr: SocketAddr, dual_stack: Option, max_packet_size: usize, } impl Server { - pub async fn set_config(cfg: Local) -> Result<(), Error> { - let server = Socks5Server::bind(cfg.server, Arc::new(NoAuth)).await?; + pub fn set_config(cfg: Local) -> Result<(), Error> { + let socket = { + let domain = match cfg.server.ip() { + IpAddr::V4(_) => Domain::IPV4, + IpAddr::V6(_) => Domain::IPV6, + }; + + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + + if let Some(dual_stack) = cfg.dual_stack { + socket.set_only_v6(!dual_stack)?; + } + + socket.set_reuse_address(true)?; + socket.bind(&SockAddr::from(cfg.server))?; + TcpListener::from_std(StdTcpListener::from(socket))? + }; + + let auth: Arc = match (cfg.username, cfg.password) { + (Some(username), Some(password)) => { + Arc::new(Password::new(username.into_bytes(), password.into_bytes())) + } + (None, None) => Arc::new(NoAuth), + _ => return Err(Error::InvalidAuth), + }; let server = Self { - inner: server, + inner: Socks5Server::new(socket, auth), + addr: cfg.server, dual_stack: cfg.dual_stack, max_packet_size: cfg.max_packet_size, }; @@ -86,15 +112,31 @@ impl Server { assoc: Associate, _addr: Address, ) -> Result<(), Error> { - let assoc_socket = UdpSocket::bind(SocketAddr::from((assoc.local_addr()?.ip(), 0))) - .await - .and_then(|socket| { - socket - .local_addr() - .map(|addr| (Arc::new(AssociatedUdpSocket::from((socket, 1500))), addr)) - }); + async fn get_assoc_socket() -> Result<(Arc, SocketAddr), IoError> { + let domain = match SERVER.get().unwrap().addr.ip() { + IpAddr::V4(_) => Domain::IPV4, + IpAddr::V6(_) => Domain::IPV6, + }; - match assoc_socket { + let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + + if let Some(dual_stack) = SERVER.get().unwrap().dual_stack { + socket.set_only_v6(!dual_stack)?; + } + + socket.set_reuse_address(true)?; + socket.bind(&SockAddr::from(SERVER.get().unwrap().addr))?; + + let socket = AssociatedUdpSocket::from(( + UdpSocket::from_std(StdUdpSocket::from(socket))?, + SERVER.get().unwrap().max_packet_size, + )); + + let addr = socket.local_addr()?; + Ok((Arc::new(socket), addr)) + } + + match get_assoc_socket().await { Ok((assoc_socket, assoc_addr)) => { let assoc = assoc .reply(Reply::Succeeded, Address::SocketAddress(assoc_addr)) From 61477f5094cb45f72b672bb6a5a0ea5ff9af3105 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 2 Feb 2023 20:55:13 +0900 Subject: [PATCH 084/103] moving `next_assoc_id`&`udp_sessions`into`Server` --- tuic-client/src/socks5.rs | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index bc82251..37d31aa 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -1,6 +1,6 @@ use crate::{config::Local, connection::Connection as TuicConnection, error::Error}; use bytes::Bytes; -use once_cell::sync::{Lazy, OnceCell}; +use once_cell::sync::OnceCell; use parking_lot::Mutex; use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use socks5_proto::{Address, Reply}; @@ -26,15 +26,14 @@ use tokio_util::compat::FuturesAsyncReadCompatExt; use tuic::Address as TuicAddress; static SERVER: OnceCell = OnceCell::new(); -static NEXT_ASSOCIATE_ID: AtomicU16 = AtomicU16::new(0); -static UDP_SESSIONS: Lazy>>> = - Lazy::new(|| Mutex::new(HashMap::new())); pub struct Server { inner: Socks5Server, addr: SocketAddr, dual_stack: Option, - max_packet_size: usize, + max_pkt_size: usize, + next_assoc_id: AtomicU16, + udp_sessions: Mutex>>, } impl Server { @@ -68,7 +67,9 @@ impl Server { inner: Socks5Server::new(socket, auth), addr: cfg.server, dual_stack: cfg.dual_stack, - max_packet_size: cfg.max_packet_size, + max_pkt_size: cfg.max_packet_size, + next_assoc_id: AtomicU16::new(0), + udp_sessions: Mutex::new(HashMap::new()), }; SERVER @@ -129,7 +130,7 @@ impl Server { let socket = AssociatedUdpSocket::from(( UdpSocket::from_std(StdUdpSocket::from(socket))?, - SERVER.get().unwrap().max_packet_size, + SERVER.get().unwrap().max_pkt_size, )); let addr = socket.local_addr()?; @@ -206,8 +207,19 @@ impl Server { mut assoc: Associate, assoc_socket: Arc, ) -> Result<(), Error> { - let assoc_id = NEXT_ASSOCIATE_ID.fetch_add(1, Ordering::AcqRel); - UDP_SESSIONS.lock().insert(assoc_id, assoc_socket.clone()); + let assoc_id = SERVER + .get() + .unwrap() + .next_assoc_id + .fetch_add(1, Ordering::AcqRel); + + SERVER + .get() + .unwrap() + .udp_sessions + .lock() + .insert(assoc_id, assoc_socket.clone()); + let mut connected = None; async fn accept_pkt( @@ -257,7 +269,7 @@ impl Server { }; let _ = assoc.shutdown().await; - UDP_SESSIONS.lock().remove(&assoc_id); + SERVER.get().unwrap().udp_sessions.lock().remove(&assoc_id); match TuicConnection::get().await { Ok(conn) => match conn.dissociate(assoc_id).await { @@ -271,7 +283,7 @@ impl Server { } pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { - let sessions = UDP_SESSIONS.lock(); + let sessions = SERVER.get().unwrap().udp_sessions.lock(); let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() }; assoc_socket.send(pkt, 0, addr).await?; Ok(()) From 0070112e23bc48b9e4429322fbcbe3554d54f2be Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 2 Feb 2023 21:02:52 +0900 Subject: [PATCH 085/103] changing authentication token length to 32 --- tuic-client/src/connection.rs | 2 +- tuic-quinn/src/lib.rs | 4 ++-- tuic/src/model/authenticate.rs | 8 ++++---- tuic/src/model/mod.rs | 2 +- tuic/src/protocol/authenticate.rs | 12 ++++++------ tuic/src/unmarshal.rs | 4 ++-- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index 7d14e8c..84bf348 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -251,7 +251,7 @@ impl Connection { } async fn authenticate(self) { - match self.model.authenticate([0; 8]).await { + match self.model.authenticate([0; 32]).await { Ok(()) => {} Err(err) => eprintln!("{err}"), } diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 8c94025..29057ba 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -102,7 +102,7 @@ impl Connection { } } - pub async fn authenticate(&self, token: [u8; 8]) -> Result<(), Error> { + pub async fn authenticate(&self, token: [u8; 32]) -> Result<(), Error> { let model = self.model.send_authenticate(token); let mut send = self.conn.open_uni().await?; model.header().async_marshal(&mut send).await?; @@ -388,7 +388,7 @@ impl Packet { #[non_exhaustive] pub enum Task { - Authenticate([u8; 8]), + Authenticate([u8; 32]), Connect(Connect), Packet(Packet), Dissociate(u16), diff --git a/tuic/src/model/authenticate.rs b/tuic/src/model/authenticate.rs index e02940d..9dcb965 100644 --- a/tuic/src/model/authenticate.rs +++ b/tuic/src/model/authenticate.rs @@ -11,7 +11,7 @@ pub struct Tx { } impl Authenticate { - pub(super) fn new(token: [u8; 8]) -> Self { + pub(super) fn new(token: [u8; 32]) -> Self { Self { inner: Side::Tx(Tx { header: Header::Authenticate(AuthenticateHeader::new(token)), @@ -27,18 +27,18 @@ impl Authenticate { } pub struct Rx { - token: [u8; 8], + token: [u8; 32], } impl Authenticate { - pub(super) fn new(token: [u8; 8]) -> Self { + pub(super) fn new(token: [u8; 32]) -> Self { Self { inner: Side::Rx(Rx { token }), _marker: side::Rx, } } - pub fn token(&self) -> [u8; 8] { + pub fn token(&self) -> [u8; 32] { let Side::Rx(rx) = &self.inner else { unreachable!() }; rx.token } diff --git a/tuic/src/model/mod.rs b/tuic/src/model/mod.rs index 5ce3c71..9e66fe1 100644 --- a/tuic/src/model/mod.rs +++ b/tuic/src/model/mod.rs @@ -50,7 +50,7 @@ where } } - pub fn send_authenticate(&self, token: [u8; 8]) -> Authenticate { + pub fn send_authenticate(&self, token: [u8; 32]) -> Authenticate { Authenticate::::new(token) } diff --git a/tuic/src/protocol/authenticate.rs b/tuic/src/protocol/authenticate.rs index b17f791..9879241 100644 --- a/tuic/src/protocol/authenticate.rs +++ b/tuic/src/protocol/authenticate.rs @@ -1,21 +1,21 @@ // +-------+ // | TOKEN | // +-------+ -// | 8 | +// | 32 | // +-------+ #[derive(Clone, Debug)] pub struct Authenticate { - token: [u8; 8], + token: [u8; 32], } impl Authenticate { const TYPE_CODE: u8 = 0x00; - pub const fn new(token: [u8; 8]) -> Self { + pub const fn new(token: [u8; 32]) -> Self { Self { token } } - pub fn token(&self) -> [u8; 8] { + pub fn token(&self) -> [u8; 32] { self.token } @@ -24,11 +24,11 @@ impl Authenticate { } pub fn len(&self) -> usize { - 8 + 32 } } -impl From for ([u8; 8],) { +impl From for ([u8; 32],) { fn from(auth: Authenticate) -> Self { (auth.token,) } diff --git a/tuic/src/unmarshal.rs b/tuic/src/unmarshal.rs index 7b4b48a..ca037c2 100644 --- a/tuic/src/unmarshal.rs +++ b/tuic/src/unmarshal.rs @@ -164,14 +164,14 @@ impl Address { impl Authenticate { #[cfg(feature = "async_marshal")] async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { - let mut buf = [0; 8]; + let mut buf = [0; 32]; s.read_exact(&mut buf).await?; Ok(Self::new(buf)) } #[cfg(feature = "marshal")] fn read(s: &mut impl Read) -> Result { - let mut buf = [0; 8]; + let mut buf = [0; 32]; s.read_exact(&mut buf)?; Ok(Self::new(buf)) } From 9806c62fe70add6b286f40dd0ae04949acac42a9 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 2 Feb 2023 23:42:53 +0900 Subject: [PATCH 086/103] reading client config to endpoint --- tuic-client/Cargo.toml | 5 ++ tuic-client/src/config.rs | 113 ++++++++++++++---------- tuic-client/src/connection.rs | 157 ++++++++++++++++++++++++++++------ tuic-client/src/error.rs | 11 ++- tuic-client/src/socks5.rs | 2 +- tuic-client/src/utils.rs | 62 +++++++++++--- 6 files changed, 265 insertions(+), 85 deletions(-) diff --git a/tuic-client/Cargo.toml b/tuic-client/Cargo.toml index 1a614f1..b291e7f 100644 --- a/tuic-client/Cargo.toml +++ b/tuic-client/Cargo.toml @@ -5,11 +5,15 @@ edition = "2021" [dependencies] bytes = { version = "1.3.0", default-features = false, features = ["std"] } +crossbeam-utils = { version = "0.8.14", default-features = false, features = ["std"] } lexopt = { version = "0.3.0", default-features = false } once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] } parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] } quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] } register-count = { version = "0.1.0", default-features = false, features = ["std"] } +rustls = { version = "0.20.8", default-features = false, features = ["quic"] } +rustls-native-certs = { version = "0.6.2", default-features = false } +rustls-pemfile = { version = "1.0.2", default-features = false } serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] } serde_json = { version = "1.0.91", default-features = false, features = ["std"] } socket2 = { version = "0.4.7", default-features = false } @@ -20,3 +24,4 @@ tokio = { version = "1.24.2", default-features = false, features = ["macros", "n tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] } tuic = { path = "../tuic", default-features = false } tuic-quinn = { path = "../tuic-quinn", default-features = false } +webpki = { version = "0.22.0", default-features = false } diff --git a/tuic-client/src/config.rs b/tuic-client/src/config.rs index 81f35db..eb06b28 100644 --- a/tuic-client/src/config.rs +++ b/tuic-client/src/config.rs @@ -1,12 +1,14 @@ -use crate::utils::{self, CongestionControl, UdpRelayMode}; +use crate::utils::{CongestionControl, UdpRelayMode}; use lexopt::{Arg, Error as ArgumentError, Parser}; use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde_json::Error as SerdeError; use std::{ env::ArgsOs, + fmt::Display, fs::File, io::Error as IoError, net::{IpAddr, SocketAddr}, + str::FromStr, time::Duration, }; use thiserror::Error; @@ -38,22 +40,26 @@ pub struct Relay { pub certificates: Vec, #[serde( default = "default::relay::udp_relay_mode", - deserialize_with = "utils::deserialize_from_str" + deserialize_with = "deserialize_from_str" )] pub udp_relay_mode: UdpRelayMode, #[serde( default = "default::relay::congestion_control", - deserialize_with = "utils::deserialize_from_str" + deserialize_with = "deserialize_from_str" )] pub congestion_control: CongestionControl, #[serde(default = "default::relay::alpn")] pub alpn: Vec, #[serde(default = "default::relay::zero_rtt_handshake")] pub zero_rtt_handshake: bool, + #[serde(default = "default::relay::disable_sni")] + pub disable_sni: bool, #[serde(default = "default::relay::timeout")] pub timeout: Duration, #[serde(default = "default::relay::heartbeat")] pub heartbeat: Duration, + #[serde(default = "default::relay::disable_native_certificates")] + pub disable_native_certificates: bool, } #[derive(Deserialize)] @@ -67,47 +73,6 @@ pub struct Local { pub max_packet_size: usize, } -mod default { - pub mod relay { - use crate::utils::{CongestionControl, UdpRelayMode}; - use std::time::Duration; - - pub const fn certificates() -> Vec { - Vec::new() - } - - pub const fn udp_relay_mode() -> UdpRelayMode { - UdpRelayMode::Native - } - - pub const fn congestion_control() -> CongestionControl { - CongestionControl::Cubic - } - - pub const fn alpn() -> Vec { - Vec::new() - } - - pub const fn zero_rtt_handshake() -> bool { - false - } - - pub const fn timeout() -> Duration { - Duration::from_secs(8) - } - - pub const fn heartbeat() -> Duration { - Duration::from_secs(3) - } - } - - pub mod local { - pub const fn max_packet_size() -> usize { - 1500 - } - } -} - impl Config { pub fn parse(args: ArgsOs) -> Result { let mut parser = Parser::from_iter(args); @@ -135,11 +100,69 @@ impl Config { } let file = File::open(path.unwrap())?; - Ok(serde_json::from_reader(file)?) } } +mod default { + pub mod relay { + use crate::utils::{CongestionControl, UdpRelayMode}; + use std::time::Duration; + + pub fn certificates() -> Vec { + Vec::new() + } + + pub fn udp_relay_mode() -> UdpRelayMode { + UdpRelayMode::Native + } + + pub fn congestion_control() -> CongestionControl { + CongestionControl::Cubic + } + + pub fn alpn() -> Vec { + Vec::new() + } + + pub fn zero_rtt_handshake() -> bool { + false + } + + pub fn disable_sni() -> bool { + false + } + + pub fn timeout() -> Duration { + Duration::from_secs(8) + } + + pub fn heartbeat() -> Duration { + Duration::from_secs(3) + } + + pub fn disable_native_certificates() -> bool { + false + } + } + + pub mod local { + pub fn max_packet_size() -> usize { + 1500 + } + } +} + +pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result +where + T: FromStr, + ::Err: Display, + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + T::from_str(&s).map_err(DeError::custom) +} + pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error> where D: Deserializer<'de>, diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index 84bf348..2a62e6c 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -2,18 +2,22 @@ use crate::{ config::Relay, error::Error, socks5::Server as Socks5Server, - utils::{ServerAddr, UdpRelayMode}, + utils::{self, CongestionControl, ServerAddr, UdpRelayMode}, }; use bytes::Bytes; +use crossbeam_utils::atomic::AtomicCell; use once_cell::sync::OnceCell; use parking_lot::Mutex; use quinn::{ - Connection as QuinnConnection, Endpoint as QuinnEndpoint, RecvStream, SendStream, VarInt, + congestion::{BbrConfig, CubicConfig, NewRenoConfig}, + ClientConfig, Connection as QuinnConnection, Endpoint as QuinnEndpoint, EndpointConfig, + RecvStream, SendStream, TokioRuntime, TransportConfig, VarInt, }; use register_count::{Counter, Register}; +use rustls::{version, ClientConfig as RustlsClientConfig}; use socks5_proto::Address as Socks5Address; use std::{ - net::SocketAddr, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -29,30 +33,67 @@ use tuic_quinn::{side, Connect, Connection as Model, Task}; static ENDPOINT: OnceCell> = OnceCell::new(); static CONNECTION: AsyncOnceCell> = AsyncOnceCell::const_new(); +static TIMEOUT: AtomicCell = AtomicCell::new(Duration::from_secs(0)); const DEFAULT_CONCURRENT_STREAMS: usize = 32; pub struct Endpoint { ep: QuinnEndpoint, server: ServerAddr, - token: Vec, + token: Arc<[u8]>, udp_relay_mode: UdpRelayMode, zero_rtt_handshake: bool, - timeout: Duration, heartbeat: Duration, } impl Endpoint { pub fn set_config(cfg: Relay) -> Result<(), Error> { - let ep = todo!(); + let certs = utils::load_certs(cfg.certificates, cfg.disable_native_certificates)?; + + let mut crypto = RustlsClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&version::TLS13]) + .unwrap() + .with_root_certificates(certs) + .with_no_client_auth(); + + crypto.alpn_protocols = cfg.alpn.into_iter().map(|alpn| alpn.into_bytes()).collect(); + crypto.enable_early_data = true; + crypto.enable_sni = !cfg.disable_sni; + + let mut config = ClientConfig::new(Arc::new(crypto)); + let mut tp_cfg = TransportConfig::default(); + + tp_cfg + .max_concurrent_bidi_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32)) + .max_concurrent_uni_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32)) + .max_idle_timeout(None); + + match cfg.congestion_control { + CongestionControl::Cubic => { + tp_cfg.congestion_controller_factory(Arc::new(CubicConfig::default())) + } + CongestionControl::NewReno => { + tp_cfg.congestion_controller_factory(Arc::new(NewRenoConfig::default())) + } + CongestionControl::Bbr => { + tp_cfg.congestion_controller_factory(Arc::new(BbrConfig::default())) + } + }; + + config.transport_config(Arc::new(tp_cfg)); + + let socket = UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], 0)))?; + let mut ep = QuinnEndpoint::new(EndpointConfig::default(), None, socket, TokioRuntime)?; + ep.set_default_client_config(config); let ep = Self { ep, server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip), - token: cfg.token.into_bytes(), + token: Arc::from(cfg.token.into_bytes().into_boxed_slice()), udp_relay_mode: cfg.udp_relay_mode, zero_rtt_handshake: cfg.zero_rtt_handshake, - timeout: cfg.timeout, heartbeat: cfg.heartbeat, }; @@ -61,19 +102,67 @@ impl Endpoint { .map_err(|_| "endpoint already initialized") .unwrap(); + TIMEOUT.store(cfg.timeout); + Ok(()) } - async fn connect(&self) -> Result { - let conn = self - .ep - .connect(SocketAddr::from(([127, 0, 0, 1], 8080)), "localhost")? + async fn connect(&mut self) -> Result { + async fn connect_to( + ep: &mut QuinnEndpoint, + addr: SocketAddr, + server_name: &str, + udp_relay_mode: UdpRelayMode, + zero_rtt_handshake: bool, + ) -> Result { + let match_ipv4 = addr.is_ipv4() && ep.local_addr().map_or(false, |addr| addr.is_ipv4()); + let match_ipv6 = addr.is_ipv6() && ep.local_addr().map_or(false, |addr| addr.is_ipv6()); + + if !match_ipv4 && !match_ipv6 { + let bind_addr = if addr.is_ipv4() { + SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)) + } else { + SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)) + }; + + ep.rebind(UdpSocket::bind(bind_addr)?)?; + } + + let conn = ep.connect(addr, server_name)?; + + let conn = if zero_rtt_handshake { + match conn.into_0rtt() { + Ok((conn, _)) => conn, + Err(conn) => conn.await?, + } + } else { + conn.await? + }; + + Ok(Connection::new(conn, udp_relay_mode)) + } + + let mut last_err = None; + + for addr in self.server.resolve().await? { + match connect_to( + &mut self.ep, + addr, + self.server.server_name(), + self.udp_relay_mode, + self.zero_rtt_handshake, + ) .await - .map(Connection::new)?; + { + Ok(conn) => { + tokio::spawn(conn.clone().init(self.token.clone(), self.heartbeat)); + return Ok(conn); + } + Err(err) => last_err = Some(err), + } + } - tokio::spawn(conn.clone().init()); - - Ok(conn) + Err(last_err.unwrap_or(Error::DnsResolve)) } } @@ -81,6 +170,7 @@ impl Endpoint { pub struct Connection { conn: QuinnConnection, model: Model, + udp_relay_mode: UdpRelayMode, remote_uni_stream_cnt: Counter, remote_bi_stream_cnt: Counter, max_concurrent_uni_streams: Arc, @@ -88,10 +178,11 @@ pub struct Connection { } impl Connection { - fn new(conn: QuinnConnection) -> Self { + fn new(conn: QuinnConnection, udp_relay_mode: UdpRelayMode) -> Self { Self { conn: conn.clone(), model: Model::::new(conn), + udp_relay_mode, remote_uni_stream_cnt: Counter::new(), remote_bi_stream_cnt: Counter::new(), max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), @@ -125,7 +216,7 @@ impl Connection { Ok::<_, Error>(conn.clone()) }; - let conn = time::timeout(Duration::from_secs(5), try_get_conn) + let conn = time::timeout(TIMEOUT.load(), try_get_conn) .await .map_err(|_| Error::Timeout)??; @@ -137,7 +228,11 @@ impl Connection { } pub async fn packet(&self, pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { - self.model.packet_quic(pkt, addr, assoc_id).await?; // TODO + match self.udp_relay_mode { + UdpRelayMode::Native => self.model.packet_native(pkt, addr, assoc_id)?, + UdpRelayMode::Quic => self.model.packet_quic(pkt, addr, assoc_id).await?, + } + Ok(()) } @@ -250,16 +345,26 @@ impl Connection { } } - async fn authenticate(self) { - match self.model.authenticate([0; 32]).await { + async fn authenticate(self, token: Arc<[u8]>) { + let mut buf = [0; 32]; + + match self.conn.export_keying_material(&mut buf, &token, &token) { + Ok(()) => {} + Err(_) => { + eprintln!("token length too short"); + return; + } + } + + match self.model.authenticate(buf).await { Ok(()) => {} Err(err) => eprintln!("{err}"), } } - async fn heartbeat(self) { + async fn heartbeat(self, heartbeat: Duration) { loop { - time::sleep(Duration::from_secs(5)).await; + time::sleep(heartbeat).await; if self.is_closed() { break; @@ -272,9 +377,9 @@ impl Connection { } } - async fn init(self) { - tokio::spawn(self.clone().authenticate()); - tokio::spawn(self.clone().heartbeat()); + async fn init(self, token: Arc<[u8]>, heartbeat: Duration) { + tokio::spawn(self.clone().authenticate(token)); + tokio::spawn(self.clone().heartbeat(heartbeat)); let err = loop { tokio::select! { diff --git a/tuic-client/src/error.rs b/tuic-client/src/error.rs index 74fa0c5..8988b01 100644 --- a/tuic-client/src/error.rs +++ b/tuic-client/src/error.rs @@ -2,6 +2,7 @@ use quinn::{ConnectError, ConnectionError}; use std::io::Error as IoError; use thiserror::Error; use tuic_quinn::Error as ModelError; +use webpki::Error as WebpkiError; #[derive(Debug, Error)] pub enum Error { @@ -13,8 +14,12 @@ pub enum Error { Connection(#[from] ConnectionError), #[error(transparent)] Model(#[from] ModelError), - #[error("timeout")] + #[error(transparent)] + Webpki(#[from] WebpkiError), + #[error("timeout establishing connection")] Timeout, - #[error("invalid authentication")] - InvalidAuth, + #[error("cannot resolve the server name")] + DnsResolve, + #[error("invalid socks5 authentication")] + InvalidSocks5Auth, } diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index 37d31aa..d44b58a 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -60,7 +60,7 @@ impl Server { Arc::new(Password::new(username.into_bytes(), password.into_bytes())) } (None, None) => Arc::new(NoAuth), - _ => return Err(Error::InvalidAuth), + _ => return Err(Error::InvalidSocks5Auth), }; let server = Self { diff --git a/tuic-client/src/utils.rs b/tuic-client/src/utils.rs index 30ee389..afa7695 100644 --- a/tuic-client/src/utils.rs +++ b/tuic-client/src/utils.rs @@ -1,14 +1,40 @@ -use serde::{de::Error as DeError, Deserialize, Deserializer}; -use std::{fmt::Display, net::IpAddr, str::FromStr}; +use crate::error::Error; +use rustls::{Certificate, RootCertStore}; +use rustls_pemfile::Item; +use std::{ + fs::{self, File}, + io::BufReader, + net::{IpAddr, SocketAddr}, + str::FromStr, +}; +use tokio::net; -pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result -where - T: FromStr, - ::Err: Display, - D: Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - T::from_str(&s).map_err(DeError::custom) +pub fn load_certs(paths: Vec, disable_native: bool) -> Result { + let mut certs = RootCertStore::empty(); + + for path in &paths { + let mut file = BufReader::new(File::open(path)?); + + while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) { + if let Item::X509Certificate(cert) = item { + certs.add(&Certificate(cert))?; + } + } + } + + if certs.is_empty() { + for path in &paths { + certs.add(&Certificate(fs::read(path)?))?; + } + } + + if !disable_native { + for cert in rustls_native_certs::load_native_certs()? { + certs.add(&Certificate(cert.0))?; + } + } + + Ok(certs) } pub struct ServerAddr { @@ -21,8 +47,24 @@ impl ServerAddr { pub fn new(domain: String, port: u16, ip: Option) -> Self { Self { domain, port, ip } } + + pub fn server_name(&self) -> &str { + &self.domain + } + + pub async fn resolve(&self) -> Result, Error> { + if let Some(ip) = self.ip { + Ok(vec![SocketAddr::from((ip, self.port))].into_iter()) + } else { + Ok(net::lookup_host((self.domain.as_str(), self.port)) + .await? + .collect::>() + .into_iter()) + } + } } +#[derive(Clone, Copy)] pub enum UdpRelayMode { Native, Quic, From 2eeaec66d42ca5afef1a4872c4c9b26dbaa7748f Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 3 Feb 2023 00:48:23 +0900 Subject: [PATCH 087/103] filtering packet source --- tuic-client/src/connection.rs | 58 +++++++++++++++++++---------------- tuic-client/src/error.rs | 2 ++ 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index 2a62e6c..38787d7 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -284,19 +284,22 @@ impl Connection { async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { let res = match self.model.accept_uni_stream(recv).await { Err(err) => Err(Error::from(err)), - Ok(Task::Packet(pkt)) => match pkt.accept().await { - Ok(Some((pkt, addr, assoc_id))) => { - let addr = match addr { - Address::None => unreachable!(), - Address::DomainAddress(domain, port) => { - Socks5Address::DomainAddress(domain, port) - } - Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), - }; - Socks5Server::recv_pkt(pkt, addr, assoc_id).await - } - Ok(None) => Ok(()), - Err(err) => Err(Error::from(err)), + Ok(Task::Packet(pkt)) => match self.udp_relay_mode { + UdpRelayMode::Quic => match pkt.accept().await { + Ok(Some((pkt, addr, assoc_id))) => { + let addr = match addr { + Address::None => unreachable!(), + Address::DomainAddress(domain, port) => { + Socks5Address::DomainAddress(domain, port) + } + Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), + }; + Socks5Server::recv_pkt(pkt, addr, assoc_id).await + } + Ok(None) => Ok(()), + Err(err) => Err(Error::from(err)), + }, + UdpRelayMode::Native => Err(Error::WrongPacketSource), }, _ => unreachable!(), }; @@ -322,19 +325,22 @@ impl Connection { async fn handle_datagram(self, dg: Bytes) { let res = match self.model.accept_datagram(dg) { Err(err) => Err(Error::from(err)), - Ok(Task::Packet(pkt)) => match pkt.accept().await { - Ok(Some((pkt, addr, assoc_id))) => { - let addr = match addr { - Address::None => unreachable!(), - Address::DomainAddress(domain, port) => { - Socks5Address::DomainAddress(domain, port) - } - Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), - }; - Socks5Server::recv_pkt(pkt, addr, assoc_id).await - } - Ok(None) => Ok(()), - Err(err) => Err(Error::from(err)), + Ok(Task::Packet(pkt)) => match self.udp_relay_mode { + UdpRelayMode::Native => match pkt.accept().await { + Ok(Some((pkt, addr, assoc_id))) => { + let addr = match addr { + Address::None => unreachable!(), + Address::DomainAddress(domain, port) => { + Socks5Address::DomainAddress(domain, port) + } + Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), + }; + Socks5Server::recv_pkt(pkt, addr, assoc_id).await + } + Ok(None) => Ok(()), + Err(err) => Err(Error::from(err)), + }, + UdpRelayMode::Quic => Err(Error::WrongPacketSource), }, _ => unreachable!(), }; diff --git a/tuic-client/src/error.rs b/tuic-client/src/error.rs index 8988b01..c541e32 100644 --- a/tuic-client/src/error.rs +++ b/tuic-client/src/error.rs @@ -20,6 +20,8 @@ pub enum Error { Timeout, #[error("cannot resolve the server name")] DnsResolve, + #[error("received packet from an unexpected source")] + WrongPacketSource, #[error("invalid socks5 authentication")] InvalidSocks5Auth, } From d08945844b64a4f25f87434aab0f7ee001b81e93 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 3 Feb 2023 00:59:19 +0900 Subject: [PATCH 088/103] adding gc for fragmentary packet --- tuic-client/src/config.rs | 18 +++++++++++++++--- tuic-client/src/connection.rs | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/tuic-client/src/config.rs b/tuic-client/src/config.rs index eb06b28..8c2a0b6 100644 --- a/tuic-client/src/config.rs +++ b/tuic-client/src/config.rs @@ -58,8 +58,12 @@ pub struct Relay { pub timeout: Duration, #[serde(default = "default::relay::heartbeat")] pub heartbeat: Duration, - #[serde(default = "default::relay::disable_native_certificates")] - pub disable_native_certificates: bool, + #[serde(default = "default::relay::disable_native_certs")] + pub disable_native_certs: bool, + #[serde(default = "default::relay::gc_interval")] + pub gc_interval: Duration, + #[serde(default = "default::relay::gc_lifetime")] + pub gc_lifetime: Duration, } #[derive(Deserialize)] @@ -141,9 +145,17 @@ mod default { Duration::from_secs(3) } - pub fn disable_native_certificates() -> bool { + pub fn disable_native_certs() -> bool { false } + + pub fn gc_interval() -> Duration { + Duration::from_secs(3) + } + + pub fn gc_lifetime() -> Duration { + Duration::from_secs(15) + } } pub mod local { diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index 38787d7..dfb7bb3 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -44,11 +44,13 @@ pub struct Endpoint { udp_relay_mode: UdpRelayMode, zero_rtt_handshake: bool, heartbeat: Duration, + gc_interval: Duration, + gc_lifetime: Duration, } impl Endpoint { pub fn set_config(cfg: Relay) -> Result<(), Error> { - let certs = utils::load_certs(cfg.certificates, cfg.disable_native_certificates)?; + let certs = utils::load_certs(cfg.certificates, cfg.disable_native_certs)?; let mut crypto = RustlsClientConfig::builder() .with_safe_default_cipher_suites() @@ -95,6 +97,8 @@ impl Endpoint { udp_relay_mode: cfg.udp_relay_mode, zero_rtt_handshake: cfg.zero_rtt_handshake, heartbeat: cfg.heartbeat, + gc_interval: cfg.gc_interval, + gc_lifetime: cfg.gc_lifetime, }; ENDPOINT @@ -155,7 +159,12 @@ impl Endpoint { .await { Ok(conn) => { - tokio::spawn(conn.clone().init(self.token.clone(), self.heartbeat)); + tokio::spawn(conn.clone().init( + self.token.clone(), + self.heartbeat, + self.gc_interval, + self.gc_lifetime, + )); return Ok(conn); } Err(err) => last_err = Some(err), @@ -383,9 +392,28 @@ impl Connection { } } - async fn init(self, token: Arc<[u8]>, heartbeat: Duration) { + async fn collect_garbage(self, gc_interval: Duration, gc_lifetime: Duration) { + loop { + time::sleep(gc_interval).await; + + if self.is_closed() { + break; + } + + self.model.collect_garbage(gc_lifetime); + } + } + + async fn init( + self, + token: Arc<[u8]>, + heartbeat: Duration, + gc_interval: Duration, + gc_lifetime: Duration, + ) { tokio::spawn(self.clone().authenticate(token)); tokio::spawn(self.clone().heartbeat(heartbeat)); + tokio::spawn(self.clone().collect_garbage(gc_interval, gc_lifetime)); let err = loop { tokio::select! { From 443252c2cacb5790fe3dacb206a4d494999671cd Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 3 Feb 2023 15:23:04 +0900 Subject: [PATCH 089/103] moving `Error` to root --- tuic-client/src/connection.rs | 2 +- tuic-client/src/error.rs | 27 --------------------------- tuic-client/src/main.rs | 32 ++++++++++++++++++++++++++++---- tuic-client/src/socks5.rs | 2 +- tuic-client/src/utils.rs | 2 +- 5 files changed, 31 insertions(+), 34 deletions(-) delete mode 100644 tuic-client/src/error.rs diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index dfb7bb3..81f94c6 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -1,8 +1,8 @@ use crate::{ config::Relay, - error::Error, socks5::Server as Socks5Server, utils::{self, CongestionControl, ServerAddr, UdpRelayMode}, + Error, }; use bytes::Bytes; use crossbeam_utils::atomic::AtomicCell; diff --git a/tuic-client/src/error.rs b/tuic-client/src/error.rs deleted file mode 100644 index c541e32..0000000 --- a/tuic-client/src/error.rs +++ /dev/null @@ -1,27 +0,0 @@ -use quinn::{ConnectError, ConnectionError}; -use std::io::Error as IoError; -use thiserror::Error; -use tuic_quinn::Error as ModelError; -use webpki::Error as WebpkiError; - -#[derive(Debug, Error)] -pub enum Error { - #[error(transparent)] - Io(#[from] IoError), - #[error(transparent)] - Connect(#[from] ConnectError), - #[error(transparent)] - Connection(#[from] ConnectionError), - #[error(transparent)] - Model(#[from] ModelError), - #[error(transparent)] - Webpki(#[from] WebpkiError), - #[error("timeout establishing connection")] - Timeout, - #[error("cannot resolve the server name")] - DnsResolve, - #[error("received packet from an unexpected source")] - WrongPacketSource, - #[error("invalid socks5 authentication")] - InvalidSocks5Auth, -} diff --git a/tuic-client/src/main.rs b/tuic-client/src/main.rs index cc06f30..b1ef3a9 100644 --- a/tuic-client/src/main.rs +++ b/tuic-client/src/main.rs @@ -1,14 +1,16 @@ -use socks5::Server; - use self::{ config::{Config, ConfigError}, connection::Endpoint, + socks5::Server, }; -use std::{env, process}; +use quinn::{ConnectError, ConnectionError}; +use std::{env, io::Error as IoError, process}; +use thiserror::Error; +use tuic_quinn::Error as ModelError; +use webpki::Error as WebpkiError; mod config; mod connection; -mod error; mod socks5; mod utils; @@ -44,3 +46,25 @@ async fn main() { Server::start().await; } + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Connect(#[from] ConnectError), + #[error(transparent)] + Connection(#[from] ConnectionError), + #[error(transparent)] + Model(#[from] ModelError), + #[error(transparent)] + Webpki(#[from] WebpkiError), + #[error("timeout establishing connection")] + Timeout, + #[error("cannot resolve the server name")] + DnsResolve, + #[error("received packet from an unexpected source")] + WrongPacketSource, + #[error("invalid socks5 authentication")] + InvalidSocks5Auth, +} diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index d44b58a..2b88217 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -1,4 +1,4 @@ -use crate::{config::Local, connection::Connection as TuicConnection, error::Error}; +use crate::{config::Local, connection::Connection as TuicConnection, Error}; use bytes::Bytes; use once_cell::sync::OnceCell; use parking_lot::Mutex; diff --git a/tuic-client/src/utils.rs b/tuic-client/src/utils.rs index afa7695..752184a 100644 --- a/tuic-client/src/utils.rs +++ b/tuic-client/src/utils.rs @@ -1,4 +1,4 @@ -use crate::error::Error; +use crate::Error; use rustls::{Certificate, RootCertStore}; use rustls_pemfile::Item; use std::{ From a9ca312726d7e94d7545cea09190b93f24c2d88a Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 3 Feb 2023 17:59:50 +0900 Subject: [PATCH 090/103] re-import `Server` as `Socks5Server` --- tuic-client/src/main.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tuic-client/src/main.rs b/tuic-client/src/main.rs index b1ef3a9..81b7cb6 100644 --- a/tuic-client/src/main.rs +++ b/tuic-client/src/main.rs @@ -1,7 +1,7 @@ use self::{ config::{Config, ConfigError}, connection::Endpoint, - socks5::Server, + socks5::Server as Socks5Server, }; use quinn::{ConnectError, ConnectionError}; use std::{env, io::Error as IoError, process}; @@ -36,7 +36,7 @@ async fn main() { } } - match Server::set_config(cfg.local) { + match Socks5Server::set_config(cfg.local) { Ok(()) => {} Err(err) => { eprintln!("{err}"); @@ -44,7 +44,7 @@ async fn main() { } } - Server::start().await; + Socks5Server::start().await; } #[derive(Debug, Error)] From 5dde0278ce316958f4737e9142ee0c3fa2560c82 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 3 Feb 2023 19:49:53 +0900 Subject: [PATCH 091/103] implement server task receiving mechanism --- tuic-server/Cargo.toml | 23 +++- tuic-server/src/config.rs | 77 +++++++++++ tuic-server/src/main.rs | 54 +++++++- tuic-server/src/server.rs | 269 ++++++++++++++++++++++++++++++++++++++ tuic-server/src/utils.rs | 29 ++++ 5 files changed, 448 insertions(+), 4 deletions(-) create mode 100644 tuic-server/src/config.rs create mode 100644 tuic-server/src/server.rs create mode 100644 tuic-server/src/utils.rs diff --git a/tuic-server/Cargo.toml b/tuic-server/Cargo.toml index 23e9215..cb33da6 100644 --- a/tuic-server/Cargo.toml +++ b/tuic-server/Cargo.toml @@ -3,6 +3,25 @@ name = "tuic-server" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +bytes = { version = "1.3.0", default-features = false, features = ["std"] } +crossbeam-utils = { version = "0.8.14", default-features = false, features = ["std"] } +lexopt = { version = "0.3.0", default-features = false } +once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] } +parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] } +quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] } +register-count = { version = "0.1.0", default-features = false, features = ["std"] } +rustls = { version = "0.20.8", default-features = false, features = ["quic"] } +rustls-native-certs = { version = "0.6.2", default-features = false } +rustls-pemfile = { version = "1.0.2", default-features = false } +serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] } +serde_json = { version = "1.0.91", default-features = false, features = ["std"] } +socket2 = { version = "0.4.7", default-features = false } +socks5-proto = { version = "0.3.3", default-features = false } +socks5-server = { version = "0.8.3", default-features = false } +thiserror = { version = "1.0.38", default-features = false } +tokio = { version = "1.24.2", default-features = false, features = ["macros", "net", "parking_lot", "rt-multi-thread", "time"] } +tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] } +tuic = { path = "../tuic", default-features = false } +tuic-quinn = { path = "../tuic-quinn", default-features = false } +webpki = { version = "0.22.0", default-features = false } \ No newline at end of file diff --git a/tuic-server/src/config.rs b/tuic-server/src/config.rs new file mode 100644 index 0000000..c08805d --- /dev/null +++ b/tuic-server/src/config.rs @@ -0,0 +1,77 @@ +use lexopt::{Arg, Error as ArgumentError, Parser}; +use serde::{de::Error as DeError, Deserialize, Deserializer}; +use serde_json::Error as SerdeError; +use std::{env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, str::FromStr}; +use thiserror::Error; + +const HELP_MSG: &str = r#" +Usage tuic-server [arguments] + +Arguments: + -c, --config Path to the config file (required) + -v, --version Print the version + -h, --help Print this help message +"#; + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Config {} + +impl Config { + pub fn parse(args: ArgsOs) -> Result { + let mut parser = Parser::from_iter(args); + let mut path = None; + + while let Some(arg) = parser.next()? { + match arg { + Arg::Short('c') | Arg::Long("config") => { + if path.is_none() { + path = Some(parser.value()?); + } else { + return Err(ConfigError::Argument(arg.unexpected())); + } + } + Arg::Short('v') | Arg::Long("version") => { + return Err(ConfigError::Version(env!("CARGO_PKG_VERSION"))) + } + Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(HELP_MSG)), + _ => return Err(ConfigError::Argument(arg.unexpected())), + } + } + + if path.is_none() { + return Err(ConfigError::NoConfig); + } + + let file = File::open(path.unwrap())?; + Ok(serde_json::from_reader(file)?) + } +} + +mod default {} + +pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result +where + T: FromStr, + ::Err: Display, + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + T::from_str(&s).map_err(DeError::custom) +} + +#[derive(Debug, Error)] +pub enum ConfigError { + #[error(transparent)] + Argument(#[from] ArgumentError), + #[error("no config file specified")] + NoConfig, + #[error("{0}")] + Version(&'static str), + #[error("{0}")] + Help(&'static str), + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Serde(#[from] SerdeError), +} diff --git a/tuic-server/src/main.rs b/tuic-server/src/main.rs index e7a11a9..d52e01a 100644 --- a/tuic-server/src/main.rs +++ b/tuic-server/src/main.rs @@ -1,3 +1,53 @@ -fn main() { - println!("Hello, world!"); +use self::{ + config::{Config, ConfigError}, + server::Server, +}; +use quinn::{crypto::ExportKeyingMaterialError, ConnectionError}; +use std::{env, io::Error as IoError, process}; +use thiserror::Error; +use tuic_quinn::Error as ModelError; + +mod config; +mod server; +mod utils; + +#[tokio::main] +async fn main() { + let cfg = match Config::parse(env::args_os()) { + Ok(cfg) => cfg, + Err(ConfigError::Version(msg) | ConfigError::Help(msg)) => { + println!("{msg}"); + process::exit(0); + } + Err(err) => { + eprintln!("{err}"); + process::exit(1); + } + }; + + match Server::init(cfg) { + Ok(server) => server.start().await, + Err(err) => { + eprintln!("{err}"); + process::exit(1); + } + } +} + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Connection(#[from] ConnectionError), + #[error(transparent)] + Model(#[from] ModelError), + #[error("duplicated authentication")] + DuplicatedAuth, + #[error("token length too short")] + ExportKeyingMaterial, + #[error("authentication failed")] + AuthFailed, + #[error("received packet from unexpected source")] + UnexpectedPacketSource, } diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs new file mode 100644 index 0000000..9101202 --- /dev/null +++ b/tuic-server/src/server.rs @@ -0,0 +1,269 @@ +use crate::{config::Config, utils::UdpRelayMode, Error}; +use bytes::Bytes; +use crossbeam_utils::atomic::AtomicCell; +use parking_lot::Mutex; +use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; +use std::{ + future::Future, + pin::Pin, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + task::{Context, Poll, Waker}, +}; +use tuic_quinn::{side, Connection as Model, Task}; + +pub struct Server { + ep: Endpoint, + token: Arc<[u8]>, + zero_rtt_handshake: bool, +} + +impl Server { + pub fn init(cfg: Config) -> Result { + todo!() + } + + pub async fn start(&self) { + loop { + let conn = self.ep.accept().await.unwrap(); + tokio::spawn(Connection::init( + conn, + self.token.clone(), + self.zero_rtt_handshake, + )); + } + } +} + +#[derive(Clone)] +struct Connection { + inner: QuinnConnection, + model: Model, + token: Arc<[u8]>, + is_authed: IsAuthed, + udp_relay_mode: Arc>>, +} + +impl Connection { + pub async fn init(conn: Connecting, token: Arc<[u8]>, zero_rtt_handshake: bool) { + match Self::handshake(conn, token, zero_rtt_handshake).await { + Ok(conn) => loop { + if conn.is_closed() { + break; + } + + match conn.accept().await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + }, + Err(err) => eprintln!("{err}"), + } + } + + async fn handshake( + conn: Connecting, + token: Arc<[u8]>, + zero_rtt_handshake: bool, + ) -> Result { + let conn = if zero_rtt_handshake { + match conn.into_0rtt() { + Ok((conn, _)) => conn, + Err(conn) => { + eprintln!("0-RTT handshake failed, fallback to 1-RTT handshake"); + conn.await? + } + } + } else { + conn.await? + }; + + Ok(Self { + inner: conn.clone(), + model: Model::::new(conn), + token, + is_authed: IsAuthed::new(), + udp_relay_mode: Arc::new(AtomicCell::new(None)), + }) + } + + async fn accept(&self) -> Result<(), Error> { + tokio::select! { + res = self.inner.accept_uni() => tokio::spawn(self.clone().handle_uni_stream(res?)), + res = self.inner.accept_bi() => tokio::spawn(self.clone().handle_bi_stream(res?)), + res = self.inner.read_datagram() => tokio::spawn(self.clone().handle_datagram(res?)), + }; + + Ok(()) + } + + async fn handle_uni_stream(self, recv: RecvStream) { + async fn pre_process(conn: &Connection, recv: RecvStream) -> Result { + let task = conn.model.accept_uni_stream(recv).await?; + + if let Task::Authenticate(token) = &task { + if conn.is_authed() { + return Err(Error::DuplicatedAuth); + } else { + let mut buf = [0; 32]; + conn.inner + .export_keying_material(&mut buf, &conn.token, &conn.token) + .map_err(|_| Error::ExportKeyingMaterial)?; + + if token == &buf { + conn.set_authed(); + } else { + return Err(Error::AuthFailed); + } + } + } + + tokio::select! { + () = conn.authed() => {} + err = conn.inner.closed() => Err(err)?, + }; + + let same_pkt_src = matches!(task, Task::Packet(_)) + && matches!(conn.get_udp_relay_mode(), Some(UdpRelayMode::Native)); + if same_pkt_src { + return Err(Error::UnexpectedPacketSource); + } + + Ok(task) + } + + match pre_process(&self, recv).await { + Ok(Task::Packet(pkt)) => todo!(), + Ok(Task::Dissociate(assoc_id)) => todo!(), + Ok(_) => unreachable!(), + Err(err) => { + eprintln!("{err}"); + self.inner.close(VarInt::from_u32(0), b""); + return; + } + } + } + + async fn handle_bi_stream(self, (send, recv): (SendStream, RecvStream)) { + async fn pre_process( + conn: &Connection, + send: SendStream, + recv: RecvStream, + ) -> Result { + let task = conn.model.accept_bi_stream(send, recv).await?; + + tokio::select! { + () = conn.authed() => {} + err = conn.inner.closed() => Err(err)?, + }; + + Ok(task) + } + + match pre_process(&self, send, recv).await { + Ok(Task::Connect(conn)) => todo!(), + Ok(_) => unreachable!(), + Err(err) => { + eprintln!("{err}"); + self.inner.close(VarInt::from_u32(0), b""); + return; + } + } + } + + async fn handle_datagram(self, dg: Bytes) { + async fn pre_process(conn: &Connection, dg: Bytes) -> Result { + let task = conn.model.accept_datagram(dg)?; + + tokio::select! { + () = conn.authed() => {} + err = conn.inner.closed() => Err(err)?, + }; + + let same_pkt_src = matches!(task, Task::Packet(_)) + && matches!(conn.get_udp_relay_mode(), Some(UdpRelayMode::Quic)); + if same_pkt_src { + return Err(Error::UnexpectedPacketSource); + } + + Ok(task) + } + + match pre_process(&self, dg).await { + Ok(Task::Packet(pkt)) => todo!(), + Ok(Task::Heartbeat) => todo!(), + Ok(_) => unreachable!(), + Err(err) => { + eprintln!("{err}"); + self.inner.close(VarInt::from_u32(0), b""); + return; + } + } + } + + fn set_authed(&self) { + self.is_authed.set_authed(); + } + + fn is_authed(&self) -> bool { + self.is_authed.is_authed() + } + + fn authed(&self) -> IsAuthed { + self.is_authed.clone() + } + + fn set_udp_relay_mode(&self, mode: UdpRelayMode) { + self.udp_relay_mode.store(Some(mode)); + } + + fn get_udp_relay_mode(&self) -> Option { + self.udp_relay_mode.load() + } + + fn is_closed(&self) -> bool { + self.inner.close_reason().is_some() + } +} + +#[derive(Clone)] +struct IsAuthed { + is_authed: Arc, + broadcast: Arc>>, +} + +impl IsAuthed { + fn new() -> Self { + Self { + is_authed: Arc::new(AtomicBool::new(false)), + broadcast: Arc::new(Mutex::new(Vec::new())), + } + } + + fn set_authed(&self) { + self.is_authed.store(true, Ordering::Release); + + for waker in self.broadcast.lock().drain(..) { + waker.wake(); + } + } + + fn is_authed(&self) -> bool { + self.is_authed.load(Ordering::Relaxed) + } +} + +impl Future for IsAuthed { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.is_authed.load(Ordering::Relaxed) { + Poll::Ready(()) + } else { + self.broadcast.lock().push(cx.waker().clone()); + Poll::Pending + } + } +} diff --git a/tuic-server/src/utils.rs b/tuic-server/src/utils.rs new file mode 100644 index 0000000..d6849a2 --- /dev/null +++ b/tuic-server/src/utils.rs @@ -0,0 +1,29 @@ +use std::str::FromStr; + +#[derive(Clone, Copy)] +pub enum UdpRelayMode { + Native, + Quic, +} + +pub enum CongestionControl { + Cubic, + NewReno, + Bbr, +} + +impl FromStr for CongestionControl { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + if s.eq_ignore_ascii_case("cubic") { + Ok(Self::Cubic) + } else if s.eq_ignore_ascii_case("new_reno") || s.eq_ignore_ascii_case("newreno") { + Ok(Self::NewReno) + } else if s.eq_ignore_ascii_case("bbr") { + Ok(Self::Bbr) + } else { + Err("invalid congestion control") + } + } +} From 65b3df1a2ffbecbd08dabca9319210244363dacf Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 01:02:48 +0900 Subject: [PATCH 092/103] implement UDP packet sending on server --- tuic-server/src/main.rs | 7 +- tuic-server/src/server.rs | 193 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 190 insertions(+), 10 deletions(-) diff --git a/tuic-server/src/main.rs b/tuic-server/src/main.rs index d52e01a..26c7dc5 100644 --- a/tuic-server/src/main.rs +++ b/tuic-server/src/main.rs @@ -2,9 +2,10 @@ use self::{ config::{Config, ConfigError}, server::Server, }; -use quinn::{crypto::ExportKeyingMaterialError, ConnectionError}; -use std::{env, io::Error as IoError, process}; +use quinn::ConnectionError; +use std::{env, io::Error as IoError, net::SocketAddr, process}; use thiserror::Error; +use tuic::Address; use tuic_quinn::Error as ModelError; mod config; @@ -50,4 +51,6 @@ pub enum Error { AuthFailed, #[error("received packet from unexpected source")] UnexpectedPacketSource, + #[error("{0} resolved to {1} but IPv6 UDP relay disabled")] + UdpRelayIpv6Disabled(Address, SocketAddr), } diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 9101202..6969bc4 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -4,7 +4,10 @@ use crossbeam_utils::atomic::AtomicCell; use parking_lot::Mutex; use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; use std::{ + collections::{hash_map::Entry, HashMap}, future::Future, + io::{Error as IoError, ErrorKind}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, @@ -12,11 +15,22 @@ use std::{ }, task::{Context, Poll, Waker}, }; -use tuic_quinn::{side, Connection as Model, Task}; +use tokio::{ + io::{self, AsyncWriteExt}, + net::{self, TcpStream, UdpSocket}, + sync::{ + oneshot::{self, Receiver, Sender}, + Mutex as AsyncMutex, + }, +}; +use tokio_util::compat::FuturesAsyncReadCompatExt; +use tuic::Address; +use tuic_quinn::{side, Connect, Connection as Model, Packet, Task}; pub struct Server { ep: Endpoint, token: Arc<[u8]>, + udp_relay_ipv6: bool, zero_rtt_handshake: bool, } @@ -31,6 +45,7 @@ impl Server { tokio::spawn(Connection::init( conn, self.token.clone(), + self.udp_relay_ipv6, self.zero_rtt_handshake, )); } @@ -42,13 +57,20 @@ struct Connection { inner: QuinnConnection, model: Model, token: Arc<[u8]>, + udp_relay_ipv6: bool, is_authed: IsAuthed, + udp_sessions: Arc>>, udp_relay_mode: Arc>>, } impl Connection { - pub async fn init(conn: Connecting, token: Arc<[u8]>, zero_rtt_handshake: bool) { - match Self::handshake(conn, token, zero_rtt_handshake).await { + pub async fn init( + conn: Connecting, + token: Arc<[u8]>, + udp_relay_ipv6: bool, + zero_rtt_handshake: bool, + ) { + match Self::handshake(conn, token, udp_relay_ipv6, zero_rtt_handshake).await { Ok(conn) => loop { if conn.is_closed() { break; @@ -66,6 +88,7 @@ impl Connection { async fn handshake( conn: Connecting, token: Arc<[u8]>, + udp_relay_ipv6: bool, zero_rtt_handshake: bool, ) -> Result { let conn = if zero_rtt_handshake { @@ -84,7 +107,9 @@ impl Connection { inner: conn.clone(), model: Model::::new(conn), token, + udp_relay_ipv6, is_authed: IsAuthed::new(), + udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_relay_mode: Arc::new(AtomicCell::new(None)), }) } @@ -135,8 +160,17 @@ impl Connection { } match pre_process(&self, recv).await { - Ok(Task::Packet(pkt)) => todo!(), - Ok(Task::Dissociate(assoc_id)) => todo!(), + Ok(Task::Packet(pkt)) => { + self.set_udp_relay_mode(UdpRelayMode::Quic); + match self.handle_packet(pkt).await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + } + Ok(Task::Dissociate(assoc_id)) => match self.handle_dissociate(assoc_id).await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + }, Ok(_) => unreachable!(), Err(err) => { eprintln!("{err}"); @@ -163,7 +197,10 @@ impl Connection { } match pre_process(&self, send, recv).await { - Ok(Task::Connect(conn)) => todo!(), + Ok(Task::Connect(conn)) => match self.handle_connect(conn).await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + }, Ok(_) => unreachable!(), Err(err) => { eprintln!("{err}"); @@ -192,8 +229,14 @@ impl Connection { } match pre_process(&self, dg).await { - Ok(Task::Packet(pkt)) => todo!(), - Ok(Task::Heartbeat) => todo!(), + Ok(Task::Packet(pkt)) => { + self.set_udp_relay_mode(UdpRelayMode::Native); + match self.handle_packet(pkt).await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + } + Ok(Task::Heartbeat) => {} Ok(_) => unreachable!(), Err(err) => { eprintln!("{err}"); @@ -203,6 +246,77 @@ impl Connection { } } + async fn handle_connect(&self, conn: Connect) -> Result<(), Error> { + let mut stream = None; + let mut last_err = None; + + match resolve_dns(conn.addr()).await { + Ok(addrs) => { + for addr in addrs { + match TcpStream::connect(addr).await { + Ok(s) => { + stream = Some(s); + break; + } + Err(err) => last_err = Some(err), + } + } + } + Err(err) => last_err = Some(err), + } + + if let Some(mut stream) = stream { + let mut conn = conn.compat(); + let res = io::copy_bidirectional(&mut conn, &mut stream).await; + let _ = conn.shutdown().await; + let _ = stream.shutdown().await; + res?; + Ok(()) + } else { + let _ = conn.compat().shutdown().await; + Err(last_err + .unwrap_or_else(|| IoError::new(ErrorKind::NotFound, "no address resolved")))? + } + } + + async fn handle_packet(&self, pkt: Packet) -> Result<(), Error> { + let Some((pkt, addr, assoc_id)) = pkt.accept().await? else { + return Ok(()); + }; + + let (socket_v4, socket_v6) = match self.udp_sessions.lock().await.entry(assoc_id) { + Entry::Occupied(mut entry) => { + let session = entry.get_mut(); + (session.socket_v4.clone(), session.socket_v6.clone()) + } + Entry::Vacant(entry) => { + let session = entry + .insert(UdpSession::new(assoc_id, self.clone(), self.udp_relay_ipv6).await?); + (session.socket_v4.clone(), session.socket_v6.clone()) + } + }; + + let Some(socket_addr) = resolve_dns(&addr).await?.next() else { + Err(IoError::new(ErrorKind::NotFound, "no address resolved"))? + }; + + let socket = match socket_addr { + SocketAddr::V4(_) => socket_v4, + SocketAddr::V6(_) => { + socket_v6.ok_or_else(|| Error::UdpRelayIpv6Disabled(addr, socket_addr))? + } + }; + + socket.send_to(&pkt, socket_addr).await?; + + Ok(()) + } + + async fn handle_dissociate(&self, assoc_id: u16) -> Result<(), Error> { + self.udp_sessions.lock().await.remove(&assoc_id); + Ok(()) + } + fn set_authed(&self) { self.is_authed.set_authed(); } @@ -228,6 +342,69 @@ impl Connection { } } +async fn resolve_dns(addr: &Address) -> Result, IoError> { + match addr { + Address::None => Err(IoError::new(ErrorKind::InvalidInput, "empty address")), + Address::DomainAddress(domain, port) => Ok(net::lookup_host((domain.as_str(), *port)) + .await? + .collect::>() + .into_iter()), + Address::SocketAddress(addr) => Ok(vec![*addr].into_iter()), + } +} + +struct UdpSession { + socket_v4: Arc, + socket_v6: Option>, + cancel: Option>, +} + +impl UdpSession { + async fn new(assoc_id: u16, conn: Connection, udp_relay_ipv6: bool) -> Result { + let socket_v4 = + Arc::new(UdpSocket::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))).await?); + let socket_v6 = if udp_relay_ipv6 { + Some(Arc::new( + UdpSocket::bind(SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))).await?, + )) + } else { + None + }; + + let (tx, rx) = oneshot::channel(); + + tokio::spawn(Self::listen_incoming( + assoc_id, + conn, + socket_v4.clone(), + socket_v6.clone(), + rx, + )); + + Ok(Self { + socket_v4, + socket_v6, + cancel: Some(tx), + }) + } + + async fn listen_incoming( + assoc_id: u16, + conn: Connection, + socket_v4: Arc, + socket_v6: Option>, + cancel: Receiver<()>, + ) { + todo!() + } +} + +impl Drop for UdpSession { + fn drop(&mut self) { + let _ = self.cancel.take().unwrap().send(()); + } +} + #[derive(Clone)] struct IsAuthed { is_authed: Arc, From 8306a7f0612dcf26984e921653249f7eb2071679 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 01:59:44 +0900 Subject: [PATCH 093/103] implement UDP packet receiving on server --- tuic-server/src/server.rs | 59 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 55 insertions(+), 4 deletions(-) diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 6969bc4..baf176c 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -174,7 +174,7 @@ impl Connection { Ok(_) => unreachable!(), Err(err) => { eprintln!("{err}"); - self.inner.close(VarInt::from_u32(0), b""); + self.close(); return; } } @@ -204,7 +204,7 @@ impl Connection { Ok(_) => unreachable!(), Err(err) => { eprintln!("{err}"); - self.inner.close(VarInt::from_u32(0), b""); + self.close(); return; } } @@ -240,7 +240,7 @@ impl Connection { Ok(_) => unreachable!(), Err(err) => { eprintln!("{err}"); - self.inner.close(VarInt::from_u32(0), b""); + self.close(); return; } } @@ -340,6 +340,10 @@ impl Connection { fn is_closed(&self) -> bool { self.inner.close_reason().is_some() } + + fn close(&self) { + self.inner.close(VarInt::from_u32(0), b""); + } } async fn resolve_dns(addr: &Address) -> Result, IoError> { @@ -395,7 +399,54 @@ impl UdpSession { socket_v6: Option>, cancel: Receiver<()>, ) { - todo!() + async fn send_pkt(conn: Connection, pkt: Bytes, addr: SocketAddr, assoc_id: u16) { + let addr = Address::SocketAddress(addr); + + let res = match conn.get_udp_relay_mode() { + Some(UdpRelayMode::Native) => conn.model.packet_native(pkt, addr, assoc_id), + Some(UdpRelayMode::Quic) => conn.model.packet_quic(pkt, addr, assoc_id).await, + None => unreachable!(), + }; + + if let Err(err) = res { + eprintln!("{err}"); + } + } + + tokio::select! { + _ = cancel => {} + () = async { + loop { + match Self::accept(&socket_v4, socket_v6.as_deref()).await { + Ok((pkt, addr)) => { + tokio::spawn(send_pkt(conn.clone(), pkt, addr, assoc_id)); + } + Err(err) => eprintln!("{err}"), + } + } + } => unreachable!(), + } + } + + async fn accept( + socket_v4: &UdpSocket, + socket_v6: Option<&UdpSocket>, + ) -> Result<(Bytes, SocketAddr), IoError> { + async fn read_packet(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> { + let mut buf = vec![0u8; 65535]; + let (n, addr) = socket.recv_from(&mut buf).await?; + buf.truncate(n); + Ok((Bytes::from(buf), addr)) + } + + if let Some(socket_v6) = socket_v6 { + tokio::select! { + res = read_packet(socket_v4) => res, + res = read_packet(socket_v6) => res, + } + } else { + read_packet(socket_v4).await + } } } From 7781f5c62a9ef8af2dd81db68a5df6144c3cf9ca Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 16:17:54 +0900 Subject: [PATCH 094/103] auto increase max concurrent stream count --- tuic-server/src/server.rs | 55 ++++++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 10 deletions(-) diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index baf176c..70c6f8b 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -3,6 +3,7 @@ use bytes::Bytes; use crossbeam_utils::atomic::AtomicCell; use parking_lot::Mutex; use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; +use register_count::{Counter, Register}; use std::{ collections::{hash_map::Entry, HashMap}, future::Future, @@ -10,7 +11,7 @@ use std::{ net::{Ipv4Addr, Ipv6Addr, SocketAddr}, pin::Pin, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, }, task::{Context, Poll, Waker}, @@ -27,6 +28,8 @@ use tokio_util::compat::FuturesAsyncReadCompatExt; use tuic::Address; use tuic_quinn::{side, Connect, Connection as Model, Packet, Task}; +const DEFAULT_CONCURRENT_STREAMS: usize = 32; + pub struct Server { ep: Endpoint, token: Arc<[u8]>, @@ -61,6 +64,10 @@ struct Connection { is_authed: IsAuthed, udp_sessions: Arc>>, udp_relay_mode: Arc>>, + remote_uni_stream_cnt: Counter, + remote_bi_stream_cnt: Counter, + max_concurrent_uni_streams: Arc, + max_concurrent_bi_streams: Arc, } impl Connection { @@ -111,20 +118,37 @@ impl Connection { is_authed: IsAuthed::new(), udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_relay_mode: Arc::new(AtomicCell::new(None)), + remote_uni_stream_cnt: Counter::new(), + remote_bi_stream_cnt: Counter::new(), + max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), + max_concurrent_bi_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), }) } async fn accept(&self) -> Result<(), Error> { tokio::select! { - res = self.inner.accept_uni() => tokio::spawn(self.clone().handle_uni_stream(res?)), - res = self.inner.accept_bi() => tokio::spawn(self.clone().handle_bi_stream(res?)), - res = self.inner.read_datagram() => tokio::spawn(self.clone().handle_datagram(res?)), + res = self.inner.accept_uni() => + tokio::spawn(self.clone().handle_uni_stream(res?, self.remote_uni_stream_cnt.reg())), + res = self.inner.accept_bi() => + tokio::spawn(self.clone().handle_bi_stream(res?, self.remote_bi_stream_cnt.reg())), + res = self.inner.read_datagram() => + tokio::spawn(self.clone().handle_datagram(res?)), }; Ok(()) } - async fn handle_uni_stream(self, recv: RecvStream) { + async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { + let max = self.max_concurrent_uni_streams.load(Ordering::Relaxed); + + if self.remote_uni_stream_cnt.count() == max { + self.max_concurrent_uni_streams + .store(max * 2, Ordering::Relaxed); + + self.inner + .set_max_concurrent_uni_streams(VarInt::from((max * 2) as u32)); + } + async fn pre_process(conn: &Connection, recv: RecvStream) -> Result { let task = conn.model.accept_uni_stream(recv).await?; @@ -133,6 +157,7 @@ impl Connection { return Err(Error::DuplicatedAuth); } else { let mut buf = [0; 32]; + conn.inner .export_keying_material(&mut buf, &conn.token, &conn.token) .map_err(|_| Error::ExportKeyingMaterial)?; @@ -180,7 +205,17 @@ impl Connection { } } - async fn handle_bi_stream(self, (send, recv): (SendStream, RecvStream)) { + async fn handle_bi_stream(self, (send, recv): (SendStream, RecvStream), _reg: Register) { + let max = self.max_concurrent_bi_streams.load(Ordering::Relaxed); + + if self.remote_bi_stream_cnt.count() == max { + self.max_concurrent_bi_streams + .store(max * 2, Ordering::Relaxed); + + self.inner + .set_max_concurrent_bi_streams(VarInt::from((max * 2) as u32)); + } + async fn pre_process( conn: &Connection, send: SendStream, @@ -432,7 +467,7 @@ impl UdpSession { socket_v4: &UdpSocket, socket_v6: Option<&UdpSocket>, ) -> Result<(Bytes, SocketAddr), IoError> { - async fn read_packet(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> { + async fn read_pkt(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> { let mut buf = vec![0u8; 65535]; let (n, addr) = socket.recv_from(&mut buf).await?; buf.truncate(n); @@ -441,11 +476,11 @@ impl UdpSession { if let Some(socket_v6) = socket_v6 { tokio::select! { - res = read_packet(socket_v4) => res, - res = read_packet(socket_v6) => res, + res = read_pkt(socket_v4) => res, + res = read_pkt(socket_v6) => res, } } else { - read_packet(socket_v4).await + read_pkt(socket_v4).await } } } From 8cf357012d19947cb9b7ad82533134fe1a890d75 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 16:29:02 +0900 Subject: [PATCH 095/103] add auth timeout & gc on server --- tuic-server/src/server.rs | 61 +++++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 70c6f8b..093585f 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -15,6 +15,7 @@ use std::{ Arc, }, task::{Context, Poll, Waker}, + time::Duration, }; use tokio::{ io::{self, AsyncWriteExt}, @@ -23,6 +24,7 @@ use tokio::{ oneshot::{self, Receiver, Sender}, Mutex as AsyncMutex, }, + time, }; use tokio_util::compat::FuturesAsyncReadCompatExt; use tuic::Address; @@ -35,6 +37,9 @@ pub struct Server { token: Arc<[u8]>, udp_relay_ipv6: bool, zero_rtt_handshake: bool, + auth_timeout: Duration, + gc_interval: Duration, + gc_lifetime: Duration, } impl Server { @@ -45,11 +50,15 @@ impl Server { pub async fn start(&self) { loop { let conn = self.ep.accept().await.unwrap(); - tokio::spawn(Connection::init( + + tokio::spawn(Connection::new( conn, self.token.clone(), self.udp_relay_ipv6, self.zero_rtt_handshake, + self.auth_timeout, + self.gc_interval, + self.gc_lifetime, )); } } @@ -71,28 +80,36 @@ struct Connection { } impl Connection { - pub async fn init( + async fn new( conn: Connecting, token: Arc<[u8]>, udp_relay_ipv6: bool, zero_rtt_handshake: bool, + auth_timeout: Duration, + gc_interval: Duration, + gc_lifetime: Duration, ) { - match Self::handshake(conn, token, udp_relay_ipv6, zero_rtt_handshake).await { - Ok(conn) => loop { - if conn.is_closed() { - break; - } + match Self::init(conn, token, udp_relay_ipv6, zero_rtt_handshake).await { + Ok(conn) => { + tokio::spawn(conn.clone().handle_auth_timeout(auth_timeout)); + tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime)); - match conn.accept().await { - Ok(()) => {} - Err(err) => eprintln!("{err}"), + loop { + if conn.is_closed() { + break; + } + + match conn.accept().await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } } - }, + } Err(err) => eprintln!("{err}"), } } - async fn handshake( + async fn init( conn: Connecting, token: Arc<[u8]>, udp_relay_ipv6: bool, @@ -352,6 +369,26 @@ impl Connection { Ok(()) } + async fn handle_auth_timeout(self, timeout: Duration) { + time::sleep(timeout).await; + + if !self.is_authed() { + self.close(); + } + } + + async fn collect_garbage(self, gc_interval: Duration, gc_lifetime: Duration) { + loop { + time::sleep(gc_interval).await; + + if self.is_closed() { + break; + } + + self.model.collect_garbage(gc_lifetime); + } + } + fn set_authed(&self) { self.is_authed.set_authed(); } From 41e9489f3055d793cfc185a7d8827af4c6bab128 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 16:35:38 +0900 Subject: [PATCH 096/103] prevent system from rebinding UDP socket too early --- tuic-client/src/connection.rs | 5 ++++- tuic-client/src/socks5.rs | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index 81f94c6..fe375d4 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -137,7 +137,10 @@ impl Endpoint { let conn = if zero_rtt_handshake { match conn.into_0rtt() { Ok((conn, _)) => conn, - Err(conn) => conn.await?, + Err(conn) => { + eprintln!("0-RTT handshake failed, fallback to 1-RTT handshake"); + conn.await? + } } } else { conn.await? diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index 2b88217..c14ef47 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -125,7 +125,6 @@ impl Server { socket.set_only_v6(!dual_stack)?; } - socket.set_reuse_address(true)?; socket.bind(&SockAddr::from(SERVER.get().unwrap().addr))?; let socket = AssociatedUdpSocket::from(( From f4dfa75e4c42ae99bbdafd8bf496538b9ed6c30c Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 16:41:56 +0900 Subject: [PATCH 097/103] mark outbound UDP socket IPv6 only --- tuic-server/src/server.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 093585f..88537e1 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -4,11 +4,12 @@ use crossbeam_utils::atomic::AtomicCell; use parking_lot::Mutex; use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; use register_count::{Counter, Register}; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use std::{ collections::{hash_map::Entry, HashMap}, future::Future, io::{Error as IoError, ErrorKind}, - net::{Ipv4Addr, Ipv6Addr, SocketAddr}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, pin::Pin, sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -440,9 +441,13 @@ impl UdpSession { let socket_v4 = Arc::new(UdpSocket::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))).await?); let socket_v6 = if udp_relay_ipv6 { - Some(Arc::new( - UdpSocket::bind(SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))).await?, - )) + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(true)?; + socket.bind(&SockAddr::from(SocketAddr::from(( + Ipv6Addr::UNSPECIFIED, + 0, + ))))?; + Some(Arc::new(UdpSocket::from_std(StdUdpSocket::from(socket))?)) } else { None }; From 53c95d9860e8ad1560fa5ec616636f55a25fc4a9 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 17:03:52 +0900 Subject: [PATCH 098/103] deserde certs & priv_keys path as `PathBuf` --- tuic-client/src/config.rs | 7 ++++--- tuic-client/src/utils.rs | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tuic-client/src/config.rs b/tuic-client/src/config.rs index 8c2a0b6..14c4845 100644 --- a/tuic-client/src/config.rs +++ b/tuic-client/src/config.rs @@ -8,6 +8,7 @@ use std::{ fs::File, io::Error as IoError, net::{IpAddr, SocketAddr}, + path::PathBuf, str::FromStr, time::Duration, }; @@ -37,7 +38,7 @@ pub struct Relay { pub token: String, pub ip: Option, #[serde(default = "default::relay::certificates")] - pub certificates: Vec, + pub certificates: Vec, #[serde( default = "default::relay::udp_relay_mode", deserialize_with = "deserialize_from_str" @@ -111,9 +112,9 @@ impl Config { mod default { pub mod relay { use crate::utils::{CongestionControl, UdpRelayMode}; - use std::time::Duration; + use std::{path::PathBuf, time::Duration}; - pub fn certificates() -> Vec { + pub fn certificates() -> Vec { Vec::new() } diff --git a/tuic-client/src/utils.rs b/tuic-client/src/utils.rs index 752184a..e84a0cc 100644 --- a/tuic-client/src/utils.rs +++ b/tuic-client/src/utils.rs @@ -5,11 +5,12 @@ use std::{ fs::{self, File}, io::BufReader, net::{IpAddr, SocketAddr}, + path::PathBuf, str::FromStr, }; use tokio::net; -pub fn load_certs(paths: Vec, disable_native: bool) -> Result { +pub fn load_certs(paths: Vec, disable_native: bool) -> Result { let mut certs = RootCertStore::empty(); for path in &paths { From 71bc8e9b2bd2e47f98bcf4da0673f9730503cb6d Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 18:00:20 +0900 Subject: [PATCH 099/103] implement server config reading --- tuic-server/src/config.rs | 76 +++++++++++++++++++++++- tuic-server/src/main.rs | 5 ++ tuic-server/src/server.rs | 122 ++++++++++++++++++++++++++++++++++---- tuic-server/src/utils.rs | 42 ++++++++++++- 4 files changed, 230 insertions(+), 15 deletions(-) diff --git a/tuic-server/src/config.rs b/tuic-server/src/config.rs index c08805d..89ff4e1 100644 --- a/tuic-server/src/config.rs +++ b/tuic-server/src/config.rs @@ -1,7 +1,11 @@ +use crate::utils::CongestionControl; use lexopt::{Arg, Error as ArgumentError, Parser}; use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde_json::Error as SerdeError; -use std::{env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, str::FromStr}; +use std::{ + env::ArgsOs, fmt::Display, fs::File, io::Error as IoError, net::SocketAddr, path::PathBuf, + str::FromStr, time::Duration, +}; use thiserror::Error; const HELP_MSG: &str = r#" @@ -15,7 +19,34 @@ Arguments: #[derive(Deserialize)] #[serde(deny_unknown_fields)] -pub struct Config {} +pub struct Config { + pub server: SocketAddr, + pub token: String, + pub certificate: PathBuf, + pub private_key: PathBuf, + #[serde( + default = "default::congestion_control", + deserialize_with = "deserialize_from_str" + )] + pub congestion_control: CongestionControl, + #[serde(default = "default::alpn")] + pub alpn: Vec, + #[serde(default = "default::udp_relay_ipv6")] + pub udp_relay_ipv6: bool, + #[serde(default = "default::zero_rtt_handshake")] + pub zero_rtt_handshake: bool, + pub dual_stack: Option, + #[serde(default = "default::auth_timeout")] + pub auth_timeout: Duration, + #[serde(default = "default::max_idle_time")] + pub max_idle_time: Duration, + #[serde(default = "default::max_external_packet_size")] + pub max_external_packet_size: usize, + #[serde(default = "default::gc_interval")] + pub gc_interval: Duration, + #[serde(default = "default::gc_lifetime")] + pub gc_lifetime: Duration, +} impl Config { pub fn parse(args: ArgsOs) -> Result { @@ -48,7 +79,46 @@ impl Config { } } -mod default {} +mod default { + use crate::utils::CongestionControl; + use std::time::Duration; + + pub fn congestion_control() -> CongestionControl { + CongestionControl::Cubic + } + + pub fn alpn() -> Vec { + Vec::new() + } + + pub fn udp_relay_ipv6() -> bool { + true + } + + pub fn zero_rtt_handshake() -> bool { + false + } + + pub fn auth_timeout() -> Duration { + Duration::from_secs(3) + } + + pub fn max_idle_time() -> Duration { + Duration::from_secs(15) + } + + pub fn max_external_packet_size() -> usize { + 1500 + } + + pub fn gc_interval() -> Duration { + Duration::from_secs(3) + } + + pub fn gc_lifetime() -> Duration { + Duration::from_secs(15) + } +} pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result where diff --git a/tuic-server/src/main.rs b/tuic-server/src/main.rs index 26c7dc5..31745c2 100644 --- a/tuic-server/src/main.rs +++ b/tuic-server/src/main.rs @@ -3,6 +3,7 @@ use self::{ server::Server, }; use quinn::ConnectionError; +use rustls::Error as RustlsError; use std::{env, io::Error as IoError, net::SocketAddr, process}; use thiserror::Error; use tuic::Address; @@ -40,6 +41,10 @@ pub enum Error { #[error(transparent)] Io(#[from] IoError), #[error(transparent)] + Rustls(#[from] RustlsError), + #[error("invalid max idle time")] + InvalidMaxIdleTime, + #[error(transparent)] Connection(#[from] ConnectionError), #[error(transparent)] Model(#[from] ModelError), diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 88537e1..892693f 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -1,15 +1,24 @@ -use crate::{config::Config, utils::UdpRelayMode, Error}; +use crate::{ + config::Config, + utils::{self, CongestionControl, UdpRelayMode}, + Error, +}; use bytes::Bytes; use crossbeam_utils::atomic::AtomicCell; use parking_lot::Mutex; -use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; +use quinn::{ + congestion::{BbrConfig, CubicConfig, NewRenoConfig}, + Connecting, Connection as QuinnConnection, Endpoint, EndpointConfig, IdleTimeout, RecvStream, + SendStream, ServerConfig, TokioRuntime, TransportConfig, VarInt, +}; use register_count::{Counter, Register}; +use rustls::{version, ServerConfig as RustlsServerConfig}; use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use std::{ collections::{hash_map::Entry, HashMap}, future::Future, io::{Error as IoError, ErrorKind}, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, pin::Pin, sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -39,13 +48,83 @@ pub struct Server { udp_relay_ipv6: bool, zero_rtt_handshake: bool, auth_timeout: Duration, + max_external_pkt_size: usize, gc_interval: Duration, gc_lifetime: Duration, } impl Server { pub fn init(cfg: Config) -> Result { - todo!() + let certs = utils::load_certs(cfg.certificate)?; + let priv_key = utils::load_priv_key(cfg.private_key)?; + + let mut crypto = RustlsServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&version::TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(certs, priv_key)?; + + crypto.alpn_protocols = cfg.alpn.into_iter().map(|alpn| alpn.into_bytes()).collect(); + crypto.max_early_data_size = u32::MAX; + crypto.send_half_rtt_data = cfg.zero_rtt_handshake; + + let mut config = ServerConfig::with_crypto(Arc::new(crypto)); + let mut tp_cfg = TransportConfig::default(); + + tp_cfg + .max_concurrent_bidi_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32)) + .max_concurrent_uni_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32)) + .max_idle_timeout(Some( + IdleTimeout::try_from(cfg.max_idle_time).map_err(|_| Error::InvalidMaxIdleTime)?, + )); + + match cfg.congestion_control { + CongestionControl::Cubic => { + tp_cfg.congestion_controller_factory(Arc::new(CubicConfig::default())) + } + CongestionControl::NewReno => { + tp_cfg.congestion_controller_factory(Arc::new(NewRenoConfig::default())) + } + CongestionControl::Bbr => { + tp_cfg.congestion_controller_factory(Arc::new(BbrConfig::default())) + } + }; + + config.transport_config(Arc::new(tp_cfg)); + + let domain = match cfg.server.ip() { + IpAddr::V4(_) => Domain::IPV4, + IpAddr::V6(_) => Domain::IPV6, + }; + + let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + + if let Some(dual_stack) = cfg.dual_stack { + socket.set_only_v6(!dual_stack)?; + } + + socket.bind(&SockAddr::from(cfg.server))?; + let socket = StdUdpSocket::from(socket); + + let ep = Endpoint::new( + EndpointConfig::default(), + Some(config), + socket, + TokioRuntime, + )?; + + Ok(Self { + ep, + token: Arc::from(cfg.token.into_bytes().into_boxed_slice()), + udp_relay_ipv6: cfg.udp_relay_ipv6, + zero_rtt_handshake: cfg.zero_rtt_handshake, + auth_timeout: cfg.auth_timeout, + max_external_pkt_size: cfg.max_external_packet_size, + gc_interval: cfg.gc_interval, + gc_lifetime: cfg.gc_lifetime, + }) } pub async fn start(&self) { @@ -58,6 +137,7 @@ impl Server { self.udp_relay_ipv6, self.zero_rtt_handshake, self.auth_timeout, + self.max_external_pkt_size, self.gc_interval, self.gc_lifetime, )); @@ -74,6 +154,7 @@ struct Connection { is_authed: IsAuthed, udp_sessions: Arc>>, udp_relay_mode: Arc>>, + max_external_pkt_size: usize, remote_uni_stream_cnt: Counter, remote_bi_stream_cnt: Counter, max_concurrent_uni_streams: Arc, @@ -87,10 +168,19 @@ impl Connection { udp_relay_ipv6: bool, zero_rtt_handshake: bool, auth_timeout: Duration, + max_external_pkt_size: usize, gc_interval: Duration, gc_lifetime: Duration, ) { - match Self::init(conn, token, udp_relay_ipv6, zero_rtt_handshake).await { + match Self::init( + conn, + token, + udp_relay_ipv6, + zero_rtt_handshake, + max_external_pkt_size, + ) + .await + { Ok(conn) => { tokio::spawn(conn.clone().handle_auth_timeout(auth_timeout)); tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime)); @@ -115,6 +205,7 @@ impl Connection { token: Arc<[u8]>, udp_relay_ipv6: bool, zero_rtt_handshake: bool, + max_external_pkt_size: usize, ) -> Result { let conn = if zero_rtt_handshake { match conn.into_0rtt() { @@ -136,6 +227,7 @@ impl Connection { is_authed: IsAuthed::new(), udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_relay_mode: Arc::new(AtomicCell::new(None)), + max_external_pkt_size, remote_uni_stream_cnt: Counter::new(), remote_bi_stream_cnt: Counter::new(), max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), @@ -494,7 +586,11 @@ impl UdpSession { _ = cancel => {} () = async { loop { - match Self::accept(&socket_v4, socket_v6.as_deref()).await { + match Self::accept( + &socket_v4, + socket_v6.as_deref(), + conn.max_external_pkt_size, + ).await { Ok((pkt, addr)) => { tokio::spawn(send_pkt(conn.clone(), pkt, addr, assoc_id)); } @@ -508,9 +604,13 @@ impl UdpSession { async fn accept( socket_v4: &UdpSocket, socket_v6: Option<&UdpSocket>, + max_pkt_size: usize, ) -> Result<(Bytes, SocketAddr), IoError> { - async fn read_pkt(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> { - let mut buf = vec![0u8; 65535]; + async fn read_pkt( + socket: &UdpSocket, + max_pkt_size: usize, + ) -> Result<(Bytes, SocketAddr), IoError> { + let mut buf = vec![0u8; max_pkt_size]; let (n, addr) = socket.recv_from(&mut buf).await?; buf.truncate(n); Ok((Bytes::from(buf), addr)) @@ -518,11 +618,11 @@ impl UdpSession { if let Some(socket_v6) = socket_v6 { tokio::select! { - res = read_pkt(socket_v4) => res, - res = read_pkt(socket_v6) => res, + res = read_pkt(socket_v4, max_pkt_size) => res, + res = read_pkt(socket_v6, max_pkt_size) => res, } } else { - read_pkt(socket_v4).await + read_pkt(socket_v4, max_pkt_size).await } } } diff --git a/tuic-server/src/utils.rs b/tuic-server/src/utils.rs index d6849a2..31061c7 100644 --- a/tuic-server/src/utils.rs +++ b/tuic-server/src/utils.rs @@ -1,4 +1,44 @@ -use std::str::FromStr; +use rustls::{Certificate, PrivateKey}; +use rustls_pemfile::Item; +use std::{ + fs::{self, File}, + io::{BufReader, Error as IoError}, + path::PathBuf, + str::FromStr, +}; + +pub fn load_certs(path: PathBuf) -> Result, IoError> { + let mut file = BufReader::new(File::open(&path)?); + let mut certs = Vec::new(); + + while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) { + if let Item::X509Certificate(cert) = item { + certs.push(Certificate(cert)); + } + } + + if certs.is_empty() { + certs = vec![Certificate(fs::read(&path)?)]; + } + + Ok(certs) +} + +pub fn load_priv_key(path: PathBuf) -> Result { + let mut file = BufReader::new(File::open(&path)?); + let mut priv_key = None; + + while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) { + if let Item::RSAKey(key) | Item::PKCS8Key(key) | Item::ECKey(key) = item { + priv_key = Some(key); + } + } + + priv_key + .map(Ok) + .unwrap_or_else(|| fs::read(&path)) + .map(PrivateKey) +} #[derive(Clone, Copy)] pub enum UdpRelayMode { From 7b690e11963137294005ce99cf9722a83b641e03 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 18:17:27 +0900 Subject: [PATCH 100/103] remove unused dependencies --- tuic-server/Cargo.toml | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tuic-server/Cargo.toml b/tuic-server/Cargo.toml index cb33da6..d527b5c 100644 --- a/tuic-server/Cargo.toml +++ b/tuic-server/Cargo.toml @@ -7,21 +7,16 @@ edition = "2021" bytes = { version = "1.3.0", default-features = false, features = ["std"] } crossbeam-utils = { version = "0.8.14", default-features = false, features = ["std"] } lexopt = { version = "0.3.0", default-features = false } -once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] } -parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] } +parking_lot = { version = "0.12.1", default-features = false } quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] } register-count = { version = "0.1.0", default-features = false, features = ["std"] } rustls = { version = "0.20.8", default-features = false, features = ["quic"] } -rustls-native-certs = { version = "0.6.2", default-features = false } rustls-pemfile = { version = "1.0.2", default-features = false } serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] } serde_json = { version = "1.0.91", default-features = false, features = ["std"] } socket2 = { version = "0.4.7", default-features = false } -socks5-proto = { version = "0.3.3", default-features = false } -socks5-server = { version = "0.8.3", default-features = false } thiserror = { version = "1.0.38", default-features = false } tokio = { version = "1.24.2", default-features = false, features = ["macros", "net", "parking_lot", "rt-multi-thread", "time"] } tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] } tuic = { path = "../tuic", default-features = false } -tuic-quinn = { path = "../tuic-quinn", default-features = false } -webpki = { version = "0.22.0", default-features = false } \ No newline at end of file +tuic-quinn = { path = "../tuic-quinn", default-features = false } \ No newline at end of file From 45011d498a82d7993ca76174afdd64b19f70038e Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 18:22:06 +0900 Subject: [PATCH 101/103] update dependencies --- tuic-client/Cargo.toml | 4 ++-- tuic-quinn/Cargo.toml | 4 ++-- tuic-server/Cargo.toml | 4 ++-- tuic/Cargo.toml | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tuic-client/Cargo.toml b/tuic-client/Cargo.toml index b291e7f..68c5ab5 100644 --- a/tuic-client/Cargo.toml +++ b/tuic-client/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -bytes = { version = "1.3.0", default-features = false, features = ["std"] } +bytes = { version = "1.4.0", default-features = false, features = ["std"] } crossbeam-utils = { version = "0.8.14", default-features = false, features = ["std"] } lexopt = { version = "0.3.0", default-features = false } once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] } @@ -20,7 +20,7 @@ socket2 = { version = "0.4.7", default-features = false } socks5-proto = { version = "0.3.3", default-features = false } socks5-server = { version = "0.8.3", default-features = false } thiserror = { version = "1.0.38", default-features = false } -tokio = { version = "1.24.2", default-features = false, features = ["macros", "net", "parking_lot", "rt-multi-thread", "time"] } +tokio = { version = "1.25.0", default-features = false, features = ["macros", "net", "parking_lot", "rt-multi-thread", "time"] } tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] } tuic = { path = "../tuic", default-features = false } tuic-quinn = { path = "../tuic-quinn", default-features = false } diff --git a/tuic-quinn/Cargo.toml b/tuic-quinn/Cargo.toml index fa901c1..651aebc 100644 --- a/tuic-quinn/Cargo.toml +++ b/tuic-quinn/Cargo.toml @@ -4,8 +4,8 @@ version = "0.1.0" edition = "2021" [dependencies] -bytes = { version = "1.3.0", default-features = false, features = ["std"] } -futures-util = { version = "0.3.25", default-features = false, features = ["io", "std"] } +bytes = { version = "1.4.0", default-features = false, features = ["std"] } +futures-util = { version = "0.3.26", default-features = false, features = ["io", "std"] } quinn = { version = "0.9.3", default-features = false, features = ["futures-io"] } thiserror = { version = "1.0.38", default-features = false } tuic = { path = "../tuic", default-features = false, features = ["async_marshal", "marshal", "model"] } diff --git a/tuic-server/Cargo.toml b/tuic-server/Cargo.toml index d527b5c..b244f8a 100644 --- a/tuic-server/Cargo.toml +++ b/tuic-server/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] -bytes = { version = "1.3.0", default-features = false, features = ["std"] } +bytes = { version = "1.4.0", default-features = false, features = ["std"] } crossbeam-utils = { version = "0.8.14", default-features = false, features = ["std"] } lexopt = { version = "0.3.0", default-features = false } parking_lot = { version = "0.12.1", default-features = false } @@ -16,7 +16,7 @@ serde = { version = "1.0.152", default-features = false, features = ["derive", " serde_json = { version = "1.0.91", default-features = false, features = ["std"] } socket2 = { version = "0.4.7", default-features = false } thiserror = { version = "1.0.38", default-features = false } -tokio = { version = "1.24.2", default-features = false, features = ["macros", "net", "parking_lot", "rt-multi-thread", "time"] } +tokio = { version = "1.25.0", default-features = false, features = ["macros", "net", "parking_lot", "rt-multi-thread", "time"] } tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] } tuic = { path = "../tuic", default-features = false } tuic-quinn = { path = "../tuic-quinn", default-features = false } \ No newline at end of file diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index 7aff4d1..d7f8606 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -9,8 +9,8 @@ marshal = ["bytes"] model = ["parking_lot", "register-count", "thiserror"] [dependencies] -bytes = { version = "1.3.0", default-features = false, features = ["std"], optional = true } -futures-util = { version = "0.3.25", default-features = false, features = ["io", "std"], optional = true } +bytes = { version = "1.4.0", default-features = false, features = ["std"], optional = true } +futures-util = { version = "0.3.26", default-features = false, features = ["io", "std"], optional = true } parking_lot = { version = "0.12.1", default-features = false, optional = true } register-count = { version = "0.1.0", default-features = false, features = ["std"], optional = true } thiserror = { version = "1.0.38", default-features = false, optional = true } From b2dbff36112eddfcae28505b52d6a6ad20040203 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 18:39:25 +0900 Subject: [PATCH 102/103] make clippy happy (kinda) --- tuic-client/src/socks5.rs | 10 +++++++--- tuic-quinn/src/lib.rs | 2 +- tuic-server/src/server.rs | 8 +++----- tuic/src/model/mod.rs | 6 ++++++ tuic/src/protocol/authenticate.rs | 1 + tuic/src/protocol/connect.rs | 1 + tuic/src/protocol/dissociate.rs | 1 + tuic/src/protocol/heartbeat.rs | 5 ++--- tuic/src/protocol/mod.rs | 4 +++- tuic/src/protocol/packet.rs | 1 + 10 files changed, 26 insertions(+), 13 deletions(-) diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index c14ef47..51fd00e 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -243,7 +243,7 @@ impl Server { if frag != 0 { Err(IoError::new( ErrorKind::Other, - format!("fragmented packet is not supported"), + "fragmented packet is not supported", ))?; } @@ -282,8 +282,12 @@ impl Server { } pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { - let sessions = SERVER.get().unwrap().udp_sessions.lock(); - let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() }; + let assoc_socket = { + let sessions = SERVER.get().unwrap().udp_sessions.lock(); + let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() }; + assoc_socket.clone() + }; + assoc_socket.send(pkt, 0, addr).await?; Ok(()) } diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 29057ba..bd86fa0 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -316,7 +316,7 @@ impl Connect { match &self.model { Side::Client(model) => { let Header::Connect(conn) = model.header() else { unreachable!() }; - &conn.addr() + conn.addr() } Side::Server(model) => model.addr(), } diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 892693f..8847248 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -131,7 +131,7 @@ impl Server { loop { let conn = self.ep.accept().await.unwrap(); - tokio::spawn(Connection::new( + tokio::spawn(Connection::handle( conn, self.token.clone(), self.udp_relay_ipv6, @@ -161,8 +161,9 @@ struct Connection { max_concurrent_bi_streams: Arc, } +#[allow(clippy::too_many_arguments)] impl Connection { - async fn new( + async fn handle( conn: Connecting, token: Arc<[u8]>, udp_relay_ipv6: bool, @@ -310,7 +311,6 @@ impl Connection { Err(err) => { eprintln!("{err}"); self.close(); - return; } } } @@ -350,7 +350,6 @@ impl Connection { Err(err) => { eprintln!("{err}"); self.close(); - return; } } } @@ -386,7 +385,6 @@ impl Connection { Err(err) => { eprintln!("{err}"); self.close(); - return; } } } diff --git a/tuic/src/model/mod.rs b/tuic/src/model/mod.rs index 9e66fe1..c65e6cb 100644 --- a/tuic/src/model/mod.rs +++ b/tuic/src/model/mod.rs @@ -40,6 +40,7 @@ impl Connection where B: AsRef<[u8]>, { + #[allow(clippy::new_without_default)] pub fn new() -> Self { let task_associate_count = Counter::new(); @@ -174,6 +175,7 @@ where .send_packet(assoc_id, addr, max_pkt_size) } + #[allow(clippy::too_many_arguments)] fn recv_packet( &mut self, sessions: Arc>, @@ -189,6 +191,7 @@ where }) } + #[allow(clippy::too_many_arguments)] fn recv_packet_unrestricted( &mut self, sessions: Arc>, @@ -215,6 +218,7 @@ where Dissociate::::new(assoc_id) } + #[allow(clippy::too_many_arguments)] fn insert( &mut self, assoc_id: u16, @@ -270,6 +274,7 @@ where ) } + #[allow(clippy::too_many_arguments)] fn recv_packet( &self, sessions: Arc>>, @@ -283,6 +288,7 @@ where Packet::::new(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr) } + #[allow(clippy::too_many_arguments)] fn insert( &mut self, assoc_id: u16, diff --git a/tuic/src/protocol/authenticate.rs b/tuic/src/protocol/authenticate.rs index 9879241..7c86ae7 100644 --- a/tuic/src/protocol/authenticate.rs +++ b/tuic/src/protocol/authenticate.rs @@ -23,6 +23,7 @@ impl Authenticate { Self::TYPE_CODE } + #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { 32 } diff --git a/tuic/src/protocol/connect.rs b/tuic/src/protocol/connect.rs index 6139c83..6f00984 100644 --- a/tuic/src/protocol/connect.rs +++ b/tuic/src/protocol/connect.rs @@ -25,6 +25,7 @@ impl Connect { Self::TYPE_CODE } + #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { self.addr.len() } diff --git a/tuic/src/protocol/dissociate.rs b/tuic/src/protocol/dissociate.rs index 86caa19..eb29754 100644 --- a/tuic/src/protocol/dissociate.rs +++ b/tuic/src/protocol/dissociate.rs @@ -23,6 +23,7 @@ impl Dissociate { Self::TYPE_CODE } + #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { 2 } diff --git a/tuic/src/protocol/heartbeat.rs b/tuic/src/protocol/heartbeat.rs index dd8143a..ec665b3 100644 --- a/tuic/src/protocol/heartbeat.rs +++ b/tuic/src/protocol/heartbeat.rs @@ -17,13 +17,12 @@ impl Heartbeat { Self::TYPE_CODE } + #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { 0 } } impl From for () { - fn from(_: Heartbeat) -> Self { - () - } + fn from(_: Heartbeat) -> Self {} } diff --git a/tuic/src/protocol/mod.rs b/tuic/src/protocol/mod.rs index 3ade4cb..a675def 100644 --- a/tuic/src/protocol/mod.rs +++ b/tuic/src/protocol/mod.rs @@ -53,6 +53,7 @@ impl Header { } } + #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { 2 + match self { Self::Authenticate(auth) => auth.len(), @@ -106,12 +107,13 @@ impl Address { } } + #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { 1 + match self { Address::None => 0, Address::DomainAddress(addr, _) => 1 + addr.len() + 2, Address::SocketAddress(addr) => match addr { - SocketAddr::V4(_) => 1 * 4 + 2, + SocketAddr::V4(_) => 4 + 2, SocketAddr::V6(_) => 2 * 8 + 2, }, } diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs index d9c6a5d..f6703c7 100644 --- a/tuic/src/protocol/packet.rs +++ b/tuic/src/protocol/packet.rs @@ -64,6 +64,7 @@ impl Packet { Self::TYPE_CODE } + #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { Self::len_without_addr() + self.addr.len() } From 43b479eb9272c7b2d7b3e19f01a3bffc4fddee2b Mon Sep 17 00:00:00 2001 From: EAimTY Date: Sat, 4 Feb 2023 19:34:26 +0900 Subject: [PATCH 103/103] README.md --- README.md | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..d358a95 --- /dev/null +++ b/README.md @@ -0,0 +1,38 @@ +# TUIC + +Delicately-TUICed 0-RTT proxy protocol + +**Warning: TUIC's [dev](https://github.com/EAimTY/tuic/tree/dev) branch is under heavy development. For end-user, please check out the latest released branch** + +## Introduction + +TUIC is a proxy protocol focusing on the simplicity. It aims to minimize the additional handshake latency caused by relaying as much as possible + +TUIC is originally designed to be used on top of the [QUIC](https://en.wikipedia.org/wiki/QUIC) protocol, but you can use it with any other protocol, e.g. TCP, in theory + +When paired with QUIC, TUIC can achieve: + +- 0-RTT TCP proxying +- 0-RTT UDP proxying with NAT type [Full Cone](https://www.rfc-editor.org/rfc/rfc3489#section-5) +- 0-RTT authentication +- Two UDP proxying modes: + - `native`: Having characteristics of native UDP mechanism + - `quic`: Transferring UDP packets losslessly using QUIC streams +- Fully multiplexed +- All the advantages of QUIC: + - Bidirectional user-space congestion control + - Connection migration + - Optional 0-RTT connection handshake + +## Overview + +There are 4 crates provided in this repository: + +- **[tuic](https://github.com/EAimTY/tuic/tree/dev/tuic)** - Library. The protocol itself, protcol & model abstraction, synchronous / asynchronous marshalling +- **[tuic-quinn](https://github.com/EAimTY/tuic/tree/dev/tuic-quinn)** - Library. A thin layer on top of [quinn](https://github.com/quinn-rs/quinn) to provide functions for TUIC +- **[tuic-server](https://github.com/EAimTY/tuic/tree/dev/tuic-server)** - Binary. Minimalistic TUIC server implementation as a reference, focusing on the simplicity +- **[tuic-client](https://github.com/EAimTY/tuic/tree/dev/tuic-client)** - Binary. Minimalistic TUIC client implementation as a reference, focusing on the simplicity + +## License + +GNU General Public License v3.0 \ No newline at end of file

ExactSizeIterator for Fragment<'_, P> +where + P: AsRef<[u8]>, +{ fn len(&self) -> usize { self.frag_total as usize } From ceebbdb200bc149476e70a4d1a202249154c23c3 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 25 Jan 2023 18:37:04 +0900 Subject: [PATCH 051/103] adding recv methods for auth & dissoc & heartbeat --- tuic/src/prototype/authenticate.rs | 20 +++++++++++++++-- tuic/src/prototype/dissociate.rs | 20 +++++++++++++++-- tuic/src/prototype/heartbeat.rs | 13 +++++++++-- tuic/src/prototype/mod.rs | 35 +++++++++++++++++++++++++----- tuic/src/prototype/packet.rs | 5 +---- 5 files changed, 77 insertions(+), 16 deletions(-) diff --git a/tuic/src/prototype/authenticate.rs b/tuic/src/prototype/authenticate.rs index 118ee20..ca54f39 100644 --- a/tuic/src/prototype/authenticate.rs +++ b/tuic/src/prototype/authenticate.rs @@ -10,8 +10,6 @@ pub struct Tx { header: Header, } -pub struct Rx; - impl Authenticate { pub(super) fn new(token: [u8; 8]) -> Self { Self { @@ -27,3 +25,21 @@ impl Authenticate { &tx.header } } + +pub struct Rx { + token: [u8; 8], +} + +impl Authenticate { + pub(super) fn new(token: [u8; 8]) -> Self { + Self { + inner: Side::Rx(Rx { token }), + _marker: side::Rx, + } + } + + pub fn token(&self) -> &[u8; 8] { + let Side::Rx(rx) = &self.inner else { unreachable!() }; + &rx.token + } +} diff --git a/tuic/src/prototype/dissociate.rs b/tuic/src/prototype/dissociate.rs index bf7f728..e6e4072 100644 --- a/tuic/src/prototype/dissociate.rs +++ b/tuic/src/prototype/dissociate.rs @@ -10,8 +10,6 @@ pub struct Tx { header: Header, } -pub struct Rx; - impl Dissociate { pub(super) fn new(assoc_id: u16) -> Self { Self { @@ -27,3 +25,21 @@ impl Dissociate { &tx.header } } + +pub struct Rx { + assoc_id: u16, +} + +impl Dissociate { + pub(super) fn new(assoc_id: u16) -> Self { + Self { + inner: Side::Rx(Rx { assoc_id }), + _marker: side::Rx, + } + } + + pub fn assoc_id(&self) -> &u16 { + let Side::Rx(rx) = &self.inner else { unreachable!() }; + &rx.assoc_id + } +} diff --git a/tuic/src/prototype/heartbeat.rs b/tuic/src/prototype/heartbeat.rs index b5b96eb..8fce8b0 100644 --- a/tuic/src/prototype/heartbeat.rs +++ b/tuic/src/prototype/heartbeat.rs @@ -10,8 +10,6 @@ pub struct Tx { header: Header, } -pub struct Rx; - impl Heartbeat { pub(super) fn new() -> Self { Self { @@ -27,3 +25,14 @@ impl Heartbeat { &tx.header } } + +pub struct Rx; + +impl Heartbeat { + pub(super) fn new() -> Self { + Self { + inner: Side::Rx(Rx), + _marker: side::Rx, + } + } +} diff --git a/tuic/src/prototype/mod.rs b/tuic/src/prototype/mod.rs index b79cd7e..358f456 100644 --- a/tuic/src/prototype/mod.rs +++ b/tuic/src/prototype/mod.rs @@ -1,4 +1,7 @@ -use crate::protocol::{Address, Connect as ConnectHeader, Packet as PacketHeader}; +use crate::protocol::{ + Address, Authenticate as AuthenticateHeader, Connect as ConnectHeader, + Dissociate as DissociateHeader, Heartbeat as HeartbeatHeader, Packet as PacketHeader, +}; use parking_lot::Mutex; use std::{ collections::HashMap, @@ -45,7 +48,12 @@ where } pub fn send_authenticate(&self, token: [u8; 8]) -> Authenticate { - Authenticate::new(token) + Authenticate::::new(token) + } + + pub fn recv_authenticate(&self, header: AuthenticateHeader) -> Authenticate { + let (token,) = header.into(); + Authenticate::::new(token) } pub fn send_connect(&self, addr: Address) -> Connect { @@ -82,11 +90,21 @@ where } pub fn send_dissociate(&self, assoc_id: u16) -> Dissociate { - self.udp_sessions.lock().dissociate(assoc_id) + self.udp_sessions.lock().send_dissociate(assoc_id) + } + + pub fn recv_dissociate(&self, header: DissociateHeader) -> Dissociate { + let (assoc_id,) = header.into(); + self.udp_sessions.lock().recv_dissociate(assoc_id) } pub fn send_heartbeat(&self) -> Heartbeat { - Heartbeat::new() + Heartbeat::::new() + } + + pub fn recv_heartbeat(&self, header: HeartbeatHeader) -> Heartbeat { + let () = header.into(); + Heartbeat::::new() } pub fn task_connect_count(&self) -> usize { @@ -174,9 +192,14 @@ where .recv_packet(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr) } - fn dissociate(&mut self, assoc_id: u16) -> Dissociate { + fn send_dissociate(&mut self, assoc_id: u16) -> Dissociate { self.sessions.remove(&assoc_id); - Dissociate::new(assoc_id) + Dissociate::::new(assoc_id) + } + + fn recv_dissociate(&mut self, assoc_id: u16) -> Dissociate { + self.sessions.remove(&assoc_id); + Dissociate::::new(assoc_id) } fn insert( diff --git a/tuic/src/prototype/packet.rs b/tuic/src/prototype/packet.rs index 5afd443..7b757a8 100644 --- a/tuic/src/prototype/packet.rs +++ b/tuic/src/prototype/packet.rs @@ -18,10 +18,7 @@ pub struct Tx { max_pkt_size: usize, } -impl Packet -where - B: AsRef<[u8]>, -{ +impl Packet { pub(super) fn new(assoc_id: u16, pkt_id: u16, addr: Address, max_pkt_size: usize) -> Self { Self { inner: Side::Tx(Tx { From 822da1ceafba7db7db630cd9da911e37ddf9ae6d Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 25 Jan 2023 18:39:24 +0900 Subject: [PATCH 052/103] renaming `prototype` to `model` --- tuic/Cargo.toml | 4 ++-- tuic/src/lib.rs | 2 +- tuic/src/{prototype => model}/authenticate.rs | 0 tuic/src/{prototype => model}/connect.rs | 0 tuic/src/{prototype => model}/dissociate.rs | 0 tuic/src/{prototype => model}/heartbeat.rs | 0 tuic/src/{prototype => model}/mod.rs | 0 tuic/src/{prototype => model}/packet.rs | 0 8 files changed, 3 insertions(+), 3 deletions(-) rename tuic/src/{prototype => model}/authenticate.rs (100%) rename tuic/src/{prototype => model}/connect.rs (100%) rename tuic/src/{prototype => model}/dissociate.rs (100%) rename tuic/src/{prototype => model}/heartbeat.rs (100%) rename tuic/src/{prototype => model}/mod.rs (100%) rename tuic/src/{prototype => model}/packet.rs (100%) diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index 3aab101..490a448 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -4,11 +4,11 @@ version = "0.1.0" edition = "2021" [features] -prototype = ["parking_lot", "thiserror"] +model = ["parking_lot", "thiserror"] [dependencies] parking_lot = { version = "0.12.1", default-features = false, optional = true } thiserror = { version = "1.0.38", default-features = false, optional = true } [dev-dependencies] -tuic = { path = ".", features = ["prototype"] } +tuic = { path = ".", features = ["model"] } diff --git a/tuic/src/lib.rs b/tuic/src/lib.rs index de2d745..48cfb71 100644 --- a/tuic/src/lib.rs +++ b/tuic/src/lib.rs @@ -3,4 +3,4 @@ pub mod protocol; #[cfg(feature = "prototype")] -pub mod prototype; +pub mod model; diff --git a/tuic/src/prototype/authenticate.rs b/tuic/src/model/authenticate.rs similarity index 100% rename from tuic/src/prototype/authenticate.rs rename to tuic/src/model/authenticate.rs diff --git a/tuic/src/prototype/connect.rs b/tuic/src/model/connect.rs similarity index 100% rename from tuic/src/prototype/connect.rs rename to tuic/src/model/connect.rs diff --git a/tuic/src/prototype/dissociate.rs b/tuic/src/model/dissociate.rs similarity index 100% rename from tuic/src/prototype/dissociate.rs rename to tuic/src/model/dissociate.rs diff --git a/tuic/src/prototype/heartbeat.rs b/tuic/src/model/heartbeat.rs similarity index 100% rename from tuic/src/prototype/heartbeat.rs rename to tuic/src/model/heartbeat.rs diff --git a/tuic/src/prototype/mod.rs b/tuic/src/model/mod.rs similarity index 100% rename from tuic/src/prototype/mod.rs rename to tuic/src/model/mod.rs diff --git a/tuic/src/prototype/packet.rs b/tuic/src/model/packet.rs similarity index 100% rename from tuic/src/prototype/packet.rs rename to tuic/src/model/packet.rs From 000f38da15eb8aa87883f0bb7db6429269bfa898 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 25 Jan 2023 18:46:40 +0900 Subject: [PATCH 053/103] fix typo --- tuic/src/lib.rs | 2 +- tuic/src/model/mod.rs | 8 ++++---- tuic/src/model/packet.rs | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tuic/src/lib.rs b/tuic/src/lib.rs index 48cfb71..4eaef2f 100644 --- a/tuic/src/lib.rs +++ b/tuic/src/lib.rs @@ -2,5 +2,5 @@ pub mod protocol; -#[cfg(feature = "prototype")] +#[cfg(feature = "model")] pub mod model; diff --git a/tuic/src/model/mod.rs b/tuic/src/model/mod.rs index 358f456..4cd8cb6 100644 --- a/tuic/src/model/mod.rs +++ b/tuic/src/model/mod.rs @@ -24,7 +24,7 @@ pub use self::{ connect::Connect, dissociate::Dissociate, heartbeat::Heartbeat, - packet::{Fragment, Packet}, + packet::{Fragments, Packet}, }; pub struct Connection { @@ -352,7 +352,7 @@ where } if self.buf[frag_id as usize].is_some() { - return Err(AssembleError::DuplicateFragment); + return Err(AssembleError::DuplicatedFragment); } self.buf[frag_id as usize] = Some(data); @@ -387,6 +387,6 @@ pub enum AssembleError { InvalidFragmentId, #[error("invalid address")] InvalidAddress, - #[error("duplicate fragment")] - DuplicateFragment, + #[error("duplicated fragment")] + DuplicatedFragment, } diff --git a/tuic/src/model/packet.rs b/tuic/src/model/packet.rs index 7b757a8..d4c4483 100644 --- a/tuic/src/model/packet.rs +++ b/tuic/src/model/packet.rs @@ -31,12 +31,12 @@ impl Packet { } } - pub fn into_fragments<'a, P>(self, payload: P) -> Fragment<'a, P> + pub fn into_fragments<'a, P>(self, payload: P) -> Fragments<'a, P> where P: AsRef<[u8]>, { let Side::Tx(tx) = self.inner else { unreachable!() }; - Fragment::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload) + Fragments::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload) } } @@ -96,7 +96,7 @@ where } } -pub struct Fragment<'a, P> +pub struct Fragments<'a, P> where P: 'a, { @@ -111,7 +111,7 @@ where _marker: PhantomData<&'a P>, } -impl<'a, P> Fragment<'a, P> +impl<'a, P> Fragments<'a, P> where P: AsRef<[u8]> + 'a, { @@ -140,7 +140,7 @@ where } } -impl<'a, P> Iterator for Fragment<'a, P> +impl<'a, P> Iterator for Fragments<'a, P> where P: AsRef<[u8]> + 'a, { @@ -176,7 +176,7 @@ where } } -impl