1
0

reading client config to endpoint

This commit is contained in:
EAimTY 2023-02-02 23:42:53 +09:00
parent 0070112e23
commit 9806c62fe7
6 changed files with 265 additions and 85 deletions

View File

@ -5,11 +5,15 @@ edition = "2021"
[dependencies] [dependencies]
bytes = { version = "1.3.0", default-features = false, features = ["std"] } bytes = { version = "1.3.0", default-features = false, features = ["std"] }
crossbeam-utils = { version = "0.8.14", default-features = false, features = ["std"] }
lexopt = { version = "0.3.0", default-features = false } lexopt = { version = "0.3.0", default-features = false }
once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] } once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] }
parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] } parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] }
quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] } quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] }
register-count = { version = "0.1.0", default-features = false, features = ["std"] } register-count = { version = "0.1.0", default-features = false, features = ["std"] }
rustls = { version = "0.20.8", default-features = false, features = ["quic"] }
rustls-native-certs = { version = "0.6.2", default-features = false }
rustls-pemfile = { version = "1.0.2", default-features = false }
serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] } serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] }
serde_json = { version = "1.0.91", default-features = false, features = ["std"] } serde_json = { version = "1.0.91", default-features = false, features = ["std"] }
socket2 = { version = "0.4.7", default-features = false } socket2 = { version = "0.4.7", default-features = false }
@ -20,3 +24,4 @@ tokio = { version = "1.24.2", default-features = false, features = ["macros", "n
tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] } tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] }
tuic = { path = "../tuic", default-features = false } tuic = { path = "../tuic", default-features = false }
tuic-quinn = { path = "../tuic-quinn", default-features = false } tuic-quinn = { path = "../tuic-quinn", default-features = false }
webpki = { version = "0.22.0", default-features = false }

View File

