From bcc79d7c5d100f10ece09ce1071e5cfc798d628f Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 2 Feb 2023 13:45:36 +0900 Subject: [PATCH] implement client config reading --- tuic-client/src/config.rs | 136 ++++++++++++-- tuic-client/src/connection.rs | 46 ++++- tuic-client/src/main.rs | 29 ++- tuic-client/src/socks5.rs | 343 ++++++++++++++++++---------------- tuic-client/src/utils.rs | 65 +++++++ 5 files changed, 437 insertions(+), 182 deletions(-) create mode 100644 tuic-client/src/utils.rs diff --git a/tuic-client/src/config.rs b/tuic-client/src/config.rs index f590ee6..81f35db 100644 --- a/tuic-client/src/config.rs +++ b/tuic-client/src/config.rs @@ -1,19 +1,115 @@ +use crate::utils::{self, CongestionControl, UdpRelayMode}; use lexopt::{Arg, Error as ArgumentError, Parser}; -use serde::Deserialize; +use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde_json::Error as SerdeError; -use std::{ffi::OsString, fs::File, io::Error as IoError}; +use std::{ + env::ArgsOs, + fs::File, + io::Error as IoError, + net::{IpAddr, SocketAddr}, + time::Duration, +}; use thiserror::Error; +const HELP_MSG: &str = r#" +Usage tuic-client [arguments] + +Arguments: + -c, --config Path to the config file (required) + -v, --version Print the version + -h, --help Print this help message +"#; + #[derive(Deserialize)] #[serde(deny_unknown_fields)] -pub struct Config {} +pub struct Config { + pub relay: Relay, + pub local: Local, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Relay { + #[serde(deserialize_with = "deserialize_server")] + pub server: (String, u16), + pub token: String, + pub ip: Option, + #[serde(default = "default::relay::certificates")] + pub certificates: Vec, + #[serde( + default = "default::relay::udp_relay_mode", + deserialize_with = "utils::deserialize_from_str" + )] + pub udp_relay_mode: UdpRelayMode, + #[serde( + default = "default::relay::congestion_control", + deserialize_with = "utils::deserialize_from_str" + )] + pub congestion_control: CongestionControl, + #[serde(default = "default::relay::alpn")] + pub alpn: Vec, + #[serde(default = "default::relay::zero_rtt_handshake")] + pub zero_rtt_handshake: bool, + #[serde(default = "default::relay::timeout")] + pub timeout: Duration, + #[serde(default = "default::relay::heartbeat")] + pub heartbeat: Duration, +} + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub struct Local { + pub server: SocketAddr, + pub username: Option, + pub password: Option, + pub dual_stack: Option, + #[serde(default = "default::local::max_packet_size")] + pub max_packet_size: usize, +} + +mod default { + pub mod relay { + use crate::utils::{CongestionControl, UdpRelayMode}; + use std::time::Duration; + + pub const fn certificates() -> Vec { + Vec::new() + } + + pub const fn udp_relay_mode() -> UdpRelayMode { + UdpRelayMode::Native + } + + pub const fn congestion_control() -> CongestionControl { + CongestionControl::Cubic + } + + pub const fn alpn() -> Vec { + Vec::new() + } + + pub const fn zero_rtt_handshake() -> bool { + false + } + + pub const fn timeout() -> Duration { + Duration::from_secs(8) + } + + pub const fn heartbeat() -> Duration { + Duration::from_secs(3) + } + } + + pub mod local { + pub const fn max_packet_size() -> usize { + 1500 + } + } +} impl Config { - pub fn parse(args: A) -> Result - where - A: IntoIterator, - A::Item: Into, - { + pub fn parse(args: ArgsOs) -> Result { let mut parser = Parser::from_iter(args); let mut path = None; @@ -29,7 +125,7 @@ impl Config { Arg::Short('v') | Arg::Long("version") => { return Err(ConfigError::Version(env!("CARGO_PKG_VERSION"))) } - Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(todo!())), + Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(HELP_MSG)), _ => return Err(ConfigError::Argument(arg.unexpected())), } } @@ -44,9 +140,25 @@ impl Config { } } +pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error> +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + let mut parts = s.split(':'); + + match (parts.next(), parts.next(), parts.next()) { + (Some(domain), Some(port), None) => port.parse().map_or_else( + |e| Err(DeError::custom(e)), + |port| Ok((domain.to_owned(), port)), + ), + _ => Err(DeError::custom("invalid server address")), + } +} + #[derive(Debug, Error)] pub enum ConfigError { - #[error("transparent")] + #[error(transparent)] Argument(#[from] ArgumentError), #[error("no config file specified")] NoConfig, @@ -54,8 +166,8 @@ pub enum ConfigError { Version(&'static str), #[error("{0}")] Help(&'static str), - #[error("transparent")] + #[error(transparent)] Io(#[from] IoError), - #[error("transparent")] + #[error(transparent)] Serde(#[from] SerdeError), } diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index 159af6e..7d14e8c 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -1,4 +1,9 @@ -use crate::{error::Error, socks5}; +use crate::{ + config::Relay, + error::Error, + socks5::Server as Socks5Server, + utils::{ServerAddr, UdpRelayMode}, +}; use bytes::Bytes; use once_cell::sync::OnceCell; use parking_lot::Mutex; @@ -27,14 +32,36 @@ static CONNECTION: AsyncOnceCell> = AsyncOnceCell::const_ const DEFAULT_CONCURRENT_STREAMS: usize = 32; -struct Endpoint { +pub struct Endpoint { ep: QuinnEndpoint, + server: ServerAddr, + token: Vec, + udp_relay_mode: UdpRelayMode, + zero_rtt_handshake: bool, + timeout: Duration, + heartbeat: Duration, } impl Endpoint { - fn new() -> Result { - let ep = QuinnEndpoint::client(SocketAddr::from(([0, 0, 0, 0], 0)))?; - Ok(Self { ep }) + pub fn set_config(cfg: Relay) -> Result<(), Error> { + let ep = todo!(); + + let ep = Self { + ep, + server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip), + token: cfg.token.into_bytes(), + udp_relay_mode: cfg.udp_relay_mode, + zero_rtt_handshake: cfg.zero_rtt_handshake, + timeout: cfg.timeout, + heartbeat: cfg.heartbeat, + }; + + ENDPOINT + .set(Mutex::new(ep)) + .map_err(|_| "endpoint already initialized") + .unwrap(); + + Ok(()) } async fn connect(&self) -> Result { @@ -75,8 +102,9 @@ impl Connection { pub async fn get() -> Result { let try_init_conn = async { ENDPOINT - .get_or_try_init(|| Endpoint::new().map(Mutex::new)) - .map(|ep| ep.lock())? + .get() + .unwrap() + .lock() .connect() .await .map(AsyncMutex::new) @@ -170,7 +198,7 @@ impl Connection { } Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), }; - socks5::recv_pkt(pkt, addr, assoc_id).await + Socks5Server::recv_pkt(pkt, addr, assoc_id).await } Ok(None) => Ok(()), Err(err) => Err(Error::from(err)), @@ -208,7 +236,7 @@ impl Connection { } Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), }; - socks5::recv_pkt(pkt, addr, assoc_id).await + Socks5Server::recv_pkt(pkt, addr, assoc_id).await } Ok(None) => Ok(()), Err(err) => Err(Error::from(err)), diff --git a/tuic-client/src/main.rs b/tuic-client/src/main.rs index 20cc0b6..e1d83e5 100644 --- a/tuic-client/src/main.rs +++ b/tuic-client/src/main.rs @@ -1,14 +1,20 @@ -use self::config::{Config, ConfigError}; +use socks5::Server; + +use self::{ + config::{Config, ConfigError}, + connection::Endpoint, +}; use std::{env, process}; mod config; mod connection; mod error; mod socks5; +mod utils; #[tokio::main] async fn main() { - let _cfg = match Config::parse(env::args_os()) { + let cfg = match Config::parse(env::args_os()) { Ok(cfg) => cfg, Err(ConfigError::Version(msg) | ConfigError::Help(msg)) => { println!("{msg}"); @@ -20,8 +26,21 @@ async fn main() { } }; - if let Err(err) = socks5::start().await { - eprintln!("{err}"); - process::exit(1); + match Endpoint::set_config(cfg.relay) { + Ok(()) => {} + Err(err) => { + eprintln!("{err}"); + process::exit(1); + } } + + match Server::set_config(cfg.local).await { + Ok(()) => {} + Err(err) => { + eprintln!("{err}"); + process::exit(1); + } + } + + Server::start().await; } diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index 70ed148..00af2b5 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -1,12 +1,12 @@ -use crate::{connection::Connection as TuicConnection, error::Error}; +use crate::{config::Local, connection::Connection as TuicConnection, error::Error}; use bytes::Bytes; -use once_cell::sync::Lazy; +use once_cell::sync::{Lazy, OnceCell}; use parking_lot::Mutex; use socks5_proto::{Address, Reply}; use socks5_server::{ auth::NoAuth, connection::{associate, bind, connect}, - Associate, AssociatedUdpSocket, Bind, Connect, Connection, Server, + Associate, AssociatedUdpSocket, Bind, Connect, Connection, Server as Socks5Server, }; use std::{ collections::HashMap, @@ -24,183 +24,214 @@ use tokio::{ use tokio_util::compat::FuturesAsyncReadCompatExt; use tuic::Address as TuicAddress; +static SERVER: OnceCell = OnceCell::new(); static NEXT_ASSOCIATE_ID: AtomicU16 = AtomicU16::new(0); static UDP_SESSIONS: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); -pub async fn start() -> Result<(), Error> { - let server = Server::bind("127.0.0.1:5000", Arc::new(NoAuth)).await?; +pub struct Server { + inner: Socks5Server, + dual_stack: Option, + max_packet_size: usize, +} - while let Ok((conn, _)) = server.accept().await { - tokio::spawn(async move { - let res = match conn.handshake().await { - Ok(Connection::Associate(associate, addr)) => { - handle_associate(associate, addr).await +impl Server { + pub async fn set_config(cfg: Local) -> Result<(), Error> { + let server = Socks5Server::bind(cfg.server, Arc::new(NoAuth)).await?; + + let server = Self { + inner: server, + dual_stack: cfg.dual_stack, + max_packet_size: cfg.max_packet_size, + }; + + SERVER + .set(server) + .map_err(|_| "socks5 server already initialized") + .unwrap(); + + Ok(()) + } + + pub async fn start() { + let server = SERVER.get().unwrap(); + + loop { + match server.inner.accept().await { + Ok((conn, _)) => { + tokio::spawn(async move { + let res = match conn.handshake().await { + Ok(Connection::Associate(associate, addr)) => { + Self::handle_associate(associate, addr).await + } + Ok(Connection::Bind(bind, addr)) => Self::handle_bind(bind, addr).await, + Ok(Connection::Connect(connect, addr)) => { + Self::handle_connect(connect, addr).await + } + Err(err) => Err(Error::from(err)), + }; + + match res { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + } + }); } - Ok(Connection::Bind(bind, addr)) => handle_bind(bind, addr).await, - Ok(Connection::Connect(connect, addr)) => handle_connect(connect, addr).await, - Err(err) => Err(Error::from(err)), - }; - - match res { - Ok(_) => {} Err(err) => eprintln!("{err}"), } - }); - } - - Ok(()) -} - -async fn handle_associate( - assoc: Associate, - _addr: Address, -) -> Result<(), Error> { - let assoc_socket = UdpSocket::bind(SocketAddr::from((assoc.local_addr()?.ip(), 0))) - .await - .and_then(|socket| { - socket - .local_addr() - .map(|addr| (Arc::new(AssociatedUdpSocket::from((socket, 1500))), addr)) - }); - - match assoc_socket { - Ok((assoc_socket, assoc_addr)) => { - let assoc = assoc - .reply(Reply::Succeeded, Address::SocketAddress(assoc_addr)) - .await?; - send_pkt(assoc, assoc_socket).await - } - Err(err) => { - let mut assoc = assoc - .reply(Reply::GeneralFailure, Address::unspecified()) - .await?; - let _ = assoc.shutdown().await; - Err(Error::from(err)) } } -} -async fn handle_bind(bind: Bind, _addr: Address) -> Result<(), Error> { - let mut conn = bind - .reply(Reply::CommandNotSupported, Address::unspecified()) - .await?; - let _ = conn.shutdown().await; - Ok(()) -} - -async fn handle_connect(conn: Connect, addr: Address) -> Result<(), Error> { - let target_addr = match addr { - Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), - Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), - }; - - let relay = match TuicConnection::get().await { - Ok(conn) => conn.connect(target_addr).await, - Err(err) => Err(err), - }; - - match relay { - Ok(relay) => { - let mut relay = relay.compat(); - let conn = conn.reply(Reply::Succeeded, Address::unspecified()).await; - - match conn { - Ok(mut conn) => match io::copy_bidirectional(&mut conn, &mut relay).await { - Ok(_) => Ok(()), - Err(err) => { - let _ = conn.shutdown().await; - let _ = relay.shutdown().await; - Err(Error::from(err)) - } - }, - Err(err) => { - let _ = relay.shutdown().await; - Err(Error::from(err)) - } - } - } - Err(err) => { - let mut conn = conn - .reply(Reply::GeneralFailure, Address::unspecified()) - .await?; - let _ = conn.shutdown().await; - Err(err) - } - } -} - -async fn send_pkt( - mut assoc: Associate, - assoc_socket: Arc, -) -> Result<(), Error> { - let assoc_id = NEXT_ASSOCIATE_ID.fetch_add(1, Ordering::AcqRel); - UDP_SESSIONS.lock().insert(assoc_id, assoc_socket.clone()); - let mut connected = None; - - async fn accept_pkt( - assoc_socket: &AssociatedUdpSocket, - connected: &mut Option, - assoc_id: u16, + async fn handle_associate( + assoc: Associate, + _addr: Address, ) -> Result<(), Error> { - let (pkt, frag, dst_addr, src_addr) = assoc_socket.recv_from().await?; + let assoc_socket = UdpSocket::bind(SocketAddr::from((assoc.local_addr()?.ip(), 0))) + .await + .and_then(|socket| { + socket + .local_addr() + .map(|addr| (Arc::new(AssociatedUdpSocket::from((socket, 1500))), addr)) + }); - if let Some(connected) = connected { - if connected != &src_addr { - Err(IoError::new( - ErrorKind::Other, - format!("invalid source address: {src_addr}"), - ))?; + match assoc_socket { + Ok((assoc_socket, assoc_addr)) => { + let assoc = assoc + .reply(Reply::Succeeded, Address::SocketAddress(assoc_addr)) + .await?; + Self::send_pkt(assoc, assoc_socket).await + } + Err(err) => { + let mut assoc = assoc + .reply(Reply::GeneralFailure, Address::unspecified()) + .await?; + let _ = assoc.shutdown().await; + Err(Error::from(err)) } - } else { - assoc_socket.connect(src_addr).await?; - *connected = Some(src_addr); } + } - if frag != 0 { - Err(IoError::new( - ErrorKind::Other, - format!("fragmented packet is not supported"), - ))?; - } + async fn handle_bind(bind: Bind, _addr: Address) -> Result<(), Error> { + let mut conn = bind + .reply(Reply::CommandNotSupported, Address::unspecified()) + .await?; + let _ = conn.shutdown().await; + Ok(()) + } - let target_addr = match dst_addr { + async fn handle_connect(conn: Connect, addr: Address) -> Result<(), Error> { + let target_addr = match addr { Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), }; - TuicConnection::get() - .await? - .packet(pkt, target_addr, assoc_id) - .await - } + let relay = match TuicConnection::get().await { + Ok(conn) => conn.connect(target_addr).await, + Err(err) => Err(err), + }; - let res = tokio::select! { - res = assoc.wait_until_closed() => res, - _ = async { loop { - if let Err(err) = accept_pkt(&assoc_socket, &mut connected, assoc_id).await { - eprintln!("{err}"); + match relay { + Ok(relay) => { + let mut relay = relay.compat(); + let conn = conn.reply(Reply::Succeeded, Address::unspecified()).await; + + match conn { + Ok(mut conn) => match io::copy_bidirectional(&mut conn, &mut relay).await { + Ok(_) => Ok(()), + Err(err) => { + let _ = conn.shutdown().await; + let _ = relay.shutdown().await; + Err(Error::from(err)) + } + }, + Err(err) => { + let _ = relay.shutdown().await; + Err(Error::from(err)) + } + } } - }} => unreachable!(), - }; - - let _ = assoc.shutdown().await; - UDP_SESSIONS.lock().remove(&assoc_id); - - match TuicConnection::get().await { - Ok(conn) => match conn.dissociate(assoc_id).await { - Ok(_) => {} - Err(err) => eprintln!("{err}"), - }, - Err(err) => eprintln!("{err}"), + Err(err) => { + let mut conn = conn + .reply(Reply::GeneralFailure, Address::unspecified()) + .await?; + let _ = conn.shutdown().await; + Err(err) + } + } } - Ok(res?) -} + async fn send_pkt( + mut assoc: Associate, + assoc_socket: Arc, + ) -> Result<(), Error> { + let assoc_id = NEXT_ASSOCIATE_ID.fetch_add(1, Ordering::AcqRel); + UDP_SESSIONS.lock().insert(assoc_id, assoc_socket.clone()); + let mut connected = None; -pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { - let sessions = UDP_SESSIONS.lock(); - let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() }; - assoc_socket.send(pkt, 0, addr).await?; - Ok(()) + async fn accept_pkt( + assoc_socket: &AssociatedUdpSocket, + connected: &mut Option, + assoc_id: u16, + ) -> Result<(), Error> { + let (pkt, frag, dst_addr, src_addr) = assoc_socket.recv_from().await?; + + if let Some(connected) = connected { + if connected != &src_addr { + Err(IoError::new( + ErrorKind::Other, + format!("invalid source address: {src_addr}"), + ))?; + } + } else { + assoc_socket.connect(src_addr).await?; + *connected = Some(src_addr); + } + + if frag != 0 { + Err(IoError::new( + ErrorKind::Other, + format!("fragmented packet is not supported"), + ))?; + } + + let target_addr = match dst_addr { + Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), + Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), + }; + + TuicConnection::get() + .await? + .packet(pkt, target_addr, assoc_id) + .await + } + + let res = tokio::select! { + res = assoc.wait_until_closed() => res, + _ = async { loop { + if let Err(err) = accept_pkt(&assoc_socket, &mut connected, assoc_id).await { + eprintln!("{err}"); + } + }} => unreachable!(), + }; + + let _ = assoc.shutdown().await; + UDP_SESSIONS.lock().remove(&assoc_id); + + match TuicConnection::get().await { + Ok(conn) => match conn.dissociate(assoc_id).await { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + }, + Err(err) => eprintln!("{err}"), + } + + Ok(res?) + } + + pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { + let sessions = UDP_SESSIONS.lock(); + let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() }; + assoc_socket.send(pkt, 0, addr).await?; + Ok(()) + } } diff --git a/tuic-client/src/utils.rs b/tuic-client/src/utils.rs new file mode 100644 index 0000000..30ee389 --- /dev/null +++ b/tuic-client/src/utils.rs @@ -0,0 +1,65 @@ +use serde::{de::Error as DeError, Deserialize, Deserializer}; +use std::{fmt::Display, net::IpAddr, str::FromStr}; + +pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result +where + T: FromStr, + ::Err: Display, + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + T::from_str(&s).map_err(DeError::custom) +} + +pub struct ServerAddr { + domain: String, + port: u16, + ip: Option, +} + +impl ServerAddr { + pub fn new(domain: String, port: u16, ip: Option) -> Self { + Self { domain, port, ip } + } +} + +pub enum UdpRelayMode { + Native, + Quic, +} + +impl FromStr for UdpRelayMode { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + if s.eq_ignore_ascii_case("native") { + Ok(Self::Native) + } else if s.eq_ignore_ascii_case("quic") { + Ok(Self::Quic) + } else { + Err("invalid UDP relay mode") + } + } +} + +pub enum CongestionControl { + Cubic, + NewReno, + Bbr, +} + +impl FromStr for CongestionControl { + type Err = &'static str; + + fn from_str(s: &str) -> Result { + if s.eq_ignore_ascii_case("cubic") { + Ok(Self::Cubic) + } else if s.eq_ignore_ascii_case("new_reno") || s.eq_ignore_ascii_case("newreno") { + Ok(Self::NewReno) + } else if s.eq_ignore_ascii_case("bbr") { + Ok(Self::Bbr) + } else { + Err("invalid congestion control") + } + } +}