1
0

implement client config reading

This commit is contained in:
EAimTY 2023-02-02 13:45:36 +09:00
parent 1065e3e938
commit bcc79d7c5d
5 changed files with 437 additions and 182 deletions

View File

@ -1,19 +1,115 @@
use crate::utils::{self, CongestionControl, UdpRelayMode};
use lexopt::{Arg, Error as ArgumentError, Parser}; use lexopt::{Arg, Error as ArgumentError, Parser};
use serde::Deserialize; use serde::{de::Error as DeError, Deserialize, Deserializer};
use serde_json::Error as SerdeError; use serde_json::Error as SerdeError;
use std::{ffi::OsString, fs::File, io::Error as IoError}; use std::{
env::ArgsOs,
fs::File,
io::Error as IoError,
net::{IpAddr, SocketAddr},
time::Duration,
};
use thiserror::Error; use thiserror::Error;
const HELP_MSG: &str = r#"
Usage tuic-client [arguments]
Arguments:
-c, --config <path> Path to the config file (required)
-v, --version Print the version
-h, --help Print this help message
"#;
#[derive(Deserialize)] #[derive(Deserialize)]
#[serde(deny_unknown_fields)] #[serde(deny_unknown_fields)]
pub struct Config {} pub struct Config {
pub relay: Relay,
pub local: Local,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Relay {
#[serde(deserialize_with = "deserialize_server")]
pub server: (String, u16),
pub token: String,
pub ip: Option<IpAddr>,
#[serde(default = "default::relay::certificates")]
pub certificates: Vec<String>,
#[serde(
default = "default::relay::udp_relay_mode",
deserialize_with = "utils::deserialize_from_str"
)]
pub udp_relay_mode: UdpRelayMode,
#[serde(
default = "default::relay::congestion_control",
deserialize_with = "utils::deserialize_from_str"
)]
pub congestion_control: CongestionControl,
#[serde(default = "default::relay::alpn")]
pub alpn: Vec<String>,
#[serde(default = "default::relay::zero_rtt_handshake")]
pub zero_rtt_handshake: bool,
#[serde(default = "default::relay::timeout")]
pub timeout: Duration,
#[serde(default = "default::relay::heartbeat")]
pub heartbeat: Duration,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Local {
pub server: SocketAddr,
pub username: Option<String>,
pub password: Option<String>,
pub dual_stack: Option<bool>,
#[serde(default = "default::local::max_packet_size")]
pub max_packet_size: usize,
}
mod default {
pub mod relay {
use crate::utils::{CongestionControl, UdpRelayMode};
use std::time::Duration;
pub const fn certificates() -> Vec<String> {
Vec::new()
}
pub const fn udp_relay_mode() -> UdpRelayMode {
UdpRelayMode::Native
}
pub const fn congestion_control() -> CongestionControl {
CongestionControl::Cubic
}
pub const fn alpn() -> Vec<String> {
Vec::new()
}
pub const fn zero_rtt_handshake() -> bool {
false
}
pub const fn timeout() -> Duration {
Duration::from_secs(8)
}
pub const fn heartbeat() -> Duration {
Duration::from_secs(3)
}
}
pub mod local {
pub const fn max_packet_size() -> usize {
1500
}
}
}
impl Config { impl Config {
pub fn parse<A>(args: A) -> Result<Self, ConfigError> pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> {
where
A: IntoIterator,
A::Item: Into<OsString>,
{
let mut parser = Parser::from_iter(args); let mut parser = Parser::from_iter(args);
let mut path = None; let mut path = None;
@ -29,7 +125,7 @@ impl Config {
Arg::Short('v') | Arg::Long("version") => { Arg::Short('v') | Arg::Long("version") => {
return Err(ConfigError::Version(env!("CARGO_PKG_VERSION"))) return Err(ConfigError::Version(env!("CARGO_PKG_VERSION")))
} }
Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(todo!())), Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(HELP_MSG)),
_ => return Err(ConfigError::Argument(arg.unexpected())), _ => return Err(ConfigError::Argument(arg.unexpected())),
} }
} }
@ -44,9 +140,25 @@ impl Config {
} }
} }
pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let mut parts = s.split(':');
match (parts.next(), parts.next(), parts.next()) {
(Some(domain), Some(port), None) => port.parse().map_or_else(
|e| Err(DeError::custom(e)),
|port| Ok((domain.to_owned(), port)),
),
_ => Err(DeError::custom("invalid server address")),
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum ConfigError { pub enum ConfigError {
#[error("transparent")] #[error(transparent)]
Argument(#[from] ArgumentError), Argument(#[from] ArgumentError),
#[error("no config file specified")] #[error("no config file specified")]
NoConfig, NoConfig,
@ -54,8 +166,8 @@ pub enum ConfigError {
Version(&'static str), Version(&'static str),
#[error("{0}")] #[error("{0}")]
Help(&'static str), Help(&'static str),
#[error("transparent")] #[error(transparent)]
Io(#[from] IoError), Io(#[from] IoError),
#[error("transparent")] #[error(transparent)]
Serde(#[from] SerdeError), Serde(#[from] SerdeError),
} }

View File

@ -1,4 +1,9 @@
use crate::{error::Error, socks5}; use crate::{
config::Relay,
error::Error,
socks5::Server as Socks5Server,
utils::{ServerAddr, UdpRelayMode},
};
use bytes::Bytes; use bytes::Bytes;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use parking_lot::Mutex; use parking_lot::Mutex;
@ -27,14 +32,36 @@ static CONNECTION: AsyncOnceCell<AsyncMutex<Connection>> = AsyncOnceCell::const_
const DEFAULT_CONCURRENT_STREAMS: usize = 32; const DEFAULT_CONCURRENT_STREAMS: usize = 32;
struct Endpoint { pub struct Endpoint {
ep: QuinnEndpoint, ep: QuinnEndpoint,
server: ServerAddr,
token: Vec<u8>,
udp_relay_mode: UdpRelayMode,
zero_rtt_handshake: bool,
timeout: Duration,
heartbeat: Duration,
} }
impl Endpoint { impl Endpoint {
fn new() -> Result<Self, Error> { pub fn set_config(cfg: Relay) -> Result<(), Error> {
let ep = QuinnEndpoint::client(SocketAddr::from(([0, 0, 0, 0], 0)))?; let ep = todo!();
Ok(Self { ep })
let ep = Self {
ep,
server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip),
token: cfg.token.into_bytes(),
udp_relay_mode: cfg.udp_relay_mode,
zero_rtt_handshake: cfg.zero_rtt_handshake,
timeout: cfg.timeout,
heartbeat: cfg.heartbeat,
};
ENDPOINT
.set(Mutex::new(ep))
.map_err(|_| "endpoint already initialized")
.unwrap();
Ok(())
} }
async fn connect(&self) -> Result<Connection, Error> { async fn connect(&self) -> Result<Connection, Error> {
@ -75,8 +102,9 @@ impl Connection {
pub async fn get() -> Result<Connection, Error> { pub async fn get() -> Result<Connection, Error> {
let try_init_conn = async { let try_init_conn = async {
ENDPOINT ENDPOINT
.get_or_try_init(|| Endpoint::new().map(Mutex::new)) .get()
.map(|ep| ep.lock())? .unwrap()
.lock()
.connect() .connect()
.await .await
.map(AsyncMutex::new) .map(AsyncMutex::new)
@ -170,7 +198,7 @@ impl Connection {
} }
Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr),
}; };
socks5::recv_pkt(pkt, addr, assoc_id).await Socks5Server::recv_pkt(pkt, addr, assoc_id).await
} }
Ok(None) => Ok(()), Ok(None) => Ok(()),
Err(err) => Err(Error::from(err)), Err(err) => Err(Error::from(err)),
@ -208,7 +236,7 @@ impl Connection {
} }
Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr),
}; };
socks5::recv_pkt(pkt, addr, assoc_id).await Socks5Server::recv_pkt(pkt, addr, assoc_id).await
} }
Ok(None) => Ok(()), Ok(None) => Ok(()),
Err(err) => Err(Error::from(err)), Err(err) => Err(Error::from(err)),

View File

@ -1,14 +1,20 @@
use self::config::{Config, ConfigError}; use socks5::Server;
use self::{
config::{Config, ConfigError},
connection::Endpoint,
};
use std::{env, process}; use std::{env, process};
mod config; mod config;
mod connection; mod connection;
mod error; mod error;
mod socks5; mod socks5;
mod utils;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let _cfg = match Config::parse(env::args_os()) { let cfg = match Config::parse(env::args_os()) {
Ok(cfg) => cfg, Ok(cfg) => cfg,
Err(ConfigError::Version(msg) | ConfigError::Help(msg)) => { Err(ConfigError::Version(msg) | ConfigError::Help(msg)) => {
println!("{msg}"); println!("{msg}");
@ -20,8 +26,21 @@ async fn main() {
} }
}; };
if let Err(err) = socks5::start().await { match Endpoint::set_config(cfg.relay) {
eprintln!("{err}"); Ok(()) => {}
process::exit(1); Err(err) => {
eprintln!("{err}");
process::exit(1);
}
} }
match Server::set_config(cfg.local).await {
Ok(()) => {}
Err(err) => {
eprintln!("{err}");
process::exit(1);
}
}
Server::start().await;
} }

View File

@ -1,12 +1,12 @@
use crate::{connection::Connection as TuicConnection, error::Error}; use crate::{config::Local, connection::Connection as TuicConnection, error::Error};
use bytes::Bytes; use bytes::Bytes;
use once_cell::sync::Lazy; use once_cell::sync::{Lazy, OnceCell};
use parking_lot::Mutex; use parking_lot::Mutex;
use socks5_proto::{Address, Reply}; use socks5_proto::{Address, Reply};
use socks5_server::{ use socks5_server::{
auth::NoAuth, auth::NoAuth,
connection::{associate, bind, connect}, connection::{associate, bind, connect},
Associate, AssociatedUdpSocket, Bind, Connect, Connection, Server, Associate, AssociatedUdpSocket, Bind, Connect, Connection, Server as Socks5Server,
}; };
use std::{ use std::{
collections::HashMap, collections::HashMap,
@ -24,183 +24,214 @@ use tokio::{
use tokio_util::compat::FuturesAsyncReadCompatExt; use tokio_util::compat::FuturesAsyncReadCompatExt;
use tuic::Address as TuicAddress; use tuic::Address as TuicAddress;
static SERVER: OnceCell<Server> = OnceCell::new();
static NEXT_ASSOCIATE_ID: AtomicU16 = AtomicU16::new(0); static NEXT_ASSOCIATE_ID: AtomicU16 = AtomicU16::new(0);
static UDP_SESSIONS: Lazy<Mutex<HashMap<u16, Arc<AssociatedUdpSocket>>>> = static UDP_SESSIONS: Lazy<Mutex<HashMap<u16, Arc<AssociatedUdpSocket>>>> =
Lazy::new(|| Mutex::new(HashMap::new())); Lazy::new(|| Mutex::new(HashMap::new()));
pub async fn start() -> Result<(), Error> { pub struct Server {
let server = Server::bind("127.0.0.1:5000", Arc::new(NoAuth)).await?; inner: Socks5Server,
dual_stack: Option<bool>,
max_packet_size: usize,
}
while let Ok((conn, _)) = server.accept().await { impl Server {
tokio::spawn(async move { pub async fn set_config(cfg: Local) -> Result<(), Error> {
let res = match conn.handshake().await { let server = Socks5Server::bind(cfg.server, Arc::new(NoAuth)).await?;
Ok(Connection::Associate(associate, addr)) => {
handle_associate(associate, addr).await let server = Self {
inner: server,
dual_stack: cfg.dual_stack,
max_packet_size: cfg.max_packet_size,
};
SERVER
.set(server)
.map_err(|_| "socks5 server already initialized")
.unwrap();
Ok(())
}
pub async fn start() {
let server = SERVER.get().unwrap();
loop {
match server.inner.accept().await {
Ok((conn, _)) => {
tokio::spawn(async move {
let res = match conn.handshake().await {
Ok(Connection::Associate(associate, addr)) => {
Self::handle_associate(associate, addr).await
}
Ok(Connection::Bind(bind, addr)) => Self::handle_bind(bind, addr).await,
Ok(Connection::Connect(connect, addr)) => {
Self::handle_connect(connect, addr).await
}
Err(err) => Err(Error::from(err)),
};
match res {
Ok(_) => {}
Err(err) => eprintln!("{err}"),
}
});
} }
Ok(Connection::Bind(bind, addr)) => handle_bind(bind, addr).await,
Ok(Connection::Connect(connect, addr)) => handle_connect(connect, addr).await,
Err(err) => Err(Error::from(err)),
};
match res {
Ok(_) => {}
Err(err) => eprintln!("{err}"), Err(err) => eprintln!("{err}"),
} }
});
}
Ok(())
}
async fn handle_associate(
assoc: Associate<associate::NeedReply>,
_addr: Address,
) -> Result<(), Error> {
let assoc_socket = UdpSocket::bind(SocketAddr::from((assoc.local_addr()?.ip(), 0)))
.await
.and_then(|socket| {
socket
.local_addr()
.map(|addr| (Arc::new(AssociatedUdpSocket::from((socket, 1500))), addr))
});
match assoc_socket {
Ok((assoc_socket, assoc_addr)) => {
let assoc = assoc
.reply(Reply::Succeeded, Address::SocketAddress(assoc_addr))
.await?;
send_pkt(assoc, assoc_socket).await
}
Err(err) => {
let mut assoc = assoc
.reply(Reply::GeneralFailure, Address::unspecified())
.await?;
let _ = assoc.shutdown().await;
Err(Error::from(err))
} }
} }
}
async fn handle_bind(bind: Bind<bind::NeedFirstReply>, _addr: Address) -> Result<(), Error> { async fn handle_associate(
let mut conn = bind assoc: Associate<associate::NeedReply>,
.reply(Reply::CommandNotSupported, Address::unspecified()) _addr: Address,
.await?;
let _ = conn.shutdown().await;
Ok(())
}
async fn handle_connect(conn: Connect<connect::NeedReply>, addr: Address) -> Result<(), Error> {
let target_addr = match addr {
Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port),
Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr),
};
let relay = match TuicConnection::get().await {
Ok(conn) => conn.connect(target_addr).await,
Err(err) => Err(err),
};
match relay {
Ok(relay) => {
let mut relay = relay.compat();
let conn = conn.reply(Reply::Succeeded, Address::unspecified()).await;
match conn {
Ok(mut conn) => match io::copy_bidirectional(&mut conn, &mut relay).await {
Ok(_) => Ok(()),
Err(err) => {
let _ = conn.shutdown().await;
let _ = relay.shutdown().await;
Err(Error::from(err))
}
},
Err(err) => {
let _ = relay.shutdown().await;
Err(Error::from(err))
}
}
}
Err(err) => {
let mut conn = conn
.reply(Reply::GeneralFailure, Address::unspecified())
.await?;
let _ = conn.shutdown().await;
Err(err)
}
}
}
async fn send_pkt(
mut assoc: Associate<associate::Ready>,
assoc_socket: Arc<AssociatedUdpSocket>,
) -> Result<(), Error> {
let assoc_id = NEXT_ASSOCIATE_ID.fetch_add(1, Ordering::AcqRel);
UDP_SESSIONS.lock().insert(assoc_id, assoc_socket.clone());
let mut connected = None;
async fn accept_pkt(
assoc_socket: &AssociatedUdpSocket,
connected: &mut Option<SocketAddr>,
assoc_id: u16,
) -> Result<(), Error> { ) -> Result<(), Error> {
let (pkt, frag, dst_addr, src_addr) = assoc_socket.recv_from().await?; let assoc_socket = UdpSocket::bind(SocketAddr::from((assoc.local_addr()?.ip(), 0)))
.await
.and_then(|socket| {
socket
.local_addr()
.map(|addr| (Arc::new(AssociatedUdpSocket::from((socket, 1500))), addr))
});
if let Some(connected) = connected { match assoc_socket {
if connected != &src_addr { Ok((assoc_socket, assoc_addr)) => {
Err(IoError::new( let assoc = assoc
ErrorKind::Other, .reply(Reply::Succeeded, Address::SocketAddress(assoc_addr))
format!("invalid source address: {src_addr}"), .await?;
))?; Self::send_pkt(assoc, assoc_socket).await
}
Err(err) => {
let mut assoc = assoc
.reply(Reply::GeneralFailure, Address::unspecified())
.await?;
let _ = assoc.shutdown().await;
Err(Error::from(err))
} }
} else {
assoc_socket.connect(src_addr).await?;
*connected = Some(src_addr);
} }
}
if frag != 0 { async fn handle_bind(bind: Bind<bind::NeedFirstReply>, _addr: Address) -> Result<(), Error> {
Err(IoError::new( let mut conn = bind
ErrorKind::Other, .reply(Reply::CommandNotSupported, Address::unspecified())
format!("fragmented packet is not supported"), .await?;
))?; let _ = conn.shutdown().await;
} Ok(())
}
let target_addr = match dst_addr { async fn handle_connect(conn: Connect<connect::NeedReply>, addr: Address) -> Result<(), Error> {
let target_addr = match addr {
Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port),
Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr),
}; };
TuicConnection::get() let relay = match TuicConnection::get().await {
.await? Ok(conn) => conn.connect(target_addr).await,
.packet(pkt, target_addr, assoc_id) Err(err) => Err(err),
.await };
}
let res = tokio::select! { match relay {
res = assoc.wait_until_closed() => res, Ok(relay) => {
_ = async { loop { let mut relay = relay.compat();
if let Err(err) = accept_pkt(&assoc_socket, &mut connected, assoc_id).await { let conn = conn.reply(Reply::Succeeded, Address::unspecified()).await;
eprintln!("{err}");
match conn {
Ok(mut conn) => match io::copy_bidirectional(&mut conn, &mut relay).await {
Ok(_) => Ok(()),
Err(err) => {
let _ = conn.shutdown().await;
let _ = relay.shutdown().await;
Err(Error::from(err))
}
},
Err(err) => {
let _ = relay.shutdown().await;
Err(Error::from(err))
}
}
} }
}} => unreachable!(), Err(err) => {
}; let mut conn = conn
.reply(Reply::GeneralFailure, Address::unspecified())
let _ = assoc.shutdown().await; .await?;
UDP_SESSIONS.lock().remove(&assoc_id); let _ = conn.shutdown().await;
Err(err)
match TuicConnection::get().await { }
Ok(conn) => match conn.dissociate(assoc_id).await { }
Ok(_) => {}
Err(err) => eprintln!("{err}"),
},
Err(err) => eprintln!("{err}"),
} }
Ok(res?) async fn send_pkt(
} mut assoc: Associate<associate::Ready>,
assoc_socket: Arc<AssociatedUdpSocket>,
) -> Result<(), Error> {
let assoc_id = NEXT_ASSOCIATE_ID.fetch_add(1, Ordering::AcqRel);
UDP_SESSIONS.lock().insert(assoc_id, assoc_socket.clone());
let mut connected = None;
pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { async fn accept_pkt(
let sessions = UDP_SESSIONS.lock(); assoc_socket: &AssociatedUdpSocket,
let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() }; connected: &mut Option<SocketAddr>,
assoc_socket.send(pkt, 0, addr).await?; assoc_id: u16,
Ok(()) ) -> Result<(), Error> {
let (pkt, frag, dst_addr, src_addr) = assoc_socket.recv_from().await?;
if let Some(connected) = connected {
if connected != &src_addr {
Err(IoError::new(
ErrorKind::Other,
format!("invalid source address: {src_addr}"),
))?;
}
} else {
assoc_socket.connect(src_addr).await?;
*connected = Some(src_addr);
}
if frag != 0 {
Err(IoError::new(
ErrorKind::Other,
format!("fragmented packet is not supported"),
))?;
}
let target_addr = match dst_addr {
Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port),
Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr),
};
TuicConnection::get()
.await?
.packet(pkt, target_addr, assoc_id)
.await
}
let res = tokio::select! {
res = assoc.wait_until_closed() => res,
_ = async { loop {
if let Err(err) = accept_pkt(&assoc_socket, &mut connected, assoc_id).await {
eprintln!("{err}");
}
}} => unreachable!(),
};
let _ = assoc.shutdown().await;
UDP_SESSIONS.lock().remove(&assoc_id);
match TuicConnection::get().await {
Ok(conn) => match conn.dissociate(assoc_id).await {
Ok(_) => {}
Err(err) => eprintln!("{err}"),
},
Err(err) => eprintln!("{err}"),
}
Ok(res?)
}
pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> {
let sessions = UDP_SESSIONS.lock();
let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() };
assoc_socket.send(pkt, 0, addr).await?;
Ok(())
}
} }

65
tuic-client/src/utils.rs Normal file
View File

@ -0,0 +1,65 @@
use serde::{de::Error as DeError, Deserialize, Deserializer};
use std::{fmt::Display, net::IpAddr, str::FromStr};
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: FromStr,
<T as FromStr>::Err: Display,
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
T::from_str(&s).map_err(DeError::custom)
}
pub struct ServerAddr {
domain: String,
port: u16,
ip: Option<IpAddr>,
}
impl ServerAddr {
pub fn new(domain: String, port: u16, ip: Option<IpAddr>) -> Self {
Self { domain, port, ip }
}
}
pub enum UdpRelayMode {
Native,
Quic,
}
impl FromStr for UdpRelayMode {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("native") {
Ok(Self::Native)
} else if s.eq_ignore_ascii_case("quic") {
Ok(Self::Quic)
} else {
Err("invalid UDP relay mode")
}
}
}
pub enum CongestionControl {
Cubic,
NewReno,
Bbr,
}
impl FromStr for CongestionControl {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("cubic") {
Ok(Self::Cubic)
} else if s.eq_ignore_ascii_case("new_reno") || s.eq_ignore_ascii_case("newreno") {
Ok(Self::NewReno)
} else if s.eq_ignore_ascii_case("bbr") {
Ok(Self::Bbr)
} else {
Err("invalid congestion control")
}
}
}