@ -1,12 +1,14 @@
use crate::utils::{self, CongestionControl, UdpRelayMode}; use crate::utils::{CongestionControl, UdpRelayMode};
use lexopt::{Arg, Error as ArgumentError, Parser}; use lexopt::{Arg, Error as ArgumentError, Parser};
use serde::{de::Error as DeError, Deserialize, Deserializer}; use serde::{de::Error as DeError, Deserialize, Deserializer};
use serde_json::Error as SerdeError; use serde_json::Error as SerdeError;
use std::{ use std::{
env::ArgsOs, env::ArgsOs,
fmt::Display,
fs::File, fs::File,
io::Error as IoError, io::Error as IoError,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
str::FromStr,
time::Duration, time::Duration,
}; };
use thiserror::Error; use thiserror::Error;
@ -38,22 +40,26 @@ pub struct Relay {
pub certificates: Vec<String>, pub certificates: Vec<String>,
#[serde( #[serde(
default = "default::relay::udp_relay_mode", default = "default::relay::udp_relay_mode",
deserialize_with = "utils::deserialize_from_str" deserialize_with = "deserialize_from_str"
)] )]
pub udp_relay_mode: UdpRelayMode, pub udp_relay_mode: UdpRelayMode,
#[serde( #[serde(
default = "default::relay::congestion_control", default = "default::relay::congestion_control",
deserialize_with = "utils::deserialize_from_str" deserialize_with = "deserialize_from_str"
)] )]
pub congestion_control: CongestionControl, pub congestion_control: CongestionControl,
#[serde(default = "default::relay::alpn")] #[serde(default = "default::relay::alpn")]
pub alpn: Vec<String>, pub alpn: Vec<String>,
#[serde(default = "default::relay::zero_rtt_handshake")] #[serde(default = "default::relay::zero_rtt_handshake")]
pub zero_rtt_handshake: bool, pub zero_rtt_handshake: bool,
#[serde(default = "default::relay::disable_sni")]
pub disable_sni: bool,
#[serde(default = "default::relay::timeout")] #[serde(default = "default::relay::timeout")]
pub timeout: Duration, pub timeout: Duration,
#[serde(default = "default::relay::heartbeat")] #[serde(default = "default::relay::heartbeat")]
pub heartbeat: Duration, pub heartbeat: Duration,
#[serde(default = "default::relay::disable_native_certificates")]
pub disable_native_certificates: bool,
} }
#[derive(Deserialize)] #[derive(Deserialize)]
@ -67,47 +73,6 @@ pub struct Local {
pub max_packet_size: usize, pub max_packet_size: usize,
} }
mod default {
pub mod relay {
use crate::utils::{CongestionControl, UdpRelayMode};
use std::time::Duration;
pub const fn certificates() -> Vec<String> {
Vec::new()
}
pub const fn udp_relay_mode() -> UdpRelayMode {
UdpRelayMode::Native
}
pub const fn congestion_control() -> CongestionControl {
CongestionControl::Cubic
}
pub const fn alpn() -> Vec<String> {
Vec::new()
}
pub const fn zero_rtt_handshake() -> bool {
false
}
pub const fn timeout() -> Duration {
Duration::from_secs(8)
}
pub const fn heartbeat() -> Duration {
Duration::from_secs(3)
}
}
pub mod local {
pub const fn max_packet_size() -> usize {
1500
}
}
}
impl Config { impl Config {
pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> { pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> {
let mut parser = Parser::from_iter(args); let mut parser = Parser::from_iter(args);
@ -135,11 +100,69 @@ impl Config {
} }
let file = File::open(path.unwrap())?; let file = File::open(path.unwrap())?;
Ok(serde_json::from_reader(file)?) Ok(serde_json::from_reader(file)?)
} }
} }
mod default {
pub mod relay {
use crate::utils::{CongestionControl, UdpRelayMode};
use std::time::Duration;
pub fn certificates() -> Vec<String> {
Vec::new()
}
pub fn udp_relay_mode() -> UdpRelayMode {
UdpRelayMode::Native
}
pub fn congestion_control() -> CongestionControl {
CongestionControl::Cubic
}
pub fn alpn() -> Vec<String> {
Vec::new()
}
pub fn zero_rtt_handshake() -> bool {
false
}
pub fn disable_sni() -> bool {
false
}
pub fn timeout() -> Duration {
Duration::from_secs(8)
}
pub fn heartbeat() -> Duration {
Duration::from_secs(3)
}
pub fn disable_native_certificates() -> bool {
false
}
}
pub mod local {
pub fn max_packet_size() -> usize {
1500
}
}
}
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
T: FromStr,
<T as FromStr>::Err: Display,
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
T::from_str(&s).map_err(DeError::custom)
}
pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error> pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,

View File

