reading client config to endpoint
This commit is contained in:
parent
0070112e23
commit
9806c62fe7
@ -5,11 +5,15 @@ edition = "2021"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
bytes = { version = "1.3.0", default-features = false, features = ["std"] }
|
bytes = { version = "1.3.0", default-features = false, features = ["std"] }
|
||||||
|
crossbeam-utils = { version = "0.8.14", default-features = false, features = ["std"] }
|
||||||
lexopt = { version = "0.3.0", default-features = false }
|
lexopt = { version = "0.3.0", default-features = false }
|
||||||
once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] }
|
once_cell = { version = "1.17.0", default-features = false, features = ["parking_lot", "std"] }
|
||||||
parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] }
|
parking_lot = { version = "0.12.1", default-features = false, features = ["send_guard"] }
|
||||||
quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] }
|
quinn = { version = "0.9.3", default-features = false, features = ["futures-io", "runtime-tokio", "tls-rustls"] }
|
||||||
register-count = { version = "0.1.0", default-features = false, features = ["std"] }
|
register-count = { version = "0.1.0", default-features = false, features = ["std"] }
|
||||||
|
rustls = { version = "0.20.8", default-features = false, features = ["quic"] }
|
||||||
|
rustls-native-certs = { version = "0.6.2", default-features = false }
|
||||||
|
rustls-pemfile = { version = "1.0.2", default-features = false }
|
||||||
serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] }
|
serde = { version = "1.0.152", default-features = false, features = ["derive", "std"] }
|
||||||
serde_json = { version = "1.0.91", default-features = false, features = ["std"] }
|
serde_json = { version = "1.0.91", default-features = false, features = ["std"] }
|
||||||
socket2 = { version = "0.4.7", default-features = false }
|
socket2 = { version = "0.4.7", default-features = false }
|
||||||
@ -20,3 +24,4 @@ tokio = { version = "1.24.2", default-features = false, features = ["macros", "n
|
|||||||
tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] }
|
tokio-util = { version = "0.7.4", default-features = false, features = ["compat"] }
|
||||||
tuic = { path = "../tuic", default-features = false }
|
tuic = { path = "../tuic", default-features = false }
|
||||||
tuic-quinn = { path = "../tuic-quinn", default-features = false }
|
tuic-quinn = { path = "../tuic-quinn", default-features = false }
|
||||||
|
webpki = { version = "0.22.0", default-features = false }
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
use crate::utils::{self, CongestionControl, UdpRelayMode};
|
use crate::utils::{CongestionControl, UdpRelayMode};
|
||||||
use lexopt::{Arg, Error as ArgumentError, Parser};
|
use lexopt::{Arg, Error as ArgumentError, Parser};
|
||||||
use serde::{de::Error as DeError, Deserialize, Deserializer};
|
use serde::{de::Error as DeError, Deserialize, Deserializer};
|
||||||
use serde_json::Error as SerdeError;
|
use serde_json::Error as SerdeError;
|
||||||
use std::{
|
use std::{
|
||||||
env::ArgsOs,
|
env::ArgsOs,
|
||||||
|
fmt::Display,
|
||||||
fs::File,
|
fs::File,
|
||||||
io::Error as IoError,
|
io::Error as IoError,
|
||||||
net::{IpAddr, SocketAddr},
|
net::{IpAddr, SocketAddr},
|
||||||
|
str::FromStr,
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@ -38,22 +40,26 @@ pub struct Relay {
|
|||||||
pub certificates: Vec<String>,
|
pub certificates: Vec<String>,
|
||||||
#[serde(
|
#[serde(
|
||||||
default = "default::relay::udp_relay_mode",
|
default = "default::relay::udp_relay_mode",
|
||||||
deserialize_with = "utils::deserialize_from_str"
|
deserialize_with = "deserialize_from_str"
|
||||||
)]
|
)]
|
||||||
pub udp_relay_mode: UdpRelayMode,
|
pub udp_relay_mode: UdpRelayMode,
|
||||||
#[serde(
|
#[serde(
|
||||||
default = "default::relay::congestion_control",
|
default = "default::relay::congestion_control",
|
||||||
deserialize_with = "utils::deserialize_from_str"
|
deserialize_with = "deserialize_from_str"
|
||||||
)]
|
)]
|
||||||
pub congestion_control: CongestionControl,
|
pub congestion_control: CongestionControl,
|
||||||
#[serde(default = "default::relay::alpn")]
|
#[serde(default = "default::relay::alpn")]
|
||||||
pub alpn: Vec<String>,
|
pub alpn: Vec<String>,
|
||||||
#[serde(default = "default::relay::zero_rtt_handshake")]
|
#[serde(default = "default::relay::zero_rtt_handshake")]
|
||||||
pub zero_rtt_handshake: bool,
|
pub zero_rtt_handshake: bool,
|
||||||
|
#[serde(default = "default::relay::disable_sni")]
|
||||||
|
pub disable_sni: bool,
|
||||||
#[serde(default = "default::relay::timeout")]
|
#[serde(default = "default::relay::timeout")]
|
||||||
pub timeout: Duration,
|
pub timeout: Duration,
|
||||||
#[serde(default = "default::relay::heartbeat")]
|
#[serde(default = "default::relay::heartbeat")]
|
||||||
pub heartbeat: Duration,
|
pub heartbeat: Duration,
|
||||||
|
#[serde(default = "default::relay::disable_native_certificates")]
|
||||||
|
pub disable_native_certificates: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -67,47 +73,6 @@ pub struct Local {
|
|||||||
pub max_packet_size: usize,
|
pub max_packet_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
mod default {
|
|
||||||
pub mod relay {
|
|
||||||
use crate::utils::{CongestionControl, UdpRelayMode};
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
pub const fn certificates() -> Vec<String> {
|
|
||||||
Vec::new()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn udp_relay_mode() -> UdpRelayMode {
|
|
||||||
UdpRelayMode::Native
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn congestion_control() -> CongestionControl {
|
|
||||||
CongestionControl::Cubic
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn alpn() -> Vec<String> {
|
|
||||||
Vec::new()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn zero_rtt_handshake() -> bool {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn timeout() -> Duration {
|
|
||||||
Duration::from_secs(8)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const fn heartbeat() -> Duration {
|
|
||||||
Duration::from_secs(3)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub mod local {
|
|
||||||
pub const fn max_packet_size() -> usize {
|
|
||||||
1500
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> {
|
pub fn parse(args: ArgsOs) -> Result<Self, ConfigError> {
|
||||||
let mut parser = Parser::from_iter(args);
|
let mut parser = Parser::from_iter(args);
|
||||||
@ -135,11 +100,69 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let file = File::open(path.unwrap())?;
|
let file = File::open(path.unwrap())?;
|
||||||
|
|
||||||
Ok(serde_json::from_reader(file)?)
|
Ok(serde_json::from_reader(file)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod default {
|
||||||
|
pub mod relay {
|
||||||
|
use crate::utils::{CongestionControl, UdpRelayMode};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
pub fn certificates() -> Vec<String> {
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn udp_relay_mode() -> UdpRelayMode {
|
||||||
|
UdpRelayMode::Native
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn congestion_control() -> CongestionControl {
|
||||||
|
CongestionControl::Cubic
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn alpn() -> Vec<String> {
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn zero_rtt_handshake() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn disable_sni() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn timeout() -> Duration {
|
||||||
|
Duration::from_secs(8)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn heartbeat() -> Duration {
|
||||||
|
Duration::from_secs(3)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn disable_native_certificates() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod local {
|
||||||
|
pub fn max_packet_size() -> usize {
|
||||||
|
1500
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
|
||||||
|
where
|
||||||
|
T: FromStr,
|
||||||
|
<T as FromStr>::Err: Display,
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let s = String::deserialize(deserializer)?;
|
||||||
|
T::from_str(&s).map_err(DeError::custom)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error>
|
pub fn deserialize_server<'de, D>(deserializer: D) -> Result<(String, u16), D::Error>
|
||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
|
@ -2,18 +2,22 @@ use crate::{
|
|||||||
config::Relay,
|
config::Relay,
|
||||||
error::Error,
|
error::Error,
|
||||||
socks5::Server as Socks5Server,
|
socks5::Server as Socks5Server,
|
||||||
utils::{ServerAddr, UdpRelayMode},
|
utils::{self, CongestionControl, ServerAddr, UdpRelayMode},
|
||||||
};
|
};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
|
use crossbeam_utils::atomic::AtomicCell;
|
||||||
use once_cell::sync::OnceCell;
|
use once_cell::sync::OnceCell;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use quinn::{
|
use quinn::{
|
||||||
Connection as QuinnConnection, Endpoint as QuinnEndpoint, RecvStream, SendStream, VarInt,
|
congestion::{BbrConfig, CubicConfig, NewRenoConfig},
|
||||||
|
ClientConfig, Connection as QuinnConnection, Endpoint as QuinnEndpoint, EndpointConfig,
|
||||||
|
RecvStream, SendStream, TokioRuntime, TransportConfig, VarInt,
|
||||||
};
|
};
|
||||||
use register_count::{Counter, Register};
|
use register_count::{Counter, Register};
|
||||||
|
use rustls::{version, ClientConfig as RustlsClientConfig};
|
||||||
use socks5_proto::Address as Socks5Address;
|
use socks5_proto::Address as Socks5Address;
|
||||||
use std::{
|
use std::{
|
||||||
net::SocketAddr,
|
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
|
||||||
sync::{
|
sync::{
|
||||||
atomic::{AtomicUsize, Ordering},
|
atomic::{AtomicUsize, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
@ -29,30 +33,67 @@ use tuic_quinn::{side, Connect, Connection as Model, Task};
|
|||||||
|
|
||||||
static ENDPOINT: OnceCell<Mutex<Endpoint>> = OnceCell::new();
|
static ENDPOINT: OnceCell<Mutex<Endpoint>> = OnceCell::new();
|
||||||
static CONNECTION: AsyncOnceCell<AsyncMutex<Connection>> = AsyncOnceCell::const_new();
|
static CONNECTION: AsyncOnceCell<AsyncMutex<Connection>> = AsyncOnceCell::const_new();
|
||||||
|
static TIMEOUT: AtomicCell<Duration> = AtomicCell::new(Duration::from_secs(0));
|
||||||
|
|
||||||
const DEFAULT_CONCURRENT_STREAMS: usize = 32;
|
const DEFAULT_CONCURRENT_STREAMS: usize = 32;
|
||||||
|
|
||||||
pub struct Endpoint {
|
pub struct Endpoint {
|
||||||
ep: QuinnEndpoint,
|
ep: QuinnEndpoint,
|
||||||
server: ServerAddr,
|
server: ServerAddr,
|
||||||
token: Vec<u8>,
|
token: Arc<[u8]>,
|
||||||
udp_relay_mode: UdpRelayMode,
|
udp_relay_mode: UdpRelayMode,
|
||||||
zero_rtt_handshake: bool,
|
zero_rtt_handshake: bool,
|
||||||
timeout: Duration,
|
|
||||||
heartbeat: Duration,
|
heartbeat: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Endpoint {
|
impl Endpoint {
|
||||||
pub fn set_config(cfg: Relay) -> Result<(), Error> {
|
pub fn set_config(cfg: Relay) -> Result<(), Error> {
|
||||||
let ep = todo!();
|
let certs = utils::load_certs(cfg.certificates, cfg.disable_native_certificates)?;
|
||||||
|
|
||||||
|
let mut crypto = RustlsClientConfig::builder()
|
||||||
|
.with_safe_default_cipher_suites()
|
||||||
|
.with_safe_default_kx_groups()
|
||||||
|
.with_protocol_versions(&[&version::TLS13])
|
||||||
|
.unwrap()
|
||||||
|
.with_root_certificates(certs)
|
||||||
|
.with_no_client_auth();
|
||||||
|
|
||||||
|
crypto.alpn_protocols = cfg.alpn.into_iter().map(|alpn| alpn.into_bytes()).collect();
|
||||||
|
crypto.enable_early_data = true;
|
||||||
|
crypto.enable_sni = !cfg.disable_sni;
|
||||||
|
|
||||||
|
let mut config = ClientConfig::new(Arc::new(crypto));
|
||||||
|
let mut tp_cfg = TransportConfig::default();
|
||||||
|
|
||||||
|
tp_cfg
|
||||||
|
.max_concurrent_bidi_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32))
|
||||||
|
.max_concurrent_uni_streams(VarInt::from(DEFAULT_CONCURRENT_STREAMS as u32))
|
||||||
|
.max_idle_timeout(None);
|
||||||
|
|
||||||
|
match cfg.congestion_control {
|
||||||
|
CongestionControl::Cubic => {
|
||||||
|
tp_cfg.congestion_controller_factory(Arc::new(CubicConfig::default()))
|
||||||
|
}
|
||||||
|
CongestionControl::NewReno => {
|
||||||
|
tp_cfg.congestion_controller_factory(Arc::new(NewRenoConfig::default()))
|
||||||
|
}
|
||||||
|
CongestionControl::Bbr => {
|
||||||
|
tp_cfg.congestion_controller_factory(Arc::new(BbrConfig::default()))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
config.transport_config(Arc::new(tp_cfg));
|
||||||
|
|
||||||
|
let socket = UdpSocket::bind(SocketAddr::from(([0, 0, 0, 0], 0)))?;
|
||||||
|
let mut ep = QuinnEndpoint::new(EndpointConfig::default(), None, socket, TokioRuntime)?;
|
||||||
|
ep.set_default_client_config(config);
|
||||||
|
|
||||||
let ep = Self {
|
let ep = Self {
|
||||||
ep,
|
ep,
|
||||||
server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip),
|
server: ServerAddr::new(cfg.server.0, cfg.server.1, cfg.ip),
|
||||||
token: cfg.token.into_bytes(),
|
token: Arc::from(cfg.token.into_bytes().into_boxed_slice()),
|
||||||
udp_relay_mode: cfg.udp_relay_mode,
|
udp_relay_mode: cfg.udp_relay_mode,
|
||||||
zero_rtt_handshake: cfg.zero_rtt_handshake,
|
zero_rtt_handshake: cfg.zero_rtt_handshake,
|
||||||
timeout: cfg.timeout,
|
|
||||||
heartbeat: cfg.heartbeat,
|
heartbeat: cfg.heartbeat,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -61,19 +102,67 @@ impl Endpoint {
|
|||||||
.map_err(|_| "endpoint already initialized")
|
.map_err(|_| "endpoint already initialized")
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
TIMEOUT.store(cfg.timeout);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn connect(&self) -> Result<Connection, Error> {
|
async fn connect(&mut self) -> Result<Connection, Error> {
|
||||||
let conn = self
|
async fn connect_to(
|
||||||
.ep
|
ep: &mut QuinnEndpoint,
|
||||||
.connect(SocketAddr::from(([127, 0, 0, 1], 8080)), "localhost")?
|
addr: SocketAddr,
|
||||||
|
server_name: &str,
|
||||||
|
udp_relay_mode: UdpRelayMode,
|
||||||
|
zero_rtt_handshake: bool,
|
||||||
|
) -> Result<Connection, Error> {
|
||||||
|
let match_ipv4 = addr.is_ipv4() && ep.local_addr().map_or(false, |addr| addr.is_ipv4());
|
||||||
|
let match_ipv6 = addr.is_ipv6() && ep.local_addr().map_or(false, |addr| addr.is_ipv6());
|
||||||
|
|
||||||
|
if !match_ipv4 && !match_ipv6 {
|
||||||
|
let bind_addr = if addr.is_ipv4() {
|
||||||
|
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
|
||||||
|
} else {
|
||||||
|
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
|
||||||
|
};
|
||||||
|
|
||||||
|
ep.rebind(UdpSocket::bind(bind_addr)?)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let conn = ep.connect(addr, server_name)?;
|
||||||
|
|
||||||
|
let conn = if zero_rtt_handshake {
|
||||||
|
match conn.into_0rtt() {
|
||||||
|
Ok((conn, _)) => conn,
|
||||||
|
Err(conn) => conn.await?,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
conn.await?
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Connection::new(conn, udp_relay_mode))
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut last_err = None;
|
||||||
|
|
||||||
|
for addr in self.server.resolve().await? {
|
||||||
|
match connect_to(
|
||||||
|
&mut self.ep,
|
||||||
|
addr,
|
||||||
|
self.server.server_name(),
|
||||||
|
self.udp_relay_mode,
|
||||||
|
self.zero_rtt_handshake,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.map(Connection::new)?;
|
{
|
||||||
|
Ok(conn) => {
|
||||||
|
tokio::spawn(conn.clone().init(self.token.clone(), self.heartbeat));
|
||||||
|
return Ok(conn);
|
||||||
|
}
|
||||||
|
Err(err) => last_err = Some(err),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tokio::spawn(conn.clone().init());
|
Err(last_err.unwrap_or(Error::DnsResolve))
|
||||||
|
|
||||||
Ok(conn)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,6 +170,7 @@ impl Endpoint {
|
|||||||
pub struct Connection {
|
pub struct Connection {
|
||||||
conn: QuinnConnection,
|
conn: QuinnConnection,
|
||||||
model: Model<side::Client>,
|
model: Model<side::Client>,
|
||||||
|
udp_relay_mode: UdpRelayMode,
|
||||||
remote_uni_stream_cnt: Counter,
|
remote_uni_stream_cnt: Counter,
|
||||||
remote_bi_stream_cnt: Counter,
|
remote_bi_stream_cnt: Counter,
|
||||||
max_concurrent_uni_streams: Arc<AtomicUsize>,
|
max_concurrent_uni_streams: Arc<AtomicUsize>,
|
||||||
@ -88,10 +178,11 @@ pub struct Connection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Connection {
|
impl Connection {
|
||||||
fn new(conn: QuinnConnection) -> Self {
|
fn new(conn: QuinnConnection, udp_relay_mode: UdpRelayMode) -> Self {
|
||||||
Self {
|
Self {
|
||||||
conn: conn.clone(),
|
conn: conn.clone(),
|
||||||
model: Model::<side::Client>::new(conn),
|
model: Model::<side::Client>::new(conn),
|
||||||
|
udp_relay_mode,
|
||||||
remote_uni_stream_cnt: Counter::new(),
|
remote_uni_stream_cnt: Counter::new(),
|
||||||
remote_bi_stream_cnt: Counter::new(),
|
remote_bi_stream_cnt: Counter::new(),
|
||||||
max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)),
|
max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)),
|
||||||
@ -125,7 +216,7 @@ impl Connection {
|
|||||||
Ok::<_, Error>(conn.clone())
|
Ok::<_, Error>(conn.clone())
|
||||||
};
|
};
|
||||||
|
|
||||||
let conn = time::timeout(Duration::from_secs(5), try_get_conn)
|
let conn = time::timeout(TIMEOUT.load(), try_get_conn)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| Error::Timeout)??;
|
.map_err(|_| Error::Timeout)??;
|
||||||
|
|
||||||
@ -137,7 +228,11 @@ impl Connection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn packet(&self, pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> {
|
pub async fn packet(&self, pkt: Bytes, addr: Address, assoc_id: u16) -> Result<(), Error> {
|
||||||
self.model.packet_quic(pkt, addr, assoc_id).await?; // TODO
|
match self.udp_relay_mode {
|
||||||
|
UdpRelayMode::Native => self.model.packet_native(pkt, addr, assoc_id)?,
|
||||||
|
UdpRelayMode::Quic => self.model.packet_quic(pkt, addr, assoc_id).await?,
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -250,16 +345,26 @@ impl Connection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn authenticate(self) {
|
async fn authenticate(self, token: Arc<[u8]>) {
|
||||||
match self.model.authenticate([0; 32]).await {
|
let mut buf = [0; 32];
|
||||||
|
|
||||||
|
match self.conn.export_keying_material(&mut buf, &token, &token) {
|
||||||
|
Ok(()) => {}
|
||||||
|
Err(_) => {
|
||||||
|
eprintln!("token length too short");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.model.authenticate(buf).await {
|
||||||
Ok(()) => {}
|
Ok(()) => {}
|
||||||
Err(err) => eprintln!("{err}"),
|
Err(err) => eprintln!("{err}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn heartbeat(self) {
|
async fn heartbeat(self, heartbeat: Duration) {
|
||||||
loop {
|
loop {
|
||||||
time::sleep(Duration::from_secs(5)).await;
|
time::sleep(heartbeat).await;
|
||||||
|
|
||||||
if self.is_closed() {
|
if self.is_closed() {
|
||||||
break;
|
break;
|
||||||
@ -272,9 +377,9 @@ impl Connection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn init(self) {
|
async fn init(self, token: Arc<[u8]>, heartbeat: Duration) {
|
||||||
tokio::spawn(self.clone().authenticate());
|
tokio::spawn(self.clone().authenticate(token));
|
||||||
tokio::spawn(self.clone().heartbeat());
|
tokio::spawn(self.clone().heartbeat(heartbeat));
|
||||||
|
|
||||||
let err = loop {
|
let err = loop {
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
|
@ -2,6 +2,7 @@ use quinn::{ConnectError, ConnectionError};
|
|||||||
use std::io::Error as IoError;
|
use std::io::Error as IoError;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use tuic_quinn::Error as ModelError;
|
use tuic_quinn::Error as ModelError;
|
||||||
|
use webpki::Error as WebpkiError;
|
||||||
|
|
||||||
#[derive(Debug, Error)]
|
#[derive(Debug, Error)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
@ -13,8 +14,12 @@ pub enum Error {
|
|||||||
Connection(#[from] ConnectionError),
|
Connection(#[from] ConnectionError),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Model(#[from] ModelError),
|
Model(#[from] ModelError),
|
||||||
#[error("timeout")]
|
#[error(transparent)]
|
||||||
|
Webpki(#[from] WebpkiError),
|
||||||
|
#[error("timeout establishing connection")]
|
||||||
Timeout,
|
Timeout,
|
||||||
#[error("invalid authentication")]
|
#[error("cannot resolve the server name")]
|
||||||
InvalidAuth,
|
DnsResolve,
|
||||||
|
#[error("invalid socks5 authentication")]
|
||||||
|
InvalidSocks5Auth,
|
||||||
}
|
}
|
||||||
|
@ -60,7 +60,7 @@ impl Server {
|
|||||||
Arc::new(Password::new(username.into_bytes(), password.into_bytes()))
|
Arc::new(Password::new(username.into_bytes(), password.into_bytes()))
|
||||||
}
|
}
|
||||||
(None, None) => Arc::new(NoAuth),
|
(None, None) => Arc::new(NoAuth),
|
||||||
_ => return Err(Error::InvalidAuth),
|
_ => return Err(Error::InvalidSocks5Auth),
|
||||||
};
|
};
|
||||||
|
|
||||||
let server = Self {
|
let server = Self {
|
||||||
|
@ -1,14 +1,40 @@
|
|||||||
use serde::{de::Error as DeError, Deserialize, Deserializer};
|
use crate::error::Error;
|
||||||
use std::{fmt::Display, net::IpAddr, str::FromStr};
|
use rustls::{Certificate, RootCertStore};
|
||||||
|
use rustls_pemfile::Item;
|
||||||
|
use std::{
|
||||||
|
fs::{self, File},
|
||||||
|
io::BufReader,
|
||||||
|
net::{IpAddr, SocketAddr},
|
||||||
|
str::FromStr,
|
||||||
|
};
|
||||||
|
use tokio::net;
|
||||||
|
|
||||||
pub fn deserialize_from_str<'de, T, D>(deserializer: D) -> Result<T, D::Error>
|
pub fn load_certs(paths: Vec<String>, disable_native: bool) -> Result<RootCertStore, Error> {
|
||||||
where
|
let mut certs = RootCertStore::empty();
|
||||||
T: FromStr,
|
|
||||||
<T as FromStr>::Err: Display,
|
for path in &paths {
|
||||||
D: Deserializer<'de>,
|
let mut file = BufReader::new(File::open(path)?);
|
||||||
{
|
|
||||||
let s = String::deserialize(deserializer)?;
|
while let Ok(Some(item)) = rustls_pemfile::read_one(&mut file) {
|
||||||
T::from_str(&s).map_err(DeError::custom)
|
if let Item::X509Certificate(cert) = item {
|
||||||
|
certs.add(&Certificate(cert))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if certs.is_empty() {
|
||||||
|
for path in &paths {
|
||||||
|
certs.add(&Certificate(fs::read(path)?))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !disable_native {
|
||||||
|
for cert in rustls_native_certs::load_native_certs()? {
|
||||||
|
certs.add(&Certificate(cert.0))?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(certs)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ServerAddr {
|
pub struct ServerAddr {
|
||||||
@ -21,8 +47,24 @@ impl ServerAddr {
|
|||||||
pub fn new(domain: String, port: u16, ip: Option<IpAddr>) -> Self {
|
pub fn new(domain: String, port: u16, ip: Option<IpAddr>) -> Self {
|
||||||
Self { domain, port, ip }
|
Self { domain, port, ip }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn server_name(&self) -> &str {
|
||||||
|
&self.domain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn resolve(&self) -> Result<impl Iterator<Item = SocketAddr>, Error> {
|
||||||
|
if let Some(ip) = self.ip {
|
||||||
|
Ok(vec![SocketAddr::from((ip, self.port))].into_iter())
|
||||||
|
} else {
|
||||||
|
Ok(net::lookup_host((self.domain.as_str(), self.port))
|
||||||
|
.await?
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.into_iter())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
pub enum UdpRelayMode {
|
pub enum UdpRelayMode {
|
||||||
Native,
|
Native,
|
||||||
Quic,
|
Quic,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user