reading client config to endpoint
This commit is contained in:
parent
0070112e23
commit
9806c62fe7
@ -5,11 +5,15 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
bytes = { version = "1.3.0", default-features = false, features = ["std"] }
|
||||
crossbeam-utils = { version = "0.8.14", default-features = false, features = ["std"] }
|
||||
lexopt = { version = "0.3.0", default-features = false }
|
||||
once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] }
|
||||
parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] }
|
||||
quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] }
|
||||
register-count = { version = "0.1.0", default-features = false, features = ["std"] }
|
||||
rustls = { version = "0.20.8", default-features = false, features = ["quic"] }
|
||||
rustls-native-certs = { version = "0.6.2", default-features = false }
|
||||
rustls-pemfile = { version = "1.0.2", default-features = false }
|
||||
serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] }
|
||||
serde_json = { version = "1.0.91", default-features = false, features = ["std"] }
|
||||
socket2 = { version = "0.4.7", default-features = false }
|
||||
@ -20,3 +24,4 @@ tokio = { version = "1.24.2", default-features = false, features = ["macros", "n
|
||||
tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] }
|
||||
tuic = { path = "../tuic", default-features = false }
|
||||
tuic-quinn = { path = "../tuic-quinn", default-features = false }
|
||||
webpki = { version = "0.22.0", default-features = false }
|
||||
|
@ -1,12 +1,14 @@
|
||||
use crate::utils::{self, CongestionControl, UdpRelayMode};
|
||||
use crate::utils::{CongestionControl, UdpRelayMode};
|
||||
use lexopt::{Arg, Error as ArgumentError, Parser};
|
||||
use serde::{de::Error as DeError, Deserialize, Deserializer};
|
||||
use serde_json::Error as SerdeError;
|
||||
use std::{
|
||||
env::ArgsOs,
|
||||
fmt::Display,
|
||||
fs::File,
|
||||
io::Error as IoError,
|
||||
net::{IpAddr, SocketAddr},
|
||||
str::FromStr,
|
||||
time::Duration,
|
||||
};
|
||||
use thiserror::Error;
|
||||
@ -38,22 +40,26 @@ pub struct Relay {
|
||||
pub certificates: Vec<String>,
|
||||
#[serde(
|
||||
default = "default::relay::udp_relay_mode",
|
||||
deserialize_with = "utils::deserialize_from_str"
|
||||
deserialize_with = "deserialize_from_str"
|
||||
)]
|
||||
pub udp_relay_mode: UdpRelayMode,
|
||||
#[serde(
|
||||
default = "default::relay::congestion_control",
|
||||
deserialize_with = "utils::deserialize_from_str"
|
||||
deserialize_with = "deserialize_from_str"
|
||||
)]
|
||||
pub congestion_control: CongestionControl,
|
||||
#[serde(default = "default::relay::alpn")]
|
||||
pub alpn: Vec<String>,
|
||||
#[serde(default = "default::relay::zero_rtt_handshake")]
|
||||
pub zero_rtt_handshake: bool,
|
||||
#[serde(default = "default::relay::disable_sni")]
|
||||
pub disable_sni: bool,
|
||||
#[serde(default = "default::relay::timeout")]
|
||||
pub timeout: Duration,
|
||||
#[serde(default = "default::relay::heartbeat")]
|
||||
pub heartbeat: Duration,
|
||||
#[serde(default = "default::relay::disable_native_certificates")]
|
||||
pub disable_native_certificates: bool,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@ -67,47 +73,6 @@ pub struct Local {
|
||||
pub max_packet_size: usize,
|
||||
}
|
||||
|
||||
mod default {
|
||||
pub mod relay {
|
||||
use crate::utils::{CongestionControl, UdpRelayMode};
|
||||
use std::time::Duration;
|
||||
|
||||
pub const fn certificates() -> Vec<String> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
pub const fn udp_relay_mode() -> UdpRelayMode {
|
||||
UdpRelayMode::Native
|
||||
}
|
||||
|
||||
pub const fn congestion_control() -> CongestionControl {
|
||||
CongestionControl::Cubic
|
||||
}
|
||||
|
||||
pub const fn alpn() -> Vec<String> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
pub const fn zero_rtt_handshake() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub const fn timeout() -> Duration {
|
||||
Duration::from_secs(8)
|
||||
}
|
||||
|
||||
pub const fn heartbeat() -> Duration {
|
||||
Duration::from_secs(3)
|
||||
}
|
||||
}
|
||||
|
||||
pub mod local {
|
||||
pub const fn max_packet_size() -> usize {
|
||||
1500
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> {
|
||||
let mut parser = Parser::from_iter(args);
|
||||
@ -135,11 +100,69 @@ impl Config {
|
||||
}
|
||||
|
||||
let file = File::open(path.unwrap())?;
|
||||
|
||||
Ok(serde_json::from_reader(file)?)
|
||||
}
|
||||
}
|
||||
|
||||
mod default {
|
||||
pub mod relay {
|
||||
use crate::utils::{CongestionControl, UdpRelayMode};
|
||||
use std::time::Duration;
|
||||
|
||||
pub fn certificates() -> Vec<String> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
pub fn udp_relay_mode() -> UdpRelayMode {
|
||||
UdpRelayMode::Native
|
||||
}
|
||||
|
||||
pub fn congestion_control() -> CongestionControl {
|
||||
CongestionControl::Cubic
|
||||
}
|
||||
|
||||
pub fn alpn() -> Vec<String> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
pub fn zero_rtt_handshake() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn disable_sni() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub fn timeout() -> Duration {
|
||||
Duration::from_secs(8)
|
||||
}
|
||||
|
||||
pub fn heartbeat() -> Duration {
|
||||
Duration::from_secs(3)
|
||||
}
|
||||
|
||||
pub fn disable_native_certificates() -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub mod local {
|
||||
pub fn max_packet_size() -> usize {
|
||||
1500
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
T: FromStr,
|
||||
<T as FromStr>::Err: Display,
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
T::from_str(&s).map_err(DeError::custom)
|
||||
}
|
||||
|
||||
pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
|
@ -2,18 +2,22 @@ use crate::{
|
||||
config::Relay,
|
||||
error::Error,
|
||||
socks5::Server as Socks5Server,
|
||||
utils::{ServerAddr, UdpRelayMode},
|
||||
utils::{self, CongestionControl, ServerAddr, UdpRelayMode},
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use once_cell::sync::OnceCell;
|
||||
use parking_lot::Mutex;
|
||||
use quinn::{
|
||||
Connection as QuinnConnection, Endpoint as QuinnEndpoint, RecvStream, SendStream, VarInt,
|
||||
congestion::{BbrConfig, CubicConfig, NewRenoConfig},
|
||||
ClientConfig, Connection as QuinnConnection, Endpoint as QuinnEndpoint, EndpointConfig,
|
||||
RecvStream, SendStream, TokioRuntime, TransportConfig, VarInt,
|
||||
};
|
||||
use register_count::{Counter, Register};
|
||||
use rustls::{version, ClientConfig as RustlsClientConfig};
|
||||
use socks5_proto::Address as Socks5Address;
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
@ -29,30 +33,67 @@ use tuic_quinn::{side, Connect, Connection as Model, Task};
|
||||
|
||||
static ENDPOINT: OnceCell<Mutex<Endpoint>> = OnceCell::new();
|
||||
static CONNECTION: AsyncOnceCell<AsyncMutex<Connection>> = AsyncOnceCell::const_new();
|
||||
static TIMEOUT: AtomicCell<Duration> = AtomicCell::new(Duration::from_secs(0));
|
||||
|
||||
const DEFAULT_CONCURRENT_STREAMS: usize = 32;
|
||||
|
||||
pub struct Endpoint {
|
||||
ep: QuinnEndpoint,
|
||||
server: ServerAddr,
|
||||
token: Vec<u8>,
|
||||
token: Arc<[u8]>,
|
||||
udp_relay_mode: UdpRelayMode,
|
||||
zero_rtt_handshake: bool,
|
||||
timeout: Duration,
|
||||
heartbeat: Duration,
|
||||
}
|
||||
|
||||
impl Endpoint {
|
||||
pub fn set_config(cfg: Relay) -> Result<(), Error> {
|
||||
let ep = todo!();
|
||||
let certs = utils::load_certs(cfg.certificates, cfg.disable_native_certificates)?;
|
||||
|
||||
let mut crypto = RustlsClientConfig::builder()
|
||||
.with_safe_default_cipher_suites()
|
||||
.with_safe_default_kx_groups()
|
||||
.with_protocol_versions(&[&version::TLS13])
|
||||
.unwrap()
|
||||
.with_root_certificates(certs)
|
||||
.with_no_client_auth();
|
||||
|
||||
crypto.alpn_protocols = cfg.alpn.into_iter().map(|alpn| alpn.into_bytes()).collect();
|
||||
crypto.enable_early_data = true;
|
||||
crypto.enable_sni = !cfg.disable_sni;
|
||||
|
||||
let mut config = ClientConfig::new(Arc::new(crypto));
|
||||
let mut tp_cfg = TransportConfig::default();
|
||||
|
||||
tp_cfg
|
||||
.max_concurrent_bidi_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32))
|
||||
.max_concurrent_uni_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32))
|
||||
.max_idle_timeout(None);
|
||||
|
||||
match cfg.congestion_control {
|
||||
CongestionControl::Cubic => {
|
||||
tp_cfg.congestion_controller_factory(Arc::new(CubicConfig::default()))
|
||||
}
|
||||
CongestionControl::NewReno => {
|
||||
tp_cfg.congestion_controller_factory(Arc::new(NewRenoConfig::default()))
|
||||
}
|
||||
CongestionControl::Bbr => {
|
||||
tp_cfg.congestion_controller_factory(Arc::new(BbrConfig::default()))
|
||||
}
|
||||
};
|
||||
|
||||
config.transport_config(Arc::new(tp_cfg));
|
||||
|
||||
let socket = UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], 0)))?;
|
||||
let mut ep = QuinnEndpoint::new(EndpointConfig::default(), None, socket, TokioRuntime)?;
|
||||
ep.set_default_client_config(config);
|
||||
|
||||
let ep = Self {
|
||||
ep,
|
||||
server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip),
|
||||
token: cfg.token.into_bytes(),
|
||||
token: Arc::from(cfg.token.into_bytes().into_boxed_slice()),
|
||||
udp_relay_mode: cfg.udp_relay_mode,
|
||||
zero_rtt_handshake: cfg.zero_rtt_handshake,
|
||||
timeout: cfg.timeout,
|
||||
heartbeat: cfg.heartbeat,
|
||||
};
|
||||
|
||||
@ -61,19 +102,67 @@ impl Endpoint {
|
||||
.map_err(|_| "endpoint already initialized")
|
||||
.unwrap();
|
||||
|
||||
TIMEOUT.store(cfg.timeout);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn connect(&self) -> Result<Connection, Error> {
|
||||
let conn = self
|
||||
.ep
|
||||
.connect(SocketAddr::from(([127, 0, 0, 1], 8080)), "localhost")?
|
||||
async fn connect(&mut self) -> Result<Connection, Error> {
|
||||
async fn connect_to(
|
||||
ep: &mut QuinnEndpoint,
|
||||
addr: SocketAddr,
|
||||
server_name: &str,
|
||||
udp_relay_mode: UdpRelayMode,
|
||||
zero_rtt_handshake: bool,
|
||||
) -> Result<Connection, Error> {
|
||||
let match_ipv4 = addr.is_ipv4() && ep.local_addr().map_or(false, |addr| addr.is_ipv4());
|
||||
let match_ipv6 = addr.is_ipv6() && ep.local_addr().map_or(false, |addr| addr.is_ipv6());
|
||||
|
||||
if !match_ipv4 && !match_ipv6 {
|
||||
let bind_addr = if addr.is_ipv4() {
|
||||
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
|
||||
} else {
|
||||
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
|
||||
};
|
||||
|
||||
ep.rebind(UdpSocket::bind(bind_addr)?)?;
|
||||
}
|
||||
|
||||
let conn = ep.connect(addr, server_name)?;
|
||||
|
||||
let conn = if zero_rtt_handshake {
|
||||
match conn.into_0rtt() {
|
||||
Ok((conn, _)) => conn,
|
||||
Err(conn) => conn.await?,
|
||||
}
|
||||
} else {
|
||||
conn.await?
|
||||
};
|
||||
|
||||
Ok(Connection::new(conn, udp_relay_mode))
|
||||
}
|
||||
|
||||
let mut last_err = None;
|
||||
|
||||
for addr in self.server.resolve().await? {
|
||||
match connect_to(
|
||||
&mut self.ep,
|
||||
addr,
|
||||
self.server.server_name(),
|
||||
self.udp_relay_mode,
|
||||
self.zero_rtt_handshake,
|
||||
)
|
||||
.await
|
||||
.map(Connection::new)?;
|
||||
{
|
||||
Ok(conn) => {
|
||||
tokio::spawn(conn.clone().init(self.token.clone(), self.heartbeat));
|
||||
return Ok(conn);
|
||||
}
|
||||
Err(err) => last_err = Some(err),
|
||||
}
|
||||
}
|
||||
|
||||
tokio::spawn(conn.clone().init());
|
||||
|
||||
Ok(conn)
|
||||
Err(last_err.unwrap_or(Error::DnsResolve))
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,6 +170,7 @@ impl Endpoint {
|
||||
pub struct Connection {
|
||||
conn: QuinnConnection,
|
||||
model: Model<side::Client>,
|
||||
udp_relay_mode: UdpRelayMode,
|
||||
remote_uni_stream_cnt: Counter,
|
||||
remote_bi_stream_cnt: Counter,
|
||||
max_concurrent_uni_streams: Arc<AtomicUsize>,
|
||||
@ -88,10 +178,11 @@ pub struct Connection {
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
fn new(conn: QuinnConnection) -> Self {
|
||||
fn new(conn: QuinnConnection, udp_relay_mode: UdpRelayMode) -> Self {
|
||||
Self {
|
||||
conn: conn.clone(),
|
||||
model: Model::<side::Client>::new(conn),
|
||||
udp_relay_mode,
|
||||
remote_uni_stream_cnt: Counter::new(),
|
||||
remote_bi_stream_cnt: Counter::new(),
|
||||
max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)),
|
||||
@ -125,7 +216,7 @@ impl Connection {
|
||||
Ok::<_, Error>(conn.clone())
|
||||
};
|
||||
|
||||
let conn = time::timeout(Duration::from_secs(5), try_get_conn)
|
||||
let conn = time::timeout(TIMEOUT.load(), try_get_conn)
|
||||
.await
|
||||
.map_err(|_| Error::Timeout)??;
|
||||
|
||||
@ -137,7 +228,11 @@ impl Connection {
|
||||
}
|
||||
|
||||
pub async fn packet(&self, pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> {
|
||||
self.model.packet_quic(pkt, addr, assoc_id).await?; // TODO
|
||||
match self.udp_relay_mode {
|
||||
UdpRelayMode::Native => self.model.packet_native(pkt, addr, assoc_id)?,
|
||||
UdpRelayMode::Quic => self.model.packet_quic(pkt, addr, assoc_id).await?,
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -250,16 +345,26 @@ impl Connection {
|
||||
}
|
||||
}
|
||||
|
||||
async fn authenticate(self) {
|
||||
match self.model.authenticate([0; 32]).await {
|
||||
async fn authenticate(self, token: Arc<[u8]>) {
|
||||
let mut buf = [0; 32];
|
||||
|
||||
match self.conn.export_keying_material(&mut buf, &token, &token) {
|
||||
Ok(()) => {}
|
||||
Err(_) => {
|
||||
eprintln!("token length too short");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
match self.model.authenticate(buf).await {
|
||||
Ok(()) => {}
|
||||
Err(err) => eprintln!("{err}"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn heartbeat(self) {
|
||||
async fn heartbeat(self, heartbeat: Duration) {
|
||||
loop {
|
||||
time::sleep(Duration::from_secs(5)).await;
|
||||
time::sleep(heartbeat).await;
|
||||
|
||||
if self.is_closed() {
|
||||
break;
|
||||
@ -272,9 +377,9 @@ impl Connection {
|
||||
}
|
||||
}
|
||||
|
||||
async fn init(self) {
|
||||
tokio::spawn(self.clone().authenticate());
|
||||
tokio::spawn(self.clone().heartbeat());
|
||||
async fn init(self, token: Arc<[u8]>, heartbeat: Duration) {
|
||||
tokio::spawn(self.clone().authenticate(token));
|
||||
tokio::spawn(self.clone().heartbeat(heartbeat));
|
||||
|
||||
let err = loop {
|
||||
tokio::select! {
|
||||
|
@ -2,6 +2,7 @@ use quinn::{ConnectError, ConnectionError};
|
||||
use std::io::Error as IoError;
|
||||
use thiserror::Error;
|
||||
use tuic_quinn::Error as ModelError;
|
||||
use webpki::Error as WebpkiError;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
@ -13,8 +14,12 @@ pub enum Error {
|
||||
Connection(#[from] ConnectionError),
|
||||
#[error(transparent)]
|
||||
Model(#[from] ModelError),
|
||||
#[error("timeout")]
|
||||
#[error(transparent)]
|
||||
Webpki(#[from] WebpkiError),
|
||||
#[error("timeout establishing connection")]
|
||||
Timeout,
|
||||
#[error("invalid authentication")]
|
||||
InvalidAuth,
|
||||
#[error("cannot resolve the server name")]
|
||||
DnsResolve,
|
||||
#[error("invalid socks5 authentication")]
|
||||
InvalidSocks5Auth,
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ impl Server {
|
||||
Arc::new(Password::new(username.into_bytes(), password.into_bytes()))
|
||||
}
|
||||
(None, None) => Arc::new(NoAuth),
|
||||
_ => return Err(Error::InvalidAuth),
|
||||
_ => return Err(Error::InvalidSocks5Auth),
|
||||
};
|
||||
|
||||
let server = Self {
|
||||
|
@ -1,14 +1,40 @@
|
||||
use serde::{de::Error as DeError, Deserialize, Deserializer};
|
||||
use std::{fmt::Display, net::IpAddr, str::FromStr};
|
||||
use crate::error::Error;
|
||||
use rustls::{Certificate, RootCertStore};
|
||||
use rustls_pemfile::Item;
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
io::BufReader,
|
||||
net::{IpAddr, SocketAddr},
|
||||
str::FromStr,
|
||||
};
|
||||
use tokio::net;
|
||||
|
||||
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
T: FromStr,
|
||||
<T as FromStr>::Err: Display,
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
T::from_str(&s).map_err(DeError::custom)
|
||||
pub fn load_certs(paths: Vec<String>, disable_native: bool) -> Result<RootCertStore, Error> {
|
||||
let mut certs = RootCertStore::empty();
|
||||
|
||||
for path in &paths {
|
||||
let mut file = BufReader::new(File::open(path)?);
|
||||
|
||||
while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) {
|
||||
if let Item::X509Certificate(cert) = item {
|
||||
certs.add(&Certificate(cert))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if certs.is_empty() {
|
||||
for path in &paths {
|
||||
certs.add(&Certificate(fs::read(path)?))?;
|
||||
}
|
||||
}
|
||||
|
||||
if !disable_native {
|
||||
for cert in rustls_native_certs::load_native_certs()? {
|
||||
certs.add(&Certificate(cert.0))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(certs)
|
||||
}
|
||||
|
||||
pub struct ServerAddr {
|
||||
@ -21,8 +47,24 @@ impl ServerAddr {
|
||||
pub fn new(domain: String, port: u16, ip: Option<IpAddr>) -> Self {
|
||||
Self { domain, port, ip }
|
||||
}
|
||||
|
||||
pub fn server_name(&self) -> &str {
|
||||
&self.domain
|
||||
}
|
||||
|
||||
pub async fn resolve(&self) -> Result<impl Iterator<Item = SocketAddr>, Error> {
|
||||
if let Some(ip) = self.ip {
|
||||
Ok(vec![SocketAddr::from((ip, self.port))].into_iter())
|
||||
} else {
|
||||
Ok(net::lookup_host((self.domain.as_str(), self.port))
|
||||
.await?
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum UdpRelayMode {
|
||||
Native,
|
||||
Quic,
|
||||
|
Loading…
x
Reference in New Issue
Block a user