@ -2,18 +2,22 @@ use crate::{
config::Relay, config::Relay,
error::Error, error::Error,
socks5::Server as Socks5Server, socks5::Server as Socks5Server,
utils::{ServerAddr, UdpRelayMode}, utils::{self, CongestionControl, ServerAddr, UdpRelayMode},
}; };
use bytes::Bytes; use bytes::Bytes;
use crossbeam_utils::atomic::AtomicCell;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use parking_lot::Mutex; use parking_lot::Mutex;
use quinn::{ use quinn::{
Connection as QuinnConnection, Endpoint as QuinnEndpoint, RecvStream, SendStream, VarInt, congestion::{BbrConfig, CubicConfig, NewRenoConfig},
ClientConfig, Connection as QuinnConnection, Endpoint as QuinnEndpoint, EndpointConfig,
RecvStream, SendStream, TokioRuntime, TransportConfig, VarInt,
}; };
use register_count::{Counter, Register}; use register_count::{Counter, Register};
use rustls::{version, ClientConfig as RustlsClientConfig};
use socks5_proto::Address as Socks5Address; use socks5_proto::Address as Socks5Address;
use std::{ use std::{
net::SocketAddr, net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
sync::{ sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
Arc, Arc,
@ -29,30 +33,67 @@ use tuic_quinn::{side, Connect, Connection as Model, Task};
static ENDPOINT: OnceCell<Mutex<Endpoint>> = OnceCell::new(); static ENDPOINT: OnceCell<Mutex<Endpoint>> = OnceCell::new();
static CONNECTION: AsyncOnceCell<AsyncMutex<Connection>> = AsyncOnceCell::const_new(); static CONNECTION: AsyncOnceCell<AsyncMutex<Connection>> = AsyncOnceCell::const_new();
static TIMEOUT: AtomicCell<Duration> = AtomicCell::new(Duration::from_secs(0));
const DEFAULT_CONCURRENT_STREAMS: usize = 32; const DEFAULT_CONCURRENT_STREAMS: usize = 32;
pub struct Endpoint { pub struct Endpoint {
ep: QuinnEndpoint, ep: QuinnEndpoint,
server: ServerAddr, server: ServerAddr,
token: Vec<u8>, token: Arc<[u8]>,
udp_relay_mode: UdpRelayMode, udp_relay_mode: UdpRelayMode,
zero_rtt_handshake: bool, zero_rtt_handshake: bool,
timeout: Duration,
heartbeat: Duration, heartbeat: Duration,
} }
impl Endpoint { impl Endpoint {
pub fn set_config(cfg: Relay) -> Result<(), Error> { pub fn set_config(cfg: Relay) -> Result<(), Error> {
let ep = todo!(); let certs = utils::load_certs(cfg.certificates, cfg.disable_native_certificates)?;
let mut crypto = RustlsClientConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
.with_protocol_versions(&[&version::TLS13])
.unwrap()
.with_root_certificates(certs)
.with_no_client_auth();
crypto.alpn_protocols = cfg.alpn.into_iter().map(|alpn| alpn.into_bytes()).collect();
crypto.enable_early_data = true;
crypto.enable_sni = !cfg.disable_sni;
let mut config = ClientConfig::new(Arc::new(crypto));
let mut tp_cfg = TransportConfig::default();
tp_cfg
.max_concurrent_bidi_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32))
.max_concurrent_uni_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32))
.max_idle_timeout(None);
match cfg.congestion_control {
CongestionControl::Cubic => {
tp_cfg.congestion_controller_factory(Arc::new(CubicConfig::default()))
}
CongestionControl::NewReno => {
tp_cfg.congestion_controller_factory(Arc::new(NewRenoConfig::default()))
}
CongestionControl::Bbr => {
tp_cfg.congestion_controller_factory(Arc::new(BbrConfig::default()))
}
};
config.transport_config(Arc::new(tp_cfg));
let socket = UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], 0)))?;
let mut ep = QuinnEndpoint::new(EndpointConfig::default(), None, socket, TokioRuntime)?;
ep.set_default_client_config(config);
let ep = Self { let ep = Self {
ep, ep,
server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip), server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip),
token: cfg.token.into_bytes(), token: Arc::from(cfg.token.into_bytes().into_boxed_slice()),
udp_relay_mode: cfg.udp_relay_mode, udp_relay_mode: cfg.udp_relay_mode,
zero_rtt_handshake: cfg.zero_rtt_handshake, zero_rtt_handshake: cfg.zero_rtt_handshake,
timeout: cfg.timeout,
heartbeat: cfg.heartbeat, heartbeat: cfg.heartbeat,
}; };
@ -61,19 +102,67 @@ impl Endpoint {
.map_err(|_| "endpoint already initialized") .map_err(|_| "endpoint already initialized")
.unwrap(); .unwrap();
TIMEOUT.store(cfg.timeout);
Ok(()) Ok(())
} }
async fn connect(&self) -> Result<Connection, Error> { async fn connect(&mut self) -> Result<Connection, Error> {
let conn = self async fn connect_to(
.ep ep: &mut QuinnEndpoint,
.connect(SocketAddr::from(([127, 0, 0, 1], 8080)), "localhost")? addr: SocketAddr,
server_name: &str,
udp_relay_mode: UdpRelayMode,
zero_rtt_handshake: bool,
) -> Result<Connection, Error> {
let match_ipv4 = addr.is_ipv4() && ep.local_addr().map_or(false, |addr| addr.is_ipv4());
let match_ipv6 = addr.is_ipv6() && ep.local_addr().map_or(false, |addr| addr.is_ipv6());
if !match_ipv4 && !match_ipv6 {
let bind_addr = if addr.is_ipv4() {
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
} else {
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
};
ep.rebind(UdpSocket::bind(bind_addr)?)?;
}
let conn = ep.connect(addr, server_name)?;
let conn = if zero_rtt_handshake {
match conn.into_0rtt() {
Ok((conn, _)) => conn,
Err(conn) => conn.await?,
}
} else {
conn.await?
};
Ok(Connection::new(conn, udp_relay_mode))
}
let mut last_err = None;
for addr in self.server.resolve().await? {
match connect_to(
&mut self.ep,
addr,
self.server.server_name(),
self.udp_relay_mode,
self.zero_rtt_handshake,
)
.await .await
.map(Connection::new)?; {
Ok(conn) => {
tokio::spawn(conn.clone().init(self.token.clone(), self.heartbeat));
return Ok(conn);
}
Err(err) => last_err = Some(err),
}
}
tokio::spawn(conn.clone().init()); Err(last_err.unwrap_or(Error::DnsResolve))
Ok(conn)
} }
} }
@ -81,6 +170,7 @@ impl Endpoint {
pub struct Connection { pub struct Connection {
conn: QuinnConnection, conn: QuinnConnection,
model: Model<side::Client>, model: Model<side::Client>,
udp_relay_mode: UdpRelayMode,
remote_uni_stream_cnt: Counter, remote_uni_stream_cnt: Counter,
remote_bi_stream_cnt: Counter, remote_bi_stream_cnt: Counter,
max_concurrent_uni_streams: Arc<AtomicUsize>, max_concurrent_uni_streams: Arc<AtomicUsize>,
@ -88,10 +178,11 @@ pub struct Connection {
} }
impl Connection { impl Connection {
fn new(conn: QuinnConnection) -> Self { fn new(conn: QuinnConnection, udp_relay_mode: UdpRelayMode) -> Self {
Self { Self {
conn: conn.clone(), conn: conn.clone(),
model: Model::<side::Client>::new(conn), model: Model::<side::Client>::new(conn),
udp_relay_mode,
remote_uni_stream_cnt: Counter::new(), remote_uni_stream_cnt: Counter::new(),
remote_bi_stream_cnt: Counter::new(), remote_bi_stream_cnt: Counter::new(),
max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)),
@ -125,7 +216,7 @@ impl Connection {
Ok::<_, Error>(conn.clone()) Ok::<_, Error>(conn.clone())
}; };
let conn = time::timeout(Duration::from_secs(5), try_get_conn) let conn = time::timeout(TIMEOUT.load(), try_get_conn)
.await .await
.map_err(|_| Error::Timeout)??; .map_err(|_| Error::Timeout)??;
@ -137,7 +228,11 @@ impl Connection {
} }
pub async fn packet(&self, pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> { pub async fn packet(&self, pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> {
self.model.packet_quic(pkt, addr, assoc_id).await?; // TODO match self.udp_relay_mode {
UdpRelayMode::Native => self.model.packet_native(pkt, addr, assoc_id)?,
UdpRelayMode::Quic => self.model.packet_quic(pkt, addr, assoc_id).await?,
}
Ok(()) Ok(())
} }
@ -250,16 +345,26 @@ impl Connection {
} }
} }
async fn authenticate(self) { async fn authenticate(self, token: Arc<[u8]>) {
match self.model.authenticate([0; 32]).await { let mut buf = [0; 32];
match self.conn.export_keying_material(&mut buf, &token, &token) {
Ok(()) => {}
Err(_) => {
eprintln!("token length too short");
return;
}
}
match self.model.authenticate(buf).await {
Ok(()) => {} Ok(()) => {}
Err(err) => eprintln!("{err}"), Err(err) => eprintln!("{err}"),
} }
} }
async fn heartbeat(self) { async fn heartbeat(self, heartbeat: Duration) {
loop { loop {
time::sleep(Duration::from_secs(5)).await; time::sleep(heartbeat).await;
if self.is_closed() { if self.is_closed() {
break; break;
@ -272,9 +377,9 @@ impl Connection {
} }
} }
async fn init(self) { async fn init(self, token: Arc<[u8]>, heartbeat: Duration) {
tokio::spawn(self.clone().authenticate()); tokio::spawn(self.clone().authenticate(token));
tokio::spawn(self.clone().heartbeat()); tokio::spawn(self.clone().heartbeat(heartbeat));
let err = loop { let err = loop {
tokio::select! { tokio::select! {

View File

@ -2,6 +2,7 @@ use quinn::{ConnectError, ConnectionError};
use std::io::Error as IoError; use std::io::Error as IoError;
use thiserror::Error; use thiserror::Error;
use tuic_quinn::Error as ModelError; use tuic_quinn::Error as ModelError;
use webpki::Error as WebpkiError;
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum Error { pub enum Error {
@ -13,8 +14,12 @@ pub enum Error {
Connection(#[from] ConnectionError), Connection(#[from] ConnectionError),
#[error(transparent)] #[error(transparent)]
Model(#[from] ModelError), Model(#[from] ModelError),
#[error("timeout")] #[error(transparent)]
Webpki(#[from] WebpkiError),
#[error("timeout establishing connection")]
Timeout, Timeout,
#[error("invalid authentication")] #[error("cannot resolve the server name")]
InvalidAuth, DnsResolve,
#[error("invalid socks5 authentication")]
InvalidSocks5Auth,
} }

View File

@ -60,7 +60,7 @@ impl Server {
Arc::new(Password::new(username.into_bytes(), password.into_bytes())) Arc::new(Password::new(username.into_bytes(), password.into_bytes()))
} }
(None, None) => Arc::new(NoAuth), (None, None) => Arc::new(NoAuth),
_ => return Err(Error::InvalidAuth), _ => return Err(Error::InvalidSocks5Auth),
}; };
let server = Self { let server = Self {

View File

@ -1,14 +1,40 @@
use serde::{de::Error as DeError, Deserialize, Deserializer}; use crate::error::Error;
use std::{fmt::Display, net::IpAddr, str::FromStr}; use rustls::{Certificate, RootCertStore};
use rustls_pemfile::Item;
use std::{
fs::{self, File},
io::BufReader,
net::{IpAddr, SocketAddr},
str::FromStr,
};
use tokio::net;
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error> pub fn load_certs(paths: Vec<String>, disable_native: bool) -> Result<RootCertStore, Error> {
where let mut certs = RootCertStore::empty();
T: FromStr,
<T as FromStr>::Err: Display, for path in &paths {
D: Deserializer<'de>, let mut file = BufReader::new(File::open(path)?);
{
let s = String::deserialize(deserializer)?; while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) {
T::from_str(&s).map_err(DeError::custom) if let Item::X509Certificate(cert) = item {
certs.add(&Certificate(cert))?;
}
}
}
if certs.is_empty() {
for path in &paths {
certs.add(&Certificate(fs::read(path)?))?;
}
}
if !disable_native {
for cert in rustls_native_certs::load_native_certs()? {
certs.add(&Certificate(cert.0))?;
}
}
Ok(certs)
} }
pub struct ServerAddr { pub struct ServerAddr {
@ -21,8 +47,24 @@ impl ServerAddr {
pub fn new(domain: String, port: u16, ip: Option<IpAddr>) -> Self { pub fn new(domain: String, port: u16, ip: Option<IpAddr>) -> Self {
Self { domain, port, ip } Self { domain, port, ip }
} }
pub fn server_name(&self) -> &str {
&self.domain
} }
pub async fn resolve(&self) -> Result<impl Iterator<Item = SocketAddr>, Error> {
if let Some(ip) = self.ip {
Ok(vec![SocketAddr::from((ip, self.port))].into_iter())
} else {
Ok(net::lookup_host((self.domain.as_str(), self.port))
.await?
.collect::<Vec<_>>()
.into_iter())
}
}
}
#[derive(Clone, Copy)]
pub enum UdpRelayMode { pub enum UdpRelayMode {
Native, Native,
Quic, Quic,