refactor server UDP session handling
This commit is contained in:
parent
02ba5056ee
commit
eb228e554e
51
tuic-server/src/connection/authenticated.rs
Normal file
51
tuic-server/src/connection/authenticated.rs
Normal file
@ -0,0 +1,51 @@
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use parking_lot::Mutex;
|
||||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll, Waker},
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) struct Authenticated(Arc<AuthenticatedInner>);
|
||||
|
||||
struct AuthenticatedInner {
|
||||
uuid: AtomicCell<Option<Uuid>>,
|
||||
broadcast: Mutex<Vec<Waker>>,
|
||||
}
|
||||
|
||||
impl Authenticated {
|
||||
pub(super) fn new() -> Self {
|
||||
Self(Arc::new(AuthenticatedInner {
|
||||
uuid: AtomicCell::new(None),
|
||||
broadcast: Mutex::new(Vec::new()),
|
||||
}))
|
||||
}
|
||||
|
||||
pub(super) fn set(&self, uuid: Uuid) {
|
||||
self.0.uuid.store(Some(uuid));
|
||||
|
||||
for waker in self.0.broadcast.lock().drain(..) {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn get(&self) -> Option<Uuid> {
|
||||
self.0.uuid.load()
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for Authenticated {
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if self.get().is_some() {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
self.0.broadcast.lock().push(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
138
tuic-server/src/connection/handle_stream.rs
Normal file
138
tuic-server/src/connection/handle_stream.rs
Normal file
@ -0,0 +1,138 @@
|
||||
use super::Connection;
|
||||
use crate::{Error, UdpRelayMode};
|
||||
use bytes::Bytes;
|
||||
use quinn::{RecvStream, SendStream, VarInt};
|
||||
use register_count::Register;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::time;
|
||||
use tuic_quinn::Task;
|
||||
|
||||
impl Connection {
|
||||
pub(crate) async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) {
|
||||
let addr = self.inner.remote_address();
|
||||
log::debug!("[{addr}] incoming unidirectional stream");
|
||||
|
||||
let max = self.max_concurrent_uni_streams.load(Ordering::Relaxed);
|
||||
|
||||
if self.remote_uni_stream_cnt.count() as u32 == max {
|
||||
self.max_concurrent_uni_streams
|
||||
.store(max * 2, Ordering::Relaxed);
|
||||
|
||||
self.inner
|
||||
.set_max_concurrent_uni_streams(VarInt::from(max * 2));
|
||||
}
|
||||
|
||||
let pre_process = async {
|
||||
let task = time::timeout(
|
||||
self.task_negotiation_timeout,
|
||||
self.model.accept_uni_stream(recv),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| Error::TaskNegotiationTimeout)??;
|
||||
|
||||
if let Task::Authenticate(auth) = &task {
|
||||
self.authenticate(auth)?;
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
() = self.auth.clone() => {}
|
||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||
};
|
||||
|
||||
let same_pkt_src = matches!(task, Task::Packet(_))
|
||||
&& matches!(self.udp_relay_mode.load(), Some(UdpRelayMode::Native));
|
||||
if same_pkt_src {
|
||||
return Err(Error::UnexpectedPacketSource);
|
||||
}
|
||||
|
||||
Ok(task)
|
||||
};
|
||||
|
||||
match pre_process.await {
|
||||
Ok(Task::Authenticate(auth)) => self.handle_authenticate(auth).await,
|
||||
Ok(Task::Packet(pkt)) => self.handle_packet(pkt, UdpRelayMode::Quic).await,
|
||||
Ok(Task::Dissociate(assoc_id)) => self.handle_dissociate(assoc_id).await,
|
||||
Ok(_) => unreachable!(), // already filtered in `tuic_quinn`
|
||||
Err(err) => {
|
||||
log::warn!("[{addr}] handle unidirection stream error: {err}");
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_bi_stream(
|
||||
self,
|
||||
(send, recv): (SendStream, RecvStream),
|
||||
_reg: Register,
|
||||
) {
|
||||
let addr = self.inner.remote_address();
|
||||
log::debug!("[{addr}] incoming bidirectional stream");
|
||||
|
||||
let max = self.max_concurrent_bi_streams.load(Ordering::Relaxed);
|
||||
|
||||
if self.remote_bi_stream_cnt.count() as u32 == max {
|
||||
self.max_concurrent_bi_streams
|
||||
.store(max * 2, Ordering::Relaxed);
|
||||
|
||||
self.inner
|
||||
.set_max_concurrent_bi_streams(VarInt::from(max * 2));
|
||||
}
|
||||
|
||||
let pre_process = async {
|
||||
let task = time::timeout(
|
||||
self.task_negotiation_timeout,
|
||||
self.model.accept_bi_stream(send, recv),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| Error::TaskNegotiationTimeout)??;
|
||||
|
||||
tokio::select! {
|
||||
() = self.auth.clone() => {}
|
||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||
};
|
||||
|
||||
Ok(task)
|
||||
};
|
||||
|
||||
match pre_process.await {
|
||||
Ok(Task::Connect(conn)) => self.handle_connect(conn).await,
|
||||
Ok(_) => unreachable!(), // already filtered in `tuic_quinn`
|
||||
Err(err) => {
|
||||
log::warn!("[{addr}] handle bidirection stream error: {err}");
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_datagram(self, dg: Bytes) {
|
||||
let addr = self.inner.remote_address();
|
||||
log::debug!("[{addr}] incoming datagram");
|
||||
|
||||
let pre_process = async {
|
||||
let task = self.model.accept_datagram(dg)?;
|
||||
|
||||
tokio::select! {
|
||||
() = self.auth.clone() => {}
|
||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||
};
|
||||
|
||||
let same_pkt_src = matches!(task, Task::Packet(_))
|
||||
&& matches!(self.udp_relay_mode.load(), Some(UdpRelayMode::Quic));
|
||||
if same_pkt_src {
|
||||
return Err(Error::UnexpectedPacketSource);
|
||||
}
|
||||
|
||||
Ok(task)
|
||||
};
|
||||
|
||||
match pre_process.await {
|
||||
Ok(Task::Packet(pkt)) => self.handle_packet(pkt, UdpRelayMode::Native).await,
|
||||
Ok(Task::Heartbeat) => self.handle_heartbeat().await,
|
||||
Ok(_) => unreachable!(),
|
||||
Err(err) => {
|
||||
log::warn!("[{addr}] handle datagram error: {err}");
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
191
tuic-server/src/connection/handle_task.rs
Normal file
191
tuic-server/src/connection/handle_task.rs
Normal file
@ -0,0 +1,191 @@
|
||||
use super::{UdpSession, ERROR_CODE};
|
||||
use crate::{Connection, Error, UdpRelayMode};
|
||||
use bytes::Bytes;
|
||||
use std::{
|
||||
collections::hash_map::Entry,
|
||||
io::{Error as IoError, ErrorKind},
|
||||
net::SocketAddr,
|
||||
};
|
||||
use tokio::{
|
||||
io::{self, AsyncWriteExt},
|
||||
net::{self, TcpStream},
|
||||
};
|
||||
use tokio_util::compat::FuturesAsyncReadCompatExt;
|
||||
use tuic::Address;
|
||||
use tuic_quinn::{Authenticate, Connect, Packet};
|
||||
|
||||
impl Connection {
|
||||
pub(super) async fn handle_authenticate(&self, auth: Authenticate) {
|
||||
log::info!(
|
||||
"[{addr}] [{uuid}] [authenticate] authenticated as {auth_uuid}",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
auth_uuid = auth.uuid(),
|
||||
);
|
||||
}
|
||||
|
||||
pub(super) async fn handle_connect(&self, conn: Connect) {
|
||||
let target_addr = conn.addr().to_string();
|
||||
|
||||
log::info!(
|
||||
"[{addr}] [{uuid}] [connect] {target_addr}",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
);
|
||||
|
||||
let process = async {
|
||||
let mut stream = None;
|
||||
let mut last_err = None;
|
||||
|
||||
match resolve_dns(conn.addr()).await {
|
||||
Ok(addrs) => {
|
||||
for addr in addrs {
|
||||
match TcpStream::connect(addr).await {
|
||||
Ok(s) => {
|
||||
stream = Some(s);
|
||||
break;
|
||||
}
|
||||
Err(err) => last_err = Some(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => last_err = Some(err),
|
||||
}
|
||||
|
||||
if let Some(mut stream) = stream {
|
||||
let mut conn = conn.compat();
|
||||
let res = io::copy_bidirectional(&mut conn, &mut stream).await;
|
||||
let _ = conn.get_mut().reset(ERROR_CODE);
|
||||
let _ = stream.shutdown().await;
|
||||
res?;
|
||||
Ok::<_, Error>(())
|
||||
} else {
|
||||
let _ = conn.compat().shutdown().await;
|
||||
Err(last_err
|
||||
.unwrap_or_else(|| IoError::new(ErrorKind::NotFound, "no address resolved")))?
|
||||
}
|
||||
};
|
||||
|
||||
match process.await {
|
||||
Ok(()) => {}
|
||||
Err(err) => log::warn!(
|
||||
"[{addr}] [{uuid}] [connect] relaying connection to {target_addr} error: {err}",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_packet(&self, pkt: Packet, mode: UdpRelayMode) {
|
||||
let assoc_id = pkt.assoc_id();
|
||||
let pkt_id = pkt.pkt_id();
|
||||
let frag_id = pkt.frag_id();
|
||||
let frag_total = pkt.frag_total();
|
||||
|
||||
log::info!(
|
||||
"[{addr}] [{uuid}] [packet] [{assoc_id:#06x}] [from-{mode}] [{pkt_id:#06x}] {frag_id}/{frag_total}",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
);
|
||||
|
||||
self.udp_relay_mode.store(Some(mode));
|
||||
|
||||
let process = async {
|
||||
let Some((pkt, addr, assoc_id)) = pkt.accept().await? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let session = match self.udp_sessions.lock().entry(assoc_id) {
|
||||
Entry::Occupied(entry) => entry.get().clone(),
|
||||
Entry::Vacant(entry) => {
|
||||
let session = UdpSession::new(
|
||||
self.clone(),
|
||||
assoc_id,
|
||||
self.udp_relay_ipv6,
|
||||
self.max_external_pkt_size,
|
||||
)?;
|
||||
entry.insert(session.clone());
|
||||
session
|
||||
}
|
||||
};
|
||||
|
||||
let Some(socket_addr) = resolve_dns(&addr).await?.next() else {
|
||||
return Err(Error::from(IoError::new(ErrorKind::NotFound, "no address resolved")));
|
||||
};
|
||||
|
||||
session.send(pkt, socket_addr).await
|
||||
};
|
||||
|
||||
match process.await {
|
||||
Ok(()) => {}
|
||||
Err(err) => log::warn!(
|
||||
"[{addr}] [{uuid}] [packet] [{assoc_id:#06x}] [from-{mode}] [{pkt_id:#06x}] error handling fragment {frag_id}/{frag_total}: {err}",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_dissociate(&self, assoc_id: u16) {
|
||||
log::info!(
|
||||
"[{addr}] [{uuid}] [dissociate] [{assoc_id:#06x}]",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
);
|
||||
|
||||
if let Some(session) = self.udp_sessions.lock().remove(&assoc_id) {
|
||||
session.close();
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn handle_heartbeat(&self) {
|
||||
log::info!(
|
||||
"[{addr}] [{uuid}] [heartbeat]",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
pub(super) async fn send_packet(self, pkt: Bytes, addr: Address, assoc_id: u16) {
|
||||
let addr_display = addr.to_string();
|
||||
|
||||
let res = match self.udp_relay_mode.load() {
|
||||
Some(UdpRelayMode::Native) => {
|
||||
log::info!(
|
||||
"[{addr}] [packet-to-native] [{assoc_id}] [{target_addr}]",
|
||||
addr = self.inner.remote_address(),
|
||||
target_addr = addr_display,
|
||||
);
|
||||
self.model.packet_native(pkt, addr, assoc_id)
|
||||
}
|
||||
Some(UdpRelayMode::Quic) => {
|
||||
log::info!(
|
||||
"[{addr}] [packet-to-quic] [{assoc_id}] [{target_addr}]",
|
||||
addr = self.inner.remote_address(),
|
||||
target_addr = addr_display,
|
||||
);
|
||||
self.model.packet_quic(pkt, addr, assoc_id).await
|
||||
}
|
||||
None => unreachable!(),
|
||||
};
|
||||
|
||||
if let Err(err) = res {
|
||||
log::warn!(
|
||||
"[{addr}] [packet-to-native] [{assoc_id}] [{target_addr}] {err}",
|
||||
addr = self.inner.remote_address(),
|
||||
target_addr = addr_display,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve_dns(addr: &Address) -> Result<impl Iterator<Item = SocketAddr>, IoError> {
|
||||
match addr {
|
||||
Address::None => Err(IoError::new(ErrorKind::InvalidInput, "empty address")),
|
||||
Address::DomainAddress(domain, port) => Ok(net::lookup_host((domain.as_str(), *port))
|
||||
.await?
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()),
|
||||
Address::SocketAddress(addr) => Ok(vec![*addr].into_iter()),
|
||||
}
|
||||
}
|
183
tuic-server/src/connection/mod.rs
Normal file
183
tuic-server/src/connection/mod.rs
Normal file
@ -0,0 +1,183 @@
|
||||
use self::{authenticated::Authenticated, udp_session::UdpSession};
|
||||
use crate::{Error, UdpRelayMode};
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use parking_lot::Mutex;
|
||||
use quinn::{Connecting, Connection as QuinnConnection, VarInt};
|
||||
use register_count::Counter;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{atomic::AtomicU32, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::time;
|
||||
use tuic_quinn::{side, Authenticate, Connection as Model};
|
||||
use uuid::Uuid;
|
||||
|
||||
mod authenticated;
|
||||
mod handle_stream;
|
||||
mod handle_task;
|
||||
mod udp_session;
|
||||
|
||||
pub(crate) const ERROR_CODE: VarInt = VarInt::from_u32(0);
|
||||
pub(crate) const DEFAULT_CONCURRENT_STREAMS: u32 = 32;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Connection {
|
||||
inner: QuinnConnection,
|
||||
model: Model<side::Server>,
|
||||
users: Arc<HashMap<Uuid, Vec<u8>>>,
|
||||
udp_relay_ipv6: bool,
|
||||
auth: Authenticated,
|
||||
task_negotiation_timeout: Duration,
|
||||
udp_sessions: Arc<Mutex<HashMap<u16, UdpSession>>>,
|
||||
udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>,
|
||||
max_external_pkt_size: usize,
|
||||
remote_uni_stream_cnt: Counter,
|
||||
remote_bi_stream_cnt: Counter,
|
||||
max_concurrent_uni_streams: Arc<AtomicU32>,
|
||||
max_concurrent_bi_streams: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl Connection {
|
||||
pub async fn handle(
|
||||
conn: Connecting,
|
||||
users: Arc<HashMap<Uuid, Vec<u8>>>,
|
||||
udp_relay_ipv6: bool,
|
||||
zero_rtt_handshake: bool,
|
||||
auth_timeout: Duration,
|
||||
task_negotiation_timeout: Duration,
|
||||
max_external_pkt_size: usize,
|
||||
gc_interval: Duration,
|
||||
gc_lifetime: Duration,
|
||||
) {
|
||||
let addr = conn.remote_address();
|
||||
|
||||
let init = async {
|
||||
let conn = if zero_rtt_handshake {
|
||||
match conn.into_0rtt() {
|
||||
Ok((conn, _)) => conn,
|
||||
Err(conn) => conn.await?,
|
||||
}
|
||||
} else {
|
||||
conn.await?
|
||||
};
|
||||
|
||||
Ok::<_, Error>(Self::new(
|
||||
conn,
|
||||
users,
|
||||
udp_relay_ipv6,
|
||||
task_negotiation_timeout,
|
||||
max_external_pkt_size,
|
||||
))
|
||||
};
|
||||
|
||||
match init.await {
|
||||
Ok(conn) => {
|
||||
log::info!("[{addr}] connection established");
|
||||
|
||||
tokio::spawn(conn.clone().timeout_authenticate(auth_timeout));
|
||||
tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime));
|
||||
|
||||
loop {
|
||||
if conn.is_closed() {
|
||||
break;
|
||||
}
|
||||
|
||||
let handle_incoming = async {
|
||||
tokio::select! {
|
||||
res = conn.inner.accept_uni() =>
|
||||
tokio::spawn(conn.clone().handle_uni_stream(res?, conn.remote_uni_stream_cnt.reg())),
|
||||
res = conn.inner.accept_bi() =>
|
||||
tokio::spawn(conn.clone().handle_bi_stream(res?, conn.remote_bi_stream_cnt.reg())),
|
||||
res = conn.inner.read_datagram() =>
|
||||
tokio::spawn(conn.clone().handle_datagram(res?)),
|
||||
};
|
||||
|
||||
Ok::<_, Error>(())
|
||||
};
|
||||
|
||||
match handle_incoming.await {
|
||||
Ok(()) => {}
|
||||
Err(err) if err.is_locally_closed() => {}
|
||||
Err(err) if err.is_timeout_closed() => {
|
||||
log::debug!("[{addr}] connection timeout")
|
||||
}
|
||||
Err(err) => log::warn!("[{addr}] connection error: {err}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) if err.is_locally_closed() || err.is_timeout_closed() => unreachable!(),
|
||||
Err(err) => log::warn!("[{addr}] connection establishing error: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn new(
|
||||
conn: QuinnConnection,
|
||||
users: Arc<HashMap<Uuid, Vec<u8>>>,
|
||||
udp_relay_ipv6: bool,
|
||||
task_negotiation_timeout: Duration,
|
||||
max_external_pkt_size: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: conn.clone(),
|
||||
model: Model::<side::Server>::new(conn),
|
||||
users,
|
||||
udp_relay_ipv6,
|
||||
auth: Authenticated::new(),
|
||||
task_negotiation_timeout,
|
||||
udp_sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||
udp_relay_mode: Arc::new(AtomicCell::new(None)),
|
||||
max_external_pkt_size,
|
||||
remote_uni_stream_cnt: Counter::new(),
|
||||
remote_bi_stream_cnt: Counter::new(),
|
||||
max_concurrent_uni_streams: Arc::new(AtomicU32::new(DEFAULT_CONCURRENT_STREAMS)),
|
||||
max_concurrent_bi_streams: Arc::new(AtomicU32::new(DEFAULT_CONCURRENT_STREAMS)),
|
||||
}
|
||||
}
|
||||
|
||||
fn authenticate(&self, auth: &Authenticate) -> Result<(), Error> {
|
||||
if self.auth.get().is_some() {
|
||||
Err(Error::DuplicatedAuth)
|
||||
} else if self
|
||||
.users
|
||||
.get(&auth.uuid())
|
||||
.map_or(false, |password| auth.validate(password))
|
||||
{
|
||||
self.auth.set(auth.uuid());
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::AuthFailed(auth.uuid()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn timeout_authenticate(self, timeout: Duration) {
|
||||
time::sleep(timeout).await;
|
||||
|
||||
if self.auth.get().is_none() {
|
||||
let addr = self.inner.remote_address();
|
||||
log::warn!("[{addr}] [authenticate] timeout");
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
|
||||
async fn collect_garbage(self, gc_interval: Duration, gc_lifetime: Duration) {
|
||||
loop {
|
||||
time::sleep(gc_interval).await;
|
||||
|
||||
if self.is_closed() {
|
||||
break;
|
||||
}
|
||||
|
||||
self.model.collect_garbage(gc_lifetime);
|
||||
}
|
||||
}
|
||||
|
||||
fn is_closed(&self) -> bool {
|
||||
self.inner.close_reason().is_some()
|
||||
}
|
||||
|
||||
fn close(&self) {
|
||||
self.inner.close(ERROR_CODE, &[]);
|
||||
}
|
||||
}
|
161
tuic-server/src/connection/udp_session.rs
Normal file
161
tuic-server/src/connection/udp_session.rs
Normal file
@ -0,0 +1,161 @@
|
||||
use crate::{Connection, Error};
|
||||
use bytes::Bytes;
|
||||
use parking_lot::Mutex;
|
||||
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
|
||||
use std::{
|
||||
io::Error as IoError,
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket},
|
||||
sync::Arc,
|
||||
};
|
||||
use tokio::{
|
||||
net::UdpSocket,
|
||||
sync::oneshot::{self, Sender},
|
||||
};
|
||||
use tuic::Address;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) struct UdpSession(Arc<UdpSessionInner>);
|
||||
|
||||
struct UdpSessionInner {
|
||||
assoc_id: u16,
|
||||
conn: Connection,
|
||||
socket_v4: UdpSocket,
|
||||
socket_v6: Option<UdpSocket>,
|
||||
max_pkt_size: usize,
|
||||
close: Mutex<Option<Sender<()>>>,
|
||||
}
|
||||
|
||||
impl UdpSession {
|
||||
pub(super) fn new(
|
||||
conn: Connection,
|
||||
assoc_id: u16,
|
||||
udp_relay_ipv6: bool,
|
||||
max_pkt_size: usize,
|
||||
) -> Result<Self, Error> {
|
||||
let socket_v4 = {
|
||||
let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))
|
||||
.map_err(|err| Error::Socket("failed to create UDP associate IPv4 socket", err))?;
|
||||
|
||||
socket.set_nonblocking(true).map_err(|err| {
|
||||
Error::Socket(
|
||||
"failed setting UDP associate IPv4 socket as non-blocking",
|
||||
err,
|
||||
)
|
||||
})?;
|
||||
|
||||
socket
|
||||
.bind(&SockAddr::from(SocketAddr::from((
|
||||
Ipv4Addr::UNSPECIFIED,
|
||||
0,
|
||||
))))
|
||||
.map_err(|err| Error::Socket("failed to bind UDP associate IPv4 socket", err))?;
|
||||
|
||||
UdpSocket::from_std(StdUdpSocket::from(socket))?
|
||||
};
|
||||
|
||||
let socket_v6 = if udp_relay_ipv6 {
|
||||
let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))
|
||||
.map_err(|err| Error::Socket("failed to create UDP associate IPv6 socket", err))?;
|
||||
|
||||
socket.set_nonblocking(true).map_err(|err| {
|
||||
Error::Socket(
|
||||
"failed setting UDP associate IPv6 socket as non-blocking",
|
||||
err,
|
||||
)
|
||||
})?;
|
||||
|
||||
socket.set_only_v6(true).map_err(|err| {
|
||||
Error::Socket("failed setting UDP associate IPv6 socket as IPv6-only", err)
|
||||
})?;
|
||||
|
||||
socket
|
||||
.bind(&SockAddr::from(SocketAddr::from((
|
||||
Ipv6Addr::UNSPECIFIED,
|
||||
0,
|
||||
))))
|
||||
.map_err(|err| Error::Socket("failed to bind UDP associate IPv6 socket", err))?;
|
||||
|
||||
Some(UdpSocket::from_std(StdUdpSocket::from(socket))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let session = Self(Arc::new(UdpSessionInner {
|
||||
conn,
|
||||
assoc_id,
|
||||
socket_v4,
|
||||
socket_v6,
|
||||
max_pkt_size,
|
||||
close: Mutex::new(Some(tx)),
|
||||
}));
|
||||
|
||||
let session_listening = session.clone();
|
||||
let listen = async move {
|
||||
loop {
|
||||
let (pkt, addr) = match session_listening.recv().await {
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
log::warn!("{err}"); // TODO
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(session_listening.0.conn.clone().send_packet(
|
||||
pkt,
|
||||
Address::SocketAddress(addr),
|
||||
session_listening.0.assoc_id,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = listen => unreachable!(),
|
||||
_ = rx => {},
|
||||
}
|
||||
});
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
pub(super) async fn send(&self, pkt: Bytes, addr: SocketAddr) -> Result<(), Error> {
|
||||
let socket = match addr {
|
||||
SocketAddr::V4(_) => &self.0.socket_v4,
|
||||
SocketAddr::V6(_) => self
|
||||
.0
|
||||
.socket_v6
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::UdpRelayIpv6Disabled(addr))?,
|
||||
};
|
||||
|
||||
socket.send_to(&pkt, addr).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recv(&self) -> Result<(Bytes, SocketAddr), IoError> {
|
||||
async fn recv(
|
||||
socket: &UdpSocket,
|
||||
max_pkt_size: usize,
|
||||
) -> Result<(Bytes, SocketAddr), IoError> {
|
||||
let mut buf = vec![0u8; max_pkt_size];
|
||||
let (n, addr) = socket.recv_from(&mut buf).await?;
|
||||
buf.truncate(n);
|
||||
Ok((Bytes::from(buf), addr))
|
||||
}
|
||||
|
||||
if let Some(socket_v6) = &self.0.socket_v6 {
|
||||
tokio::select! {
|
||||
res = recv(&self.0.socket_v4, self.0.max_pkt_size) => res,
|
||||
res = recv(socket_v6, self.0.max_pkt_size) => res,
|
||||
}
|
||||
} else {
|
||||
recv(&self.0.socket_v4, self.0.max_pkt_size).await
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn close(&self) {
|
||||
let _ = self.0.close.lock().take().unwrap().send(());
|
||||
}
|
||||
}
|
44
tuic-server/src/error.rs
Normal file
44
tuic-server/src/error.rs
Normal file
@ -0,0 +1,44 @@
|
||||
use quinn::ConnectionError;
|
||||
use rustls::Error as RustlsError;
|
||||
use std::{io::Error as IoError, net::SocketAddr};
|
||||
use thiserror::Error;
|
||||
use tuic_quinn::Error as ModelError;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
#[error(transparent)]
|
||||
Io(#[from] IoError),
|
||||
#[error(transparent)]
|
||||
Rustls(#[from] RustlsError),
|
||||
#[error("invalid max idle time")]
|
||||
InvalidMaxIdleTime,
|
||||
#[error(transparent)]
|
||||
Connection(#[from] ConnectionError),
|
||||
#[error(transparent)]
|
||||
Model(#[from] ModelError),
|
||||
#[error("duplicated authentication")]
|
||||
DuplicatedAuth,
|
||||
#[error("token length too short")]
|
||||
ExportKeyingMaterial,
|
||||
#[error("authentication failed: {0}")]
|
||||
AuthFailed(Uuid),
|
||||
#[error("received packet from unexpected source")]
|
||||
UnexpectedPacketSource,
|
||||
#[error("{0}: {1}")]
|
||||
Socket(&'static str, IoError),
|
||||
#[error("task negotiation timed out")]
|
||||
TaskNegotiationTimeout,
|
||||
#[error("failed sending packet to {0}: relaying IPv6 UDP packet is disabled")]
|
||||
UdpRelayIpv6Disabled(SocketAddr),
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn is_locally_closed(&self) -> bool {
|
||||
matches!(self, Self::Connection(ConnectionError::LocallyClosed))
|
||||
}
|
||||
|
||||
pub fn is_timeout_closed(&self) -> bool {
|
||||
matches!(self, Self::Connection(ConnectionError::TimedOut))
|
||||
}
|
||||
}
|
14
tuic-server/src/lib.rs
Normal file
14
tuic-server/src/lib.rs
Normal file
@ -0,0 +1,14 @@
|
||||
pub(crate) mod config;
|
||||
pub(crate) mod error;
|
||||
pub(crate) mod server;
|
||||
pub(crate) mod utils;
|
||||
|
||||
pub mod connection;
|
||||
|
||||
pub use crate::{
|
||||
config::{Config, ConfigError},
|
||||
connection::Connection,
|
||||
error::Error,
|
||||
server::Server,
|
||||
utils::{CongestionControl, UdpRelayMode},
|
||||
};
|
@ -1,19 +1,6 @@
|
||||
use self::{
|
||||
config::{Config, ConfigError},
|
||||
server::Server,
|
||||
};
|
||||
use env_logger::Builder as LoggerBuilder;
|
||||
use quinn::ConnectionError;
|
||||
use rustls::Error as RustlsError;
|
||||
use std::{env, io::Error as IoError, net::SocketAddr, process};
|
||||
use thiserror::Error;
|
||||
use tuic::Address;
|
||||
use tuic_quinn::Error as ModelError;
|
||||
use uuid::Uuid;
|
||||
|
||||
mod config;
|
||||
mod server;
|
||||
mod utils;
|
||||
use std::{env, process};
|
||||
use tuic_server::{Config, ConfigError, Server};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
@ -43,31 +30,3 @@ async fn main() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
#[error(transparent)]
|
||||
Io(#[from] IoError),
|
||||
#[error(transparent)]
|
||||
Rustls(#[from] RustlsError),
|
||||
#[error("invalid max idle time")]
|
||||
InvalidMaxIdleTime,
|
||||
#[error(transparent)]
|
||||
Connection(#[from] ConnectionError),
|
||||
#[error(transparent)]
|
||||
Model(#[from] ModelError),
|
||||
#[error("duplicated authentication")]
|
||||
DuplicatedAuth,
|
||||
#[error("token length too short")]
|
||||
ExportKeyingMaterial,
|
||||
#[error("authentication failed: {0}")]
|
||||
AuthFailed(Uuid),
|
||||
#[error("received packet from unexpected source")]
|
||||
UnexpectedPacketSource,
|
||||
#[error("{0}: {1}")]
|
||||
Socket(&'static str, IoError),
|
||||
#[error("task negotiation timed out")]
|
||||
TaskNegotiationTimeout,
|
||||
#[error("{0} resolved to {1} but IPv6 UDP relaying is disabled")]
|
||||
UdpRelayIpv6Disabled(Address, SocketAddr),
|
||||
}
|
||||
|
@ -1,49 +1,21 @@
|
||||
use crate::{
|
||||
config::Config,
|
||||
utils::{self, CongestionControl, UdpRelayMode},
|
||||
config::Config, connection::DEFAULT_CONCURRENT_STREAMS, utils, CongestionControl, Connection,
|
||||
Error,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use parking_lot::Mutex;
|
||||
use quinn::{
|
||||
congestion::{BbrConfig, CubicConfig, NewRenoConfig},
|
||||
Connecting, Connection as QuinnConnection, ConnectionError, Endpoint, EndpointConfig,
|
||||
IdleTimeout, RecvStream, SendStream, ServerConfig, TokioRuntime, TransportConfig, VarInt,
|
||||
Endpoint, EndpointConfig, IdleTimeout, ServerConfig, TokioRuntime, TransportConfig, VarInt,
|
||||
};
|
||||
use register_count::{Counter, Register};
|
||||
use rustls::{version, ServerConfig as RustlsServerConfig};
|
||||
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
future::Future,
|
||||
io::{Error as IoError, ErrorKind},
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket},
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicU32, Ordering},
|
||||
Arc,
|
||||
},
|
||||
task::{Context, Poll, Waker},
|
||||
collections::HashMap,
|
||||
net::{SocketAddr, UdpSocket as StdUdpSocket},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
io::{self, AsyncWriteExt},
|
||||
net::{self, TcpStream, UdpSocket},
|
||||
sync::{
|
||||
oneshot::{self, Receiver, Sender},
|
||||
Mutex as AsyncMutex,
|
||||
},
|
||||
time,
|
||||
};
|
||||
use tokio_util::compat::FuturesAsyncReadCompatExt;
|
||||
use tuic::Address;
|
||||
use tuic_quinn::{side, Authenticate, Connect, Connection as Model, Packet, Task};
|
||||
use uuid::Uuid;
|
||||
|
||||
const ERROR_CODE: VarInt = VarInt::from_u32(0);
|
||||
const DEFAULT_CONCURRENT_STREAMS: u32 = 32;
|
||||
|
||||
pub struct Server {
|
||||
ep: Endpoint,
|
||||
users: Arc<HashMap<Uuid, Vec<u8>>>,
|
||||
@ -172,638 +144,3 @@ impl Server {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Connection {
|
||||
inner: QuinnConnection,
|
||||
model: Model<side::Server>,
|
||||
users: Arc<HashMap<Uuid, Vec<u8>>>,
|
||||
udp_relay_ipv6: bool,
|
||||
auth: Authenticated,
|
||||
task_negotiation_timeout: Duration,
|
||||
udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>,
|
||||
udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>,
|
||||
max_external_pkt_size: usize,
|
||||
remote_uni_stream_cnt: Counter,
|
||||
remote_bi_stream_cnt: Counter,
|
||||
max_concurrent_uni_streams: Arc<AtomicU32>,
|
||||
max_concurrent_bi_streams: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
impl Connection {
|
||||
async fn handle(
|
||||
conn: Connecting,
|
||||
users: Arc<HashMap<Uuid, Vec<u8>>>,
|
||||
udp_relay_ipv6: bool,
|
||||
zero_rtt_handshake: bool,
|
||||
auth_timeout: Duration,
|
||||
task_negotiation_timeout: Duration,
|
||||
max_external_pkt_size: usize,
|
||||
gc_interval: Duration,
|
||||
gc_lifetime: Duration,
|
||||
) {
|
||||
let addr = conn.remote_address();
|
||||
|
||||
let init = async {
|
||||
let conn = if zero_rtt_handshake {
|
||||
match conn.into_0rtt() {
|
||||
Ok((conn, _)) => conn,
|
||||
Err(conn) => conn.await?,
|
||||
}
|
||||
} else {
|
||||
conn.await?
|
||||
};
|
||||
|
||||
Ok::<_, Error>(Self::new(
|
||||
conn,
|
||||
users,
|
||||
udp_relay_ipv6,
|
||||
task_negotiation_timeout,
|
||||
max_external_pkt_size,
|
||||
))
|
||||
};
|
||||
|
||||
match init.await {
|
||||
Ok(conn) => {
|
||||
log::info!("[{addr}] connection established");
|
||||
|
||||
tokio::spawn(conn.clone().timeout_authenticate(auth_timeout));
|
||||
tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime));
|
||||
|
||||
loop {
|
||||
if conn.is_closed() {
|
||||
break;
|
||||
}
|
||||
|
||||
let handle_incoming = async {
|
||||
tokio::select! {
|
||||
res = conn.inner.accept_uni() =>
|
||||
tokio::spawn(conn.clone().handle_uni_stream(res?, conn.remote_uni_stream_cnt.reg())),
|
||||
res = conn.inner.accept_bi() =>
|
||||
tokio::spawn(conn.clone().handle_bi_stream(res?, conn.remote_bi_stream_cnt.reg())),
|
||||
res = conn.inner.read_datagram() =>
|
||||
tokio::spawn(conn.clone().handle_datagram(res?)),
|
||||
};
|
||||
|
||||
Ok::<_, Error>(())
|
||||
};
|
||||
|
||||
match handle_incoming.await {
|
||||
Ok(()) => {}
|
||||
Err(err) if err.is_locally_closed() => {}
|
||||
Err(err) if err.is_timeout_closed() => {
|
||||
log::debug!("[{addr}] connection timeout")
|
||||
}
|
||||
Err(err) => log::warn!("[{addr}] connection error: {err}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) if err.is_locally_closed() || err.is_timeout_closed() => unreachable!(),
|
||||
Err(err) => log::warn!("[{addr}] connection establishing error: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn new(
|
||||
conn: QuinnConnection,
|
||||
users: Arc<HashMap<Uuid, Vec<u8>>>,
|
||||
udp_relay_ipv6: bool,
|
||||
task_negotiation_timeout: Duration,
|
||||
max_external_pkt_size: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: conn.clone(),
|
||||
model: Model::<side::Server>::new(conn),
|
||||
users,
|
||||
udp_relay_ipv6,
|
||||
auth: Authenticated::new(),
|
||||
task_negotiation_timeout,
|
||||
udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||
udp_relay_mode: Arc::new(AtomicCell::new(None)),
|
||||
max_external_pkt_size,
|
||||
remote_uni_stream_cnt: Counter::new(),
|
||||
remote_bi_stream_cnt: Counter::new(),
|
||||
max_concurrent_uni_streams: Arc::new(AtomicU32::new(DEFAULT_CONCURRENT_STREAMS)),
|
||||
max_concurrent_bi_streams: Arc::new(AtomicU32::new(DEFAULT_CONCURRENT_STREAMS)),
|
||||
}
|
||||
}
|
||||
|
||||
fn authenticate(&self, auth: &Authenticate) -> Result<(), Error> {
|
||||
if self.auth.get().is_some() {
|
||||
Err(Error::DuplicatedAuth)
|
||||
} else if self
|
||||
.users
|
||||
.get(&auth.uuid())
|
||||
.map_or(false, |password| auth.validate(password))
|
||||
{
|
||||
self.auth.set(auth.uuid());
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::AuthFailed(auth.uuid()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn timeout_authenticate(self, timeout: Duration) {
|
||||
time::sleep(timeout).await;
|
||||
|
||||
if self.auth.get().is_none() {
|
||||
let addr = self.inner.remote_address();
|
||||
log::warn!("[{addr}] [authenticate] timeout");
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) {
|
||||
let addr = self.inner.remote_address();
|
||||
log::debug!("[{addr}] incoming unidirectional stream");
|
||||
|
||||
let max = self.max_concurrent_uni_streams.load(Ordering::Relaxed);
|
||||
|
||||
if self.remote_uni_stream_cnt.count() as u32 == max {
|
||||
self.max_concurrent_uni_streams
|
||||
.store(max * 2, Ordering::Relaxed);
|
||||
|
||||
self.inner
|
||||
.set_max_concurrent_uni_streams(VarInt::from(max * 2));
|
||||
}
|
||||
|
||||
let pre_process = async {
|
||||
let task = time::timeout(
|
||||
self.task_negotiation_timeout,
|
||||
self.model.accept_uni_stream(recv),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| Error::TaskNegotiationTimeout)??;
|
||||
|
||||
if let Task::Authenticate(auth) = &task {
|
||||
self.authenticate(auth)?;
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
() = self.auth.clone() => {}
|
||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||
};
|
||||
|
||||
let same_pkt_src = matches!(task, Task::Packet(_))
|
||||
&& matches!(self.get_udp_relay_mode(), Some(UdpRelayMode::Native));
|
||||
if same_pkt_src {
|
||||
return Err(Error::UnexpectedPacketSource);
|
||||
}
|
||||
|
||||
Ok(task)
|
||||
};
|
||||
|
||||
match pre_process.await {
|
||||
Ok(Task::Authenticate(auth)) => self.handle_authenticate(auth).await,
|
||||
Ok(Task::Packet(pkt)) => self.handle_packet(pkt, UdpRelayMode::Quic).await,
|
||||
Ok(Task::Dissociate(assoc_id)) => self.handle_dissociate(assoc_id).await,
|
||||
Ok(_) => unreachable!(), // already filtered in `tuic_quinn`
|
||||
Err(err) => {
|
||||
log::warn!("[{addr}] handle unidirection stream error: {err}");
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_bi_stream(self, (send, recv): (SendStream, RecvStream), _reg: Register) {
|
||||
let addr = self.inner.remote_address();
|
||||
log::debug!("[{addr}] incoming bidirectional stream");
|
||||
|
||||
let max = self.max_concurrent_bi_streams.load(Ordering::Relaxed);
|
||||
|
||||
if self.remote_bi_stream_cnt.count() as u32 == max {
|
||||
self.max_concurrent_bi_streams
|
||||
.store(max * 2, Ordering::Relaxed);
|
||||
|
||||
self.inner
|
||||
.set_max_concurrent_bi_streams(VarInt::from(max * 2));
|
||||
}
|
||||
|
||||
let pre_process = async {
|
||||
let task = time::timeout(
|
||||
self.task_negotiation_timeout,
|
||||
self.model.accept_bi_stream(send, recv),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| Error::TaskNegotiationTimeout)??;
|
||||
|
||||
tokio::select! {
|
||||
() = self.auth.clone() => {}
|
||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||
};
|
||||
|
||||
Ok(task)
|
||||
};
|
||||
|
||||
match pre_process.await {
|
||||
Ok(Task::Connect(conn)) => self.handle_connect(conn).await,
|
||||
Ok(_) => unreachable!(), // already filtered in `tuic_quinn`
|
||||
Err(err) => {
|
||||
log::warn!("[{addr}] handle bidirection stream error: {err}");
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_datagram(self, dg: Bytes) {
|
||||
let addr = self.inner.remote_address();
|
||||
log::debug!("[{addr}] incoming datagram");
|
||||
|
||||
let pre_process = async {
|
||||
let task = self.model.accept_datagram(dg)?;
|
||||
|
||||
tokio::select! {
|
||||
() = self.auth.clone() => {}
|
||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||
};
|
||||
|
||||
let same_pkt_src = matches!(task, Task::Packet(_))
|
||||
&& matches!(self.get_udp_relay_mode(), Some(UdpRelayMode::Quic));
|
||||
if same_pkt_src {
|
||||
return Err(Error::UnexpectedPacketSource);
|
||||
}
|
||||
|
||||
Ok(task)
|
||||
};
|
||||
|
||||
match pre_process.await {
|
||||
Ok(Task::Packet(pkt)) => self.handle_packet(pkt, UdpRelayMode::Native).await,
|
||||
Ok(Task::Heartbeat) => self.handle_heartbeat().await,
|
||||
Ok(_) => unreachable!(),
|
||||
Err(err) => {
|
||||
log::warn!("[{addr}] handle datagram error: {err}");
|
||||
self.close();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_authenticate(&self, auth: Authenticate) {
|
||||
log::info!(
|
||||
"[{addr}] [{uuid}] [authenticate] authenticated as {auth_uuid}",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
auth_uuid = auth.uuid(),
|
||||
);
|
||||
}
|
||||
|
||||
async fn handle_connect(&self, conn: Connect) {
|
||||
let target_addr = conn.addr().to_string();
|
||||
|
||||
log::info!(
|
||||
"[{addr}] [{uuid}] [connect] {target_addr}",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
);
|
||||
|
||||
let process = async {
|
||||
let mut stream = None;
|
||||
let mut last_err = None;
|
||||
|
||||
match resolve_dns(conn.addr()).await {
|
||||
Ok(addrs) => {
|
||||
for addr in addrs {
|
||||
match TcpStream::connect(addr).await {
|
||||
Ok(s) => {
|
||||
stream = Some(s);
|
||||
break;
|
||||
}
|
||||
Err(err) => last_err = Some(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => last_err = Some(err),
|
||||
}
|
||||
|
||||
if let Some(mut stream) = stream {
|
||||
let mut conn = conn.compat();
|
||||
let res = io::copy_bidirectional(&mut conn, &mut stream).await;
|
||||
let _ = conn.get_mut().reset(ERROR_CODE);
|
||||
let _ = stream.shutdown().await;
|
||||
res?;
|
||||
Ok::<_, Error>(())
|
||||
} else {
|
||||
let _ = conn.compat().shutdown().await;
|
||||
Err(last_err
|
||||
.unwrap_or_else(|| IoError::new(ErrorKind::NotFound, "no address resolved")))?
|
||||
}
|
||||
};
|
||||
|
||||
match process.await {
|
||||
Ok(()) => {}
|
||||
Err(err) => log::warn!(
|
||||
"[{addr}] [{uuid}] [connect] relaying connection to {target_addr} error: {err}",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_packet(&self, pkt: Packet, mode: UdpRelayMode) {
|
||||
let assoc_id = pkt.assoc_id();
|
||||
let pkt_id = pkt.pkt_id();
|
||||
let frag_id = pkt.frag_id();
|
||||
let frag_total = pkt.frag_total();
|
||||
|
||||
log::info!(
|
||||
"[{addr}] [{uuid}] [packet] [{assoc_id:#06x}] [from-{mode}] [{pkt_id:#06x}] {frag_id}/{frag_total}",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
);
|
||||
|
||||
self.set_udp_relay_mode(mode);
|
||||
|
||||
let process = async {
|
||||
let Some((pkt, addr, assoc_id)) = pkt.accept().await? else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let (socket_v4, socket_v6) = match self.udp_sessions.lock().await.entry(assoc_id) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
let session = entry.get_mut();
|
||||
(session.socket_v4.clone(), session.socket_v6.clone())
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
let session = entry.insert(
|
||||
UdpSession::new(assoc_id, self.clone(), self.udp_relay_ipv6).await?,
|
||||
);
|
||||
|
||||
(session.socket_v4.clone(), session.socket_v6.clone())
|
||||
}
|
||||
};
|
||||
|
||||
let Some(socket_addr) = resolve_dns(&addr).await?.next() else {
|
||||
return Err(Error::from(IoError::new(ErrorKind::NotFound, "no address resolved")));
|
||||
};
|
||||
|
||||
let socket = match socket_addr {
|
||||
SocketAddr::V4(_) => socket_v4,
|
||||
SocketAddr::V6(_) => {
|
||||
socket_v6.ok_or_else(|| Error::UdpRelayIpv6Disabled(addr, socket_addr))?
|
||||
}
|
||||
};
|
||||
|
||||
socket.send_to(&pkt, socket_addr).await?;
|
||||
|
||||
Ok(())
|
||||
};
|
||||
|
||||
match process.await {
|
||||
Ok(()) => {}
|
||||
Err(err) => log::warn!(
|
||||
"[{addr}] [{uuid}] [packet] [{assoc_id:#06x}] [from-{mode}] [{pkt_id:#06x}] error handling fragment {frag_id}/{frag_total}: {err}",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_dissociate(&self, assoc_id: u16) {
|
||||
log::info!(
|
||||
"[{addr}] [{uuid}] [dissociate] [{assoc_id:#06x}]",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
);
|
||||
|
||||
self.udp_sessions.lock().await.remove(&assoc_id);
|
||||
}
|
||||
|
||||
async fn handle_heartbeat(&self) {
|
||||
log::info!(
|
||||
"[{addr}] [{uuid}] [heartbeat]",
|
||||
addr = self.inner.remote_address(),
|
||||
uuid = self.auth.get().unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
async fn collect_garbage(self, gc_interval: Duration, gc_lifetime: Duration) {
|
||||
loop {
|
||||
time::sleep(gc_interval).await;
|
||||
|
||||
if self.is_closed() {
|
||||
break;
|
||||
}
|
||||
|
||||
self.model.collect_garbage(gc_lifetime);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_udp_relay_mode(&self, mode: UdpRelayMode) {
|
||||
self.udp_relay_mode.store(Some(mode));
|
||||
}
|
||||
|
||||
fn get_udp_relay_mode(&self) -> Option<UdpRelayMode> {
|
||||
self.udp_relay_mode.load()
|
||||
}
|
||||
|
||||
fn is_closed(&self) -> bool {
|
||||
self.inner.close_reason().is_some()
|
||||
}
|
||||
|
||||
fn close(&self) {
|
||||
self.inner.close(ERROR_CODE, &[]);
|
||||
}
|
||||
}
|
||||
|
||||
async fn resolve_dns(addr: &Address) -> Result<impl Iterator<Item = SocketAddr>, IoError> {
|
||||
match addr {
|
||||
Address::None => Err(IoError::new(ErrorKind::InvalidInput, "empty address")),
|
||||
Address::DomainAddress(domain, port) => Ok(net::lookup_host((domain.as_str(), *port))
|
||||
.await?
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()),
|
||||
Address::SocketAddress(addr) => Ok(vec![*addr].into_iter()),
|
||||
}
|
||||
}
|
||||
|
||||
struct UdpSession {
|
||||
socket_v4: Arc<UdpSocket>,
|
||||
socket_v6: Option<Arc<UdpSocket>>,
|
||||
cancel: Option<Sender<()>>,
|
||||
}
|
||||
|
||||
impl UdpSession {
|
||||
async fn new(assoc_id: u16, conn: Connection, udp_relay_ipv6: bool) -> Result<Self, Error> {
|
||||
let socket_v4 = Arc::new(
|
||||
UdpSocket::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)))
|
||||
.await
|
||||
.map_err(|err| Error::Socket("failed to create UDP associate IPv4 socket", err))?,
|
||||
);
|
||||
let socket_v6 = if udp_relay_ipv6 {
|
||||
let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))
|
||||
.map_err(|err| Error::Socket("failed to create UDP associate IPv6 socket", err))?;
|
||||
|
||||
socket.set_nonblocking(true).map_err(|err| {
|
||||
Error::Socket(
|
||||
"failed setting UDP associate IPv6 socket as non-blocking",
|
||||
err,
|
||||
)
|
||||
})?;
|
||||
|
||||
socket.set_only_v6(true).map_err(|err| {
|
||||
Error::Socket("failed setting UDP associate IPv6 socket as IPv6-only", err)
|
||||
})?;
|
||||
|
||||
socket
|
||||
.bind(&SockAddr::from(SocketAddr::from((
|
||||
Ipv6Addr::UNSPECIFIED,
|
||||
0,
|
||||
))))
|
||||
.map_err(|err| Error::Socket("failed to bind UDP associate IPv6 socket", err))?;
|
||||
|
||||
Some(Arc::new(UdpSocket::from_std(StdUdpSocket::from(socket))?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
tokio::spawn(Self::listen_incoming(
|
||||
assoc_id,
|
||||
conn,
|
||||
socket_v4.clone(),
|
||||
socket_v6.clone(),
|
||||
rx,
|
||||
));
|
||||
|
||||
Ok(Self {
|
||||
socket_v4,
|
||||
socket_v6,
|
||||
cancel: Some(tx),
|
||||
})
|
||||
}
|
||||
|
||||
async fn listen_incoming(
|
||||
assoc_id: u16,
|
||||
conn: Connection,
|
||||
socket_v4: Arc<UdpSocket>,
|
||||
socket_v6: Option<Arc<UdpSocket>>,
|
||||
cancel: Receiver<()>,
|
||||
) {
|
||||
async fn send_pkt(conn: Connection, pkt: Bytes, target_addr: SocketAddr, assoc_id: u16) {
|
||||
let addr = conn.inner.remote_address();
|
||||
let target_addr_tuic = Address::SocketAddress(target_addr);
|
||||
|
||||
let res = match conn.get_udp_relay_mode() {
|
||||
Some(UdpRelayMode::Native) => {
|
||||
log::info!("[{addr}] [packet-to-native] [{assoc_id}] [{target_addr_tuic}]");
|
||||
conn.model.packet_native(pkt, target_addr_tuic, assoc_id)
|
||||
}
|
||||
Some(UdpRelayMode::Quic) => {
|
||||
log::info!("[{addr}] [packet-to-quic] [{assoc_id}] [{target_addr_tuic}]");
|
||||
conn.model
|
||||
.packet_quic(pkt, target_addr_tuic, assoc_id)
|
||||
.await
|
||||
}
|
||||
None => unreachable!(),
|
||||
};
|
||||
|
||||
if let Err(err) = res {
|
||||
let target_addr_tuic = Address::SocketAddress(target_addr);
|
||||
log::warn!("[{addr}] [packet-to-quic] [{assoc_id}] [{target_addr_tuic}] {err}");
|
||||
}
|
||||
}
|
||||
|
||||
let addr = conn.inner.remote_address();
|
||||
|
||||
tokio::select! {
|
||||
_ = cancel => {}
|
||||
() = async {
|
||||
loop {
|
||||
match Self::accept(
|
||||
&socket_v4,
|
||||
socket_v6.as_deref(),
|
||||
conn.max_external_pkt_size,
|
||||
).await {
|
||||
Ok((pkt, target_addr)) => {
|
||||
tokio::spawn(send_pkt(conn.clone(), pkt, target_addr, assoc_id));
|
||||
}
|
||||
Err(err) => log::warn!("[{addr}] [packet-to-*] [{assoc_id}] {err}"),
|
||||
}
|
||||
}
|
||||
} => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn accept(
|
||||
socket_v4: &UdpSocket,
|
||||
socket_v6: Option<&UdpSocket>,
|
||||
max_pkt_size: usize,
|
||||
) -> Result<(Bytes, SocketAddr), IoError> {
|
||||
async fn read_pkt(
|
||||
socket: &UdpSocket,
|
||||
max_pkt_size: usize,
|
||||
) -> Result<(Bytes, SocketAddr), IoError> {
|
||||
let mut buf = vec![0u8; max_pkt_size];
|
||||
let (n, addr) = socket.recv_from(&mut buf).await?;
|
||||
buf.truncate(n);
|
||||
Ok((Bytes::from(buf), addr))
|
||||
}
|
||||
|
||||
if let Some(socket_v6) = socket_v6 {
|
||||
tokio::select! {
|
||||
res = read_pkt(socket_v4, max_pkt_size) => res,
|
||||
res = read_pkt(socket_v6, max_pkt_size) => res,
|
||||
}
|
||||
} else {
|
||||
read_pkt(socket_v4, max_pkt_size).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for UdpSession {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.cancel.take().unwrap().send(());
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Authenticated(Arc<AuthenticatedInner>);
|
||||
|
||||
struct AuthenticatedInner {
|
||||
uuid: AtomicCell<Option<Uuid>>,
|
||||
broadcast: Mutex<Vec<Waker>>,
|
||||
}
|
||||
|
||||
impl Authenticated {
|
||||
fn new() -> Self {
|
||||
Self(Arc::new(AuthenticatedInner {
|
||||
uuid: AtomicCell::new(None),
|
||||
broadcast: Mutex::new(Vec::new()),
|
||||
}))
|
||||
}
|
||||
|
||||
fn set(&self, uuid: Uuid) {
|
||||
self.0.uuid.store(Some(uuid));
|
||||
|
||||
for waker in self.0.broadcast.lock().drain(..) {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
fn get(&self) -> Option<Uuid> {
|
||||
self.0.uuid.load()
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for Authenticated {
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if self.get().is_some() {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
self.0.broadcast.lock().push(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
fn is_locally_closed(&self) -> bool {
|
||||
matches!(self, Self::Connection(ConnectionError::LocallyClosed))
|
||||
}
|
||||
|
||||
fn is_timeout_closed(&self) -> bool {
|
||||
matches!(self, Self::Connection(ConnectionError::TimedOut))
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ use std::{
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
pub fn load_certs(path: PathBuf) -> Result<Vec<Certificate>, IoError> {
|
||||
pub(crate) fn load_certs(path: PathBuf) -> Result<Vec<Certificate>, IoError> {
|
||||
let mut file = BufReader::new(File::open(&path)?);
|
||||
let mut certs = Vec::new();
|
||||
|
||||
@ -25,7 +25,7 @@ pub fn load_certs(path: PathBuf) -> Result<Vec<Certificate>, IoError> {
|
||||
Ok(certs)
|
||||
}
|
||||
|
||||
pub fn load_priv_key(path: PathBuf) -> Result<PrivateKey, IoError> {
|
||||
pub(crate) fn load_priv_key(path: PathBuf) -> Result<PrivateKey, IoError> {
|
||||
let mut file = BufReader::new(File::open(&path)?);
|
||||
let mut priv_key = None;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user