From 9806c62fe70add6b286f40dd0ae04949acac42a9 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 2 Feb 2023 23:42:53 +0900 Subject: [PATCH] reading client config to endpoint --- tuic-client/Cargo.toml | 5 ++ tuic-client/src/config.rs | 113 ++++++++++++++---------- tuic-client/src/connection.rs | 157 ++++++++++++++++++++++++++++------ tuic-client/src/error.rs | 11 ++- tuic-client/src/socks5.rs | 2 +- tuic-client/src/utils.rs | 62 +++++++++++--- 6 files changed, 265 insertions(+), 85 deletions(-) diff --git a/tuic-client/Cargo.toml b/tuic-client/Cargo.toml index 1a614f1..b291e7f 100644 --- a/tuic-client/Cargo.toml +++ b/tuic-client/Cargo.toml @@ -5,11 +5,15 @@ edition = "2021" [dependencies] bytes = { version = "1.3.0", default-features = false, features = ["std"] } +crossbeam-utils = { version = "0.8.14", default-features = false, features = ["std"] } lexopt = { version = "0.3.0", default-features = false } once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] } parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] } quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] } register-count = { version = "0.1.0", default-features = false, features = ["std"] } +rustls = { version = "0.20.8", default-features = false, features = ["quic"] } +rustls-native-certs = { version = "0.6.2", default-features = false } +rustls-pemfile = { version = "1.0.2", default-features = false } serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] } serde_json = { version = "1.0.91", default-features = false, features = ["std"] } socket2 = { version = "0.4.7", default-features = false } @@ -20,3 +24,4 @@ tokio = { version = "1.24.2", default-features = false, features = ["macros", "n tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] } tuic = { path = "../tuic", default-features = false } tuic-quinn = { path = "../tuic-quinn", default-features = false } +webpki = { version = "0.22.0", default-features = false } diff --git a/tuic-client/src/config.rs b/tuic-client/src/config.rs index 81f35db..eb06b28 100644 --- a/tuic-client/src/config.rs +++ b/tuic-client/src/config.rs @@ -1,12 +1,14 @@ -use crate::utils::{self, CongestionControl, UdpRelayMode}; +use crate::utils::{CongestionControl, UdpRelayMode}; use lexopt::{Arg, Error as ArgumentError, Parser}; use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde_json::Error as SerdeError; use std::{ env::ArgsOs, + fmt::Display, fs::File, io::Error as IoError, net::{IpAddr, SocketAddr}, + str::FromStr, time::Duration, }; use thiserror::Error; @@ -38,22 +40,26 @@ pub struct Relay { pub certificates: Vec, #[serde( default = "default::relay::udp_relay_mode", - deserialize_with = "utils::deserialize_from_str" + deserialize_with = "deserialize_from_str" )] pub udp_relay_mode: UdpRelayMode, #[serde( default = "default::relay::congestion_control", - deserialize_with = "utils::deserialize_from_str" + deserialize_with = "deserialize_from_str" )] pub congestion_control: CongestionControl, #[serde(default = "default::relay::alpn")] pub alpn: Vec, #[serde(default = "default::relay::zero_rtt_handshake")] pub zero_rtt_handshake: bool, + #[serde(default = "default::relay::disable_sni")] + pub disable_sni: bool, #[serde(default = "default::relay::timeout")] pub timeout: Duration, #[serde(default = "default::relay::heartbeat")] pub heartbeat: Duration, + #[serde(default = "default::relay::disable_native_certificates")] + pub disable_native_certificates: bool, } #[derive(Deserialize)] @@ -67,47 +73,6 @@ pub struct Local { pub max_packet_size: usize, } -mod default { - pub mod relay { - use crate::utils::{CongestionControl, UdpRelayMode}; - use std::time::Duration; - - pub const fn certificates() -> Vec { - Vec::new() - } - - pub const fn udp_relay_mode() -> UdpRelayMode { - UdpRelayMode::Native - } - - pub const fn congestion_control() -> CongestionControl { - CongestionControl::Cubic - } - - pub const fn alpn() -> Vec { - Vec::new() - } - - pub const fn zero_rtt_handshake() -> bool { - false - } - - pub const fn timeout() -> Duration { - Duration::from_secs(8) - } - - pub const fn heartbeat() -> Duration { - Duration::from_secs(3) - } - } - - pub mod local { - pub const fn max_packet_size() -> usize { - 1500 - } - } -} - impl Config { pub fn parse(args: ArgsOs) -> Result { let mut parser = Parser::from_iter(args); @@ -135,11 +100,69 @@ impl Config { } let file = File::open(path.unwrap())?; - Ok(serde_json::from_reader(file)?) } } +mod default { + pub mod relay { + use crate::utils::{CongestionControl, UdpRelayMode}; + use std::time::Duration; + + pub fn certificates() -> Vec { + Vec::new() + } + + pub fn udp_relay_mode() -> UdpRelayMode { + UdpRelayMode::Native + } + + pub fn congestion_control() -> CongestionControl { + CongestionControl::Cubic + } + + pub fn alpn() -> Vec { + Vec::new() + } + + pub fn zero_rtt_handshake() -> bool { + false + } + + pub fn disable_sni() -> bool { + false + } + + pub fn timeout() -> Duration { + Duration::from_secs(8) + } + + pub fn heartbeat() -> Duration { + Duration::from_secs(3) + } + + pub fn disable_native_certificates() -> bool { + false + } + } + + pub mod local { + pub fn max_packet_size() -> usize { + 1500 + } + } +} + +pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result +where + T: FromStr, + ::Err: Display, + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + T::from_str(&s).map_err(DeError::custom) +} + pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error> where D: Deserializer<'de>, diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index 84bf348..2a62e6c 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -2,18 +2,22 @@ use crate::{ config::Relay, error::Error, socks5::Server as Socks5Server, - utils::{ServerAddr, UdpRelayMode}, + utils::{self, CongestionControl, ServerAddr, UdpRelayMode}, }; use bytes::Bytes; +use crossbeam_utils::atomic::AtomicCell; use once_cell::sync::OnceCell; use parking_lot::Mutex; use quinn::{ - Connection as QuinnConnection, Endpoint as QuinnEndpoint, RecvStream, SendStream, VarInt, + congestion::{BbrConfig, CubicConfig, NewRenoConfig}, + ClientConfig, Connection as QuinnConnection, Endpoint as QuinnEndpoint, EndpointConfig, + RecvStream, SendStream, TokioRuntime, TransportConfig, VarInt, }; use register_count::{Counter, Register}; +use rustls::{version, ClientConfig as RustlsClientConfig}; use socks5_proto::Address as Socks5Address; use std::{ - net::SocketAddr, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -29,30 +33,67 @@ use tuic_quinn::{side, Connect, Connection as Model, Task}; static ENDPOINT: OnceCell> = OnceCell::new(); static CONNECTION: AsyncOnceCell> = AsyncOnceCell::const_new(); +static TIMEOUT: AtomicCell = AtomicCell::new(Duration::from_secs(0)); const DEFAULT_CONCURRENT_STREAMS: usize = 32; pub struct Endpoint { ep: QuinnEndpoint, server: ServerAddr, - token: Vec, + token: Arc<[u8]>, udp_relay_mode: UdpRelayMode, zero_rtt_handshake: bool, - timeout: Duration, heartbeat: Duration, } impl Endpoint { pub fn set_config(cfg: Relay) -> Result<(), Error> { - let ep = todo!(); + let certs = utils::load_certs(cfg.certificates, cfg.disable_native_certificates)?; + + let mut crypto = RustlsClientConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&version::TLS13]) + .unwrap() + .with_root_certificates(certs) + .with_no_client_auth(); + + crypto.alpn_protocols = cfg.alpn.into_iter().map(|alpn| alpn.into_bytes()).collect(); + crypto.enable_early_data = true; + crypto.enable_sni = !cfg.disable_sni; + + let mut config = ClientConfig::new(Arc::new(crypto)); + let mut tp_cfg = TransportConfig::default(); + + tp_cfg + .max_concurrent_bidi_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32)) + .max_concurrent_uni_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32)) + .max_idle_timeout(None); + + match cfg.congestion_control { + CongestionControl::Cubic => { + tp_cfg.congestion_controller_factory(Arc::new(CubicConfig::default())) + } + CongestionControl::NewReno => { + tp_cfg.congestion_controller_factory(Arc::new(NewRenoConfig::default())) + } + CongestionControl::Bbr => { + tp_cfg.congestion_controller_factory(Arc::new(BbrConfig::default())) + } + }; + + config.transport_config(Arc::new(tp_cfg)); + + let socket = UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], 0)))?; + let mut ep = QuinnEndpoint::new(EndpointConfig::default(), None, socket, TokioRuntime)?; + ep.set_default_client_config(config); let ep = Self { ep, server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip), - token: cfg.token.into_bytes(), + token: Arc::from(cfg.token.into_bytes().into_boxed_slice()), udp_relay_mode: cfg.udp_relay_mode, zero_rtt_handshake: cfg.zero_rtt_handshake, - timeout: cfg.timeout, heartbeat: cfg.heartbeat, }; @@ -61,19 +102,67 @@ impl Endpoint { .map_err(|_| "endpoint already initialized") .unwrap(); + TIMEOUT.store(cfg.timeout); + Ok(()) } - async fn connect(&self) -> Result { - let conn = self - .ep - .connect(SocketAddr::from(([127, 0, 0, 1], 8080)), "localhost")? + async fn connect(&mut self) -> Result { + async fn connect_to( + ep: &mut QuinnEndpoint, + addr: SocketAddr, + server_name: &str, + udp_relay_mode: UdpRelayMode, + zero_rtt_handshake: bool, + ) -> Result { + let match_ipv4 = addr.is_ipv4() && ep.local_addr().map_or(false, |addr| addr.is_ipv4()); + let match_ipv6 = addr.is_ipv6() && ep.local_addr().map_or(false, |addr| addr.is_ipv6()); + + if !match_ipv4 && !match_ipv6 { + let bind_addr = if addr.is_ipv4() { + SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)) + } else { + SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)) + }; + + ep.rebind(UdpSocket::bind(bind_addr)?)?; + } + + let conn = ep.connect(addr, server_name)?; + + let conn = if zero_rtt_handshake { + match conn.into_0rtt() { + Ok((conn, _)) => conn, + Err(conn) => conn.await?, + } + } else { + conn.await? + }; + + Ok(Connection::new(conn, udp_relay_mode)) + } + + let mut last_err = None; + + for addr in self.server.resolve().await? { + match connect_to( + &mut self.ep, + addr, + self.server.server_name(), + self.udp_relay_mode, + self.zero_rtt_handshake, + ) .await - .map(Connection::new)?; + { + Ok(conn) => { + tokio::spawn(conn.clone().init(self.token.clone(), self.heartbeat)); + return Ok(conn); + } + Err(err) => last_err = Some(err), + } + } - tokio::spawn(conn.clone().init()); - - Ok(conn) + Err(last_err.unwrap_or(Error::DnsResolve)) } } @@ -81,6 +170,7 @@ impl Endpoint { pub struct Connection { conn: QuinnConnection, model: Model, + udp_relay_mode: UdpRelayMode, remote_uni_stream_cnt: Counter, remote_bi_stream_cnt: Counter, max_concurrent_uni_streams: Arc, @@ -88,10 +178,11 @@ pub struct Connection { } impl Connection { - fn new(conn: QuinnConnection) -> Self { + fn new(conn: QuinnConnection, udp_relay_mode: UdpRelayMode) -> Self { Self { conn: conn.clone(), model: Model::::new(conn), + udp_relay_mode, remote_uni_stream_cnt: Counter::new(), remote_bi_stream_cnt: Counter::new(), max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), @@ -125,7 +216,7 @@ impl Connection { Ok::<_, Error>(conn.clone()) }; - let conn = time::timeout(Duration::from_secs(5), try_get_conn) + let conn = time::timeout(TIMEOUT.load(), try_get_conn) .await .map_err(|_| Error::Timeout)??; @@ -137,7 +228,11 @@ impl Connection { } pub async fn packet(&self, pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { - self.model.packet_quic(pkt, addr, assoc_id).await?; // TODO + match self.udp_relay_mode { + UdpRelayMode::Native => self.model.packet_native(pkt, addr, assoc_id)?, + UdpRelayMode::Quic => self.model.packet_quic(pkt, addr, assoc_id).await?, + } + Ok(()) } @@ -250,16 +345,26 @@ impl Connection { } } - async fn authenticate(self) { - match self.model.authenticate([0; 32]).await { + async fn authenticate(self, token: Arc<[u8]>) { + let mut buf = [0; 32]; + + match self.conn.export_keying_material(&mut buf, &token, &token) { + Ok(()) => {} + Err(_) => { + eprintln!("token length too short"); + return; + } + } + + match self.model.authenticate(buf).await { Ok(()) => {} Err(err) => eprintln!("{err}"), } } - async fn heartbeat(self) { + async fn heartbeat(self, heartbeat: Duration) { loop { - time::sleep(Duration::from_secs(5)).await; + time::sleep(heartbeat).await; if self.is_closed() { break; @@ -272,9 +377,9 @@ impl Connection { } } - async fn init(self) { - tokio::spawn(self.clone().authenticate()); - tokio::spawn(self.clone().heartbeat()); + async fn init(self, token: Arc<[u8]>, heartbeat: Duration) { + tokio::spawn(self.clone().authenticate(token)); + tokio::spawn(self.clone().heartbeat(heartbeat)); let err = loop { tokio::select! { diff --git a/tuic-client/src/error.rs b/tuic-client/src/error.rs index 74fa0c5..8988b01 100644 --- a/tuic-client/src/error.rs +++ b/tuic-client/src/error.rs @@ -2,6 +2,7 @@ use quinn::{ConnectError, ConnectionError}; use std::io::Error as IoError; use thiserror::Error; use tuic_quinn::Error as ModelError; +use webpki::Error as WebpkiError; #[derive(Debug, Error)] pub enum Error { @@ -13,8 +14,12 @@ pub enum Error { Connection(#[from] ConnectionError), #[error(transparent)] Model(#[from] ModelError), - #[error("timeout")] + #[error(transparent)] + Webpki(#[from] WebpkiError), + #[error("timeout establishing connection")] Timeout, - #[error("invalid authentication")] - InvalidAuth, + #[error("cannot resolve the server name")] + DnsResolve, + #[error("invalid socks5 authentication")] + InvalidSocks5Auth, } diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index 37d31aa..d44b58a 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -60,7 +60,7 @@ impl Server { Arc::new(Password::new(username.into_bytes(), password.into_bytes())) } (None, None) => Arc::new(NoAuth), - _ => return Err(Error::InvalidAuth), + _ => return Err(Error::InvalidSocks5Auth), }; let server = Self { diff --git a/tuic-client/src/utils.rs b/tuic-client/src/utils.rs index 30ee389..afa7695 100644 --- a/tuic-client/src/utils.rs +++ b/tuic-client/src/utils.rs @@ -1,14 +1,40 @@ -use serde::{de::Error as DeError, Deserialize, Deserializer}; -use std::{fmt::Display, net::IpAddr, str::FromStr}; +use crate::error::Error; +use rustls::{Certificate, RootCertStore}; +use rustls_pemfile::Item; +use std::{ + fs::{self, File}, + io::BufReader, + net::{IpAddr, SocketAddr}, + str::FromStr, +}; +use tokio::net; -pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result -where - T: FromStr, - ::Err: Display, - D: Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - T::from_str(&s).map_err(DeError::custom) +pub fn load_certs(paths: Vec, disable_native: bool) -> Result { + let mut certs = RootCertStore::empty(); + + for path in &paths { + let mut file = BufReader::new(File::open(path)?); + + while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) { + if let Item::X509Certificate(cert) = item { + certs.add(&Certificate(cert))?; + } + } + } + + if certs.is_empty() { + for path in &paths { + certs.add(&Certificate(fs::read(path)?))?; + } + } + + if !disable_native { + for cert in rustls_native_certs::load_native_certs()? { + certs.add(&Certificate(cert.0))?; + } + } + + Ok(certs) } pub struct ServerAddr { @@ -21,8 +47,24 @@ impl ServerAddr { pub fn new(domain: String, port: u16, ip: Option) -> Self { Self { domain, port, ip } } + + pub fn server_name(&self) -> &str { + &self.domain + } + + pub async fn resolve(&self) -> Result, Error> { + if let Some(ip) = self.ip { + Ok(vec![SocketAddr::from((ip, self.port))].into_iter()) + } else { + Ok(net::lookup_host((self.domain.as_str(), self.port)) + .await? + .collect::>() + .into_iter()) + } + } } +#[derive(Clone, Copy)] pub enum UdpRelayMode { Native, Quic,