1
0

implement client config reading

This commit is contained in:
EAimTY 2023-02-02 13:45:36 +09:00
parent 1065e3e938
commit bcc79d7c5d
5 changed files with 437 additions and 182 deletions

View File

@ -1,19 +1,115 @@
use crate::utils::{self, CongestionControl, UdpRelayMode};
use lexopt::{Arg, Error as ArgumentError, Parser};
use serde::Deserialize;
use serde::{de::Error as DeError, Deserialize, Deserializer};
use serde_json::Error as SerdeError;
use std::{ffi::OsString, fs::File, io::Error as IoError};
use std::{
env::ArgsOs,
fs::File,
io::Error as IoError,
net::{IpAddr, SocketAddr},
time::Duration,
};
use thiserror::Error;
const HELP_MSG: &str = r#"
Usage tuic-client [arguments]
Arguments:
-c, --config <path> Path to the config file (required)
-v, --version Print the version
-h, --help Print this help message
"#;
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Config {}
pub struct Config {
pub relay: Relay,
pub local: Local,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Relay {
#[serde(deserialize_with = "deserialize_server")]
pub server: (String, u16),
pub token: String,
pub ip: Option<IpAddr>,
#[serde(default = "default::relay::certificates")]
pub certificates: Vec<String>,
#[serde(
default = "default::relay::udp_relay_mode",
deserialize_with = "utils::deserialize_from_str"
)]
pub udp_relay_mode: UdpRelayMode,
#[serde(
default = "default::relay::congestion_control",
deserialize_with = "utils::deserialize_from_str"
)]
pub congestion_control: CongestionControl,
#[serde(default = "default::relay::alpn")]
pub alpn: Vec<String>,
#[serde(default = "default::relay::zero_rtt_handshake")]
pub zero_rtt_handshake: bool,
#[serde(default = "default::relay::timeout")]
pub timeout: Duration,
#[serde(default = "default::relay::heartbeat")]
pub heartbeat: Duration,
}
#[derive(Deserialize)]
#[serde(deny_unknown_fields)]
pub struct Local {
pub server: SocketAddr,
pub username: Option<String>,
pub password: Option<String>,
pub dual_stack: Option<bool>,
#[serde(default = "default::local::max_packet_size")]
pub max_packet_size: usize,
}
mod default {
pub mod relay {
use crate::utils::{CongestionControl, UdpRelayMode};
use std::time::Duration;
pub const fn certificates() -> Vec<String> {
Vec::new()
}
pub const fn udp_relay_mode() -> UdpRelayMode {
UdpRelayMode::Native
}
pub const fn congestion_control() -> CongestionControl {
CongestionControl::Cubic
}
pub const fn alpn() -> Vec<String> {
Vec::new()
}
pub const fn zero_rtt_handshake() -> bool {
false
}
pub const fn timeout() -> Duration {
Duration::from_secs(8)
}
pub const fn heartbeat() -> Duration {
Duration::from_secs(3)
}
}
pub mod local {
pub const fn max_packet_size() -> usize {
1500
}
}
}
impl Config {
pub fn parse<A>(args: A) -> Result<Self, ConfigError>
where
A: IntoIterator,
A::Item: Into<OsString>,
{
pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> {
let mut parser = Parser::from_iter(args);
let mut path = None;
@ -29,7 +125,7 @@ impl Config {
Arg::Short('v') | Arg::Long("version") => {
return Err(ConfigError::Version(env!("CARGO_PKG_VERSION")))
}
Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(todo!())),
Arg::Short('h') | Arg::Long("help") => return Err(ConfigError::Help(HELP_MSG)),
_ => return Err(ConfigError::Argument(arg.unexpected())),
}
}
@ -44,9 +140,25 @@ impl Config {
}
}
pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let mut parts = s.split(':');
match (parts.next(), parts.next(), parts.next()) {
(Some(domain), Some(port), None) => port.parse().map_or_else(
|e| Err(DeError::custom(e)),
|port| Ok((domain.to_owned(), port)),
),
_ => Err(DeError::custom("invalid server address")),
}
}
#[derive(Debug, Error)]
pub enum ConfigError {
#[error("transparent")]
#[error(transparent)]
Argument(#[from] ArgumentError),
#[error("no config file specified")]
NoConfig,
@ -54,8 +166,8 @@ pub enum ConfigError {
Version(&'static str),
#[error("{0}")]
Help(&'static str),
#[error("transparent")]
#[error(transparent)]
Io(#[from] IoError),
#[error("transparent")]
#[error(transparent)]
Serde(#[from] SerdeError),
}

View File

@ -1,4 +1,9 @@
use crate::{error::Error, socks5};
use crate::{
config::Relay,
error::Error,
socks5::Server as Socks5Server,
utils::{ServerAddr, UdpRelayMode},
};
use bytes::Bytes;
use once_cell::sync::OnceCell;
use parking_lot::Mutex;
@ -27,14 +32,36 @@ static CONNECTION: AsyncOnceCell<AsyncMutex<Connection>> = AsyncOnceCell::const_
const DEFAULT_CONCURRENT_STREAMS: usize = 32;
struct Endpoint {
pub struct Endpoint {
ep: QuinnEndpoint,
server: ServerAddr,
token: Vec<u8>,
udp_relay_mode: UdpRelayMode,
zero_rtt_handshake: bool,
timeout: Duration,
heartbeat: Duration,
}
impl Endpoint {
fn new() -> Result<Self, Error> {
let ep = QuinnEndpoint::client(SocketAddr::from(([0, 0, 0, 0], 0)))?;
Ok(Self { ep })
pub fn set_config(cfg: Relay) -> Result<(), Error> {
let ep = todo!();
let ep = Self {
ep,
server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip),
token: cfg.token.into_bytes(),
udp_relay_mode: cfg.udp_relay_mode,
zero_rtt_handshake: cfg.zero_rtt_handshake,
timeout: cfg.timeout,
heartbeat: cfg.heartbeat,
};
ENDPOINT
.set(Mutex::new(ep))
.map_err(|_| "endpoint already initialized")
.unwrap();
Ok(())
}
async fn connect(&self) -> Result<Connection, Error> {
@ -75,8 +102,9 @@ impl Connection {
pub async fn get() -> Result<Connection, Error> {
let try_init_conn = async {
ENDPOINT
.get_or_try_init(|| Endpoint::new().map(Mutex::new))
.map(|ep| ep.lock())?
.get()
.unwrap()
.lock()
.connect()
.await
.map(AsyncMutex::new)
@ -170,7 +198,7 @@ impl Connection {
}
Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr),
};
socks5::recv_pkt(pkt, addr, assoc_id).await
Socks5Server::recv_pkt(pkt, addr, assoc_id).await
}
Ok(None) => Ok(()),
Err(err) => Err(Error::from(err)),
@ -208,7 +236,7 @@ impl Connection {
}
Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr),
};
socks5::recv_pkt(pkt, addr, assoc_id).await
Socks5Server::recv_pkt(pkt, addr, assoc_id).await
}
Ok(None) => Ok(()),
Err(err) => Err(Error::from(err)),

View File

@ -1,14 +1,20 @@
use self::config::{Config, ConfigError};
use socks5::Server;
use self::{
config::{Config, ConfigError},
connection::Endpoint,
};
use std::{env, process};
mod config;
mod connection;
mod error;
mod socks5;
mod utils;
#[tokio::main]
async fn main() {
let _cfg = match Config::parse(env::args_os()) {
let cfg = match Config::parse(env::args_os()) {
Ok(cfg) => cfg,
Err(ConfigError::Version(msg) | ConfigError::Help(msg)) => {
println!("{msg}");
@ -20,8 +26,21 @@ async fn main() {
}
};
if let Err(err) = socks5::start().await {
match Endpoint::set_config(cfg.relay) {
Ok(()) => {}
Err(err) => {
eprintln!("{err}");
process::exit(1);
}
}
match Server::set_config(cfg.local).await {
Ok(()) => {}
Err(err) => {
eprintln!("{err}");
process::exit(1);
}
}
Server::start().await;
}

View File

@ -1,12 +1,12 @@
use crate::{connection::Connection as TuicConnection, error::Error};
use crate::{config::Local, connection::Connection as TuicConnection, error::Error};
use bytes::Bytes;
use once_cell::sync::Lazy;
use once_cell::sync::{Lazy, OnceCell};
use parking_lot::Mutex;
use socks5_proto::{Address, Reply};
use socks5_server::{
auth::NoAuth,
connection::{associate, bind, connect},
Associate, AssociatedUdpSocket, Bind, Connect, Connection, Server,
Associate, AssociatedUdpSocket, Bind, Connect, Connection, Server as Socks5Server,
};
use std::{
collections::HashMap,
@ -24,21 +24,50 @@ use tokio::{
use tokio_util::compat::FuturesAsyncReadCompatExt;
use tuic::Address as TuicAddress;
static SERVER: OnceCell<Server> = OnceCell::new();
static NEXT_ASSOCIATE_ID: AtomicU16 = AtomicU16::new(0);
static UDP_SESSIONS: Lazy<Mutex<HashMap<u16, Arc<AssociatedUdpSocket>>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
pub async fn start() -> Result<(), Error> {
let server = Server::bind("127.0.0.1:5000", Arc::new(NoAuth)).await?;
pub struct Server {
inner: Socks5Server,
dual_stack: Option<bool>,
max_packet_size: usize,
}
while let Ok((conn, _)) = server.accept().await {
impl Server {
pub async fn set_config(cfg: Local) -> Result<(), Error> {
let server = Socks5Server::bind(cfg.server, Arc::new(NoAuth)).await?;
let server = Self {
inner: server,
dual_stack: cfg.dual_stack,
max_packet_size: cfg.max_packet_size,
};
SERVER
.set(server)
.map_err(|_| "socks5 server already initialized")
.unwrap();
Ok(())
}
pub async fn start() {
let server = SERVER.get().unwrap();
loop {
match server.inner.accept().await {
Ok((conn, _)) => {
tokio::spawn(async move {
let res = match conn.handshake().await {
Ok(Connection::Associate(associate, addr)) => {
handle_associate(associate, addr).await
Self::handle_associate(associate, addr).await
}
Ok(Connection::Bind(bind, addr)) => Self::handle_bind(bind, addr).await,
Ok(Connection::Connect(connect, addr)) => {
Self::handle_connect(connect, addr).await
}
Ok(Connection::Bind(bind, addr)) => handle_bind(bind, addr).await,
Ok(Connection::Connect(connect, addr)) => handle_connect(connect, addr).await,
Err(err) => Err(Error::from(err)),
};
@ -48,14 +77,15 @@ pub async fn start() -> Result<(), Error> {
}
});
}
Err(err) => eprintln!("{err}"),
}
}
}
Ok(())
}
async fn handle_associate(
async fn handle_associate(
assoc: Associate<associate::NeedReply>,
_addr: Address,
) -> Result<(), Error> {
) -> Result<(), Error> {
let assoc_socket = UdpSocket::bind(SocketAddr::from((assoc.local_addr()?.ip(), 0)))
.await
.and_then(|socket| {
@ -69,7 +99,7 @@ async fn handle_associate(
let assoc = assoc
.reply(Reply::Succeeded, Address::SocketAddress(assoc_addr))
.await?;
send_pkt(assoc, assoc_socket).await
Self::send_pkt(assoc, assoc_socket).await
}
Err(err) => {
let mut assoc = assoc
@ -79,17 +109,17 @@ async fn handle_associate(
Err(Error::from(err))
}
}
}
}
async fn handle_bind(bind: Bind<bind::NeedFirstReply>, _addr: Address) -> Result<(), Error> {
async fn handle_bind(bind: Bind<bind::NeedFirstReply>, _addr: Address) -> Result<(), Error> {
let mut conn = bind
.reply(Reply::CommandNotSupported, Address::unspecified())
.await?;
let _ = conn.shutdown().await;
Ok(())
}
}
async fn handle_connect(conn: Connect<connect::NeedReply>, addr: Address) -> Result<(), Error> {
async fn handle_connect(conn: Connect<connect::NeedReply>, addr: Address) -> Result<(), Error> {
let target_addr = match addr {
Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port),
Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr),
@ -128,12 +158,12 @@ async fn handle_connect(conn: Connect<connect::NeedReply>, addr: Address) -> Res
Err(err)
}
}
}
}
async fn send_pkt(
async fn send_pkt(
mut assoc: Associate<associate::Ready>,
assoc_socket: Arc<AssociatedUdpSocket>,
) -> Result<(), Error> {
) -> Result<(), Error> {
let assoc_id = NEXT_ASSOCIATE_ID.fetch_add(1, Ordering::AcqRel);
UDP_SESSIONS.lock().insert(assoc_id, assoc_socket.clone());
let mut connected = None;
@ -196,11 +226,12 @@ async fn send_pkt(
}
Ok(res?)
}
}
pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> {
pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> {
let sessions = UDP_SESSIONS.lock();
let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() };
assoc_socket.send(pkt, 0, addr).await?;
Ok(())
}
}

65
tuic-client/src/utils.rs Normal file
View File

@ -0,0 +1,65 @@
use serde::{de::Error as DeError, Deserialize, Deserializer};
use std::{fmt::Display, net::IpAddr, str::FromStr};
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: FromStr,
<T as FromStr>::Err: Display,
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
T::from_str(&s).map_err(DeError::custom)
}
pub struct ServerAddr {
domain: String,
port: u16,
ip: Option<IpAddr>,
}
impl ServerAddr {
pub fn new(domain: String, port: u16, ip: Option<IpAddr>) -> Self {
Self { domain, port, ip }
}
}
pub enum UdpRelayMode {
Native,
Quic,
}
impl FromStr for UdpRelayMode {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("native") {
Ok(Self::Native)
} else if s.eq_ignore_ascii_case("quic") {
Ok(Self::Quic)
} else {
Err("invalid UDP relay mode")
}
}
}
pub enum CongestionControl {
Cubic,
NewReno,
Bbr,
}
impl FromStr for CongestionControl {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.eq_ignore_ascii_case("cubic") {
Ok(Self::Cubic)
} else if s.eq_ignore_ascii_case("new_reno") || s.eq_ignore_ascii_case("newreno") {
Ok(Self::NewReno)
} else if s.eq_ignore_ascii_case("bbr") {
Ok(Self::Bbr)
} else {
Err("invalid congestion control")
}
}
}