From eb228e554e59957a4e9650c31ac4a53092a766f0 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 29 May 2023 22:57:14 +0900 Subject: [PATCH] refactor server UDP session handling --- tuic-server/src/connection/authenticated.rs | 51 ++ tuic-server/src/connection/handle_stream.rs | 138 ++++ tuic-server/src/connection/handle_task.rs | 191 ++++++ tuic-server/src/connection/mod.rs | 183 ++++++ tuic-server/src/connection/udp_session.rs | 161 +++++ tuic-server/src/error.rs | 44 ++ tuic-server/src/lib.rs | 14 + tuic-server/src/main.rs | 45 +- tuic-server/src/server.rs | 673 +------------------- tuic-server/src/utils.rs | 4 +- 10 files changed, 791 insertions(+), 713 deletions(-) create mode 100644 tuic-server/src/connection/authenticated.rs create mode 100644 tuic-server/src/connection/handle_stream.rs create mode 100644 tuic-server/src/connection/handle_task.rs create mode 100644 tuic-server/src/connection/mod.rs create mode 100644 tuic-server/src/connection/udp_session.rs create mode 100644 tuic-server/src/error.rs create mode 100644 tuic-server/src/lib.rs diff --git a/tuic-server/src/connection/authenticated.rs b/tuic-server/src/connection/authenticated.rs new file mode 100644 index 0000000..030d7d7 --- /dev/null +++ b/tuic-server/src/connection/authenticated.rs @@ -0,0 +1,51 @@ +use crossbeam_utils::atomic::AtomicCell; +use parking_lot::Mutex; +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll, Waker}, +}; +use uuid::Uuid; + +#[derive(Clone)] +pub(super) struct Authenticated(Arc); + +struct AuthenticatedInner { + uuid: AtomicCell>, + broadcast: Mutex>, +} + +impl Authenticated { + pub(super) fn new() -> Self { + Self(Arc::new(AuthenticatedInner { + uuid: AtomicCell::new(None), + broadcast: Mutex::new(Vec::new()), + })) + } + + pub(super) fn set(&self, uuid: Uuid) { + self.0.uuid.store(Some(uuid)); + + for waker in self.0.broadcast.lock().drain(..) { + waker.wake(); + } + } + + pub(super) fn get(&self) -> Option { + self.0.uuid.load() + } +} + +impl Future for Authenticated { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.get().is_some() { + Poll::Ready(()) + } else { + self.0.broadcast.lock().push(cx.waker().clone()); + Poll::Pending + } + } +} diff --git a/tuic-server/src/connection/handle_stream.rs b/tuic-server/src/connection/handle_stream.rs new file mode 100644 index 0000000..514661f --- /dev/null +++ b/tuic-server/src/connection/handle_stream.rs @@ -0,0 +1,138 @@ +use super::Connection; +use crate::{Error, UdpRelayMode}; +use bytes::Bytes; +use quinn::{RecvStream, SendStream, VarInt}; +use register_count::Register; +use std::sync::atomic::Ordering; +use tokio::time; +use tuic_quinn::Task; + +impl Connection { + pub(crate) async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { + let addr = self.inner.remote_address(); + log::debug!("[{addr}] incoming unidirectional stream"); + + let max = self.max_concurrent_uni_streams.load(Ordering::Relaxed); + + if self.remote_uni_stream_cnt.count() as u32 == max { + self.max_concurrent_uni_streams + .store(max * 2, Ordering::Relaxed); + + self.inner + .set_max_concurrent_uni_streams(VarInt::from(max * 2)); + } + + let pre_process = async { + let task = time::timeout( + self.task_negotiation_timeout, + self.model.accept_uni_stream(recv), + ) + .await + .map_err(|_| Error::TaskNegotiationTimeout)??; + + if let Task::Authenticate(auth) = &task { + self.authenticate(auth)?; + } + + tokio::select! { + () = self.auth.clone() => {} + err = self.inner.closed() => return Err(Error::Connection(err)), + }; + + let same_pkt_src = matches!(task, Task::Packet(_)) + && matches!(self.udp_relay_mode.load(), Some(UdpRelayMode::Native)); + if same_pkt_src { + return Err(Error::UnexpectedPacketSource); + } + + Ok(task) + }; + + match pre_process.await { + Ok(Task::Authenticate(auth)) => self.handle_authenticate(auth).await, + Ok(Task::Packet(pkt)) => self.handle_packet(pkt, UdpRelayMode::Quic).await, + Ok(Task::Dissociate(assoc_id)) => self.handle_dissociate(assoc_id).await, + Ok(_) => unreachable!(), // already filtered in `tuic_quinn` + Err(err) => { + log::warn!("[{addr}] handle unidirection stream error: {err}"); + self.close(); + } + } + } + + pub(crate) async fn handle_bi_stream( + self, + (send, recv): (SendStream, RecvStream), + _reg: Register, + ) { + let addr = self.inner.remote_address(); + log::debug!("[{addr}] incoming bidirectional stream"); + + let max = self.max_concurrent_bi_streams.load(Ordering::Relaxed); + + if self.remote_bi_stream_cnt.count() as u32 == max { + self.max_concurrent_bi_streams + .store(max * 2, Ordering::Relaxed); + + self.inner + .set_max_concurrent_bi_streams(VarInt::from(max * 2)); + } + + let pre_process = async { + let task = time::timeout( + self.task_negotiation_timeout, + self.model.accept_bi_stream(send, recv), + ) + .await + .map_err(|_| Error::TaskNegotiationTimeout)??; + + tokio::select! { + () = self.auth.clone() => {} + err = self.inner.closed() => return Err(Error::Connection(err)), + }; + + Ok(task) + }; + + match pre_process.await { + Ok(Task::Connect(conn)) => self.handle_connect(conn).await, + Ok(_) => unreachable!(), // already filtered in `tuic_quinn` + Err(err) => { + log::warn!("[{addr}] handle bidirection stream error: {err}"); + self.close(); + } + } + } + + pub(crate) async fn handle_datagram(self, dg: Bytes) { + let addr = self.inner.remote_address(); + log::debug!("[{addr}] incoming datagram"); + + let pre_process = async { + let task = self.model.accept_datagram(dg)?; + + tokio::select! { + () = self.auth.clone() => {} + err = self.inner.closed() => return Err(Error::Connection(err)), + }; + + let same_pkt_src = matches!(task, Task::Packet(_)) + && matches!(self.udp_relay_mode.load(), Some(UdpRelayMode::Quic)); + if same_pkt_src { + return Err(Error::UnexpectedPacketSource); + } + + Ok(task) + }; + + match pre_process.await { + Ok(Task::Packet(pkt)) => self.handle_packet(pkt, UdpRelayMode::Native).await, + Ok(Task::Heartbeat) => self.handle_heartbeat().await, + Ok(_) => unreachable!(), + Err(err) => { + log::warn!("[{addr}] handle datagram error: {err}"); + self.close(); + } + } + } +} diff --git a/tuic-server/src/connection/handle_task.rs b/tuic-server/src/connection/handle_task.rs new file mode 100644 index 0000000..2341c4c --- /dev/null +++ b/tuic-server/src/connection/handle_task.rs @@ -0,0 +1,191 @@ +use super::{UdpSession, ERROR_CODE}; +use crate::{Connection, Error, UdpRelayMode}; +use bytes::Bytes; +use std::{ + collections::hash_map::Entry, + io::{Error as IoError, ErrorKind}, + net::SocketAddr, +}; +use tokio::{ + io::{self, AsyncWriteExt}, + net::{self, TcpStream}, +}; +use tokio_util::compat::FuturesAsyncReadCompatExt; +use tuic::Address; +use tuic_quinn::{Authenticate, Connect, Packet}; + +impl Connection { + pub(super) async fn handle_authenticate(&self, auth: Authenticate) { + log::info!( + "[{addr}] [{uuid}] [authenticate] authenticated as {auth_uuid}", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + auth_uuid = auth.uuid(), + ); + } + + pub(super) async fn handle_connect(&self, conn: Connect) { + let target_addr = conn.addr().to_string(); + + log::info!( + "[{addr}] [{uuid}] [connect] {target_addr}", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ); + + let process = async { + let mut stream = None; + let mut last_err = None; + + match resolve_dns(conn.addr()).await { + Ok(addrs) => { + for addr in addrs { + match TcpStream::connect(addr).await { + Ok(s) => { + stream = Some(s); + break; + } + Err(err) => last_err = Some(err), + } + } + } + Err(err) => last_err = Some(err), + } + + if let Some(mut stream) = stream { + let mut conn = conn.compat(); + let res = io::copy_bidirectional(&mut conn, &mut stream).await; + let _ = conn.get_mut().reset(ERROR_CODE); + let _ = stream.shutdown().await; + res?; + Ok::<_, Error>(()) + } else { + let _ = conn.compat().shutdown().await; + Err(last_err + .unwrap_or_else(|| IoError::new(ErrorKind::NotFound, "no address resolved")))? + } + }; + + match process.await { + Ok(()) => {} + Err(err) => log::warn!( + "[{addr}] [{uuid}] [connect] relaying connection to {target_addr} error: {err}", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ), + } + } + + pub(super) async fn handle_packet(&self, pkt: Packet, mode: UdpRelayMode) { + let assoc_id = pkt.assoc_id(); + let pkt_id = pkt.pkt_id(); + let frag_id = pkt.frag_id(); + let frag_total = pkt.frag_total(); + + log::info!( + "[{addr}] [{uuid}] [packet] [{assoc_id:#06x}] [from-{mode}] [{pkt_id:#06x}] {frag_id}/{frag_total}", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ); + + self.udp_relay_mode.store(Some(mode)); + + let process = async { + let Some((pkt, addr, assoc_id)) = pkt.accept().await? else { + return Ok(()); + }; + + let session = match self.udp_sessions.lock().entry(assoc_id) { + Entry::Occupied(entry) => entry.get().clone(), + Entry::Vacant(entry) => { + let session = UdpSession::new( + self.clone(), + assoc_id, + self.udp_relay_ipv6, + self.max_external_pkt_size, + )?; + entry.insert(session.clone()); + session + } + }; + + let Some(socket_addr) = resolve_dns(&addr).await?.next() else { + return Err(Error::from(IoError::new(ErrorKind::NotFound, "no address resolved"))); + }; + + session.send(pkt, socket_addr).await + }; + + match process.await { + Ok(()) => {} + Err(err) => log::warn!( + "[{addr}] [{uuid}] [packet] [{assoc_id:#06x}] [from-{mode}] [{pkt_id:#06x}] error handling fragment {frag_id}/{frag_total}: {err}", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ), + } + } + + pub(super) async fn handle_dissociate(&self, assoc_id: u16) { + log::info!( + "[{addr}] [{uuid}] [dissociate] [{assoc_id:#06x}]", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ); + + if let Some(session) = self.udp_sessions.lock().remove(&assoc_id) { + session.close(); + } + } + + pub(super) async fn handle_heartbeat(&self) { + log::info!( + "[{addr}] [{uuid}] [heartbeat]", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ); + } + + pub(super) async fn send_packet(self, pkt: Bytes, addr: Address, assoc_id: u16) { + let addr_display = addr.to_string(); + + let res = match self.udp_relay_mode.load() { + Some(UdpRelayMode::Native) => { + log::info!( + "[{addr}] [packet-to-native] [{assoc_id}] [{target_addr}]", + addr = self.inner.remote_address(), + target_addr = addr_display, + ); + self.model.packet_native(pkt, addr, assoc_id) + } + Some(UdpRelayMode::Quic) => { + log::info!( + "[{addr}] [packet-to-quic] [{assoc_id}] [{target_addr}]", + addr = self.inner.remote_address(), + target_addr = addr_display, + ); + self.model.packet_quic(pkt, addr, assoc_id).await + } + None => unreachable!(), + }; + + if let Err(err) = res { + log::warn!( + "[{addr}] [packet-to-native] [{assoc_id}] [{target_addr}] {err}", + addr = self.inner.remote_address(), + target_addr = addr_display, + ); + } + } +} + +async fn resolve_dns(addr: &Address) -> Result, IoError> { + match addr { + Address::None => Err(IoError::new(ErrorKind::InvalidInput, "empty address")), + Address::DomainAddress(domain, port) => Ok(net::lookup_host((domain.as_str(), *port)) + .await? + .collect::>() + .into_iter()), + Address::SocketAddress(addr) => Ok(vec![*addr].into_iter()), + } +} diff --git a/tuic-server/src/connection/mod.rs b/tuic-server/src/connection/mod.rs new file mode 100644 index 0000000..a1123b9 --- /dev/null +++ b/tuic-server/src/connection/mod.rs @@ -0,0 +1,183 @@ +use self::{authenticated::Authenticated, udp_session::UdpSession}; +use crate::{Error, UdpRelayMode}; +use crossbeam_utils::atomic::AtomicCell; +use parking_lot::Mutex; +use quinn::{Connecting, Connection as QuinnConnection, VarInt}; +use register_count::Counter; +use std::{ + collections::HashMap, + sync::{atomic::AtomicU32, Arc}, + time::Duration, +}; +use tokio::time; +use tuic_quinn::{side, Authenticate, Connection as Model}; +use uuid::Uuid; + +mod authenticated; +mod handle_stream; +mod handle_task; +mod udp_session; + +pub(crate) const ERROR_CODE: VarInt = VarInt::from_u32(0); +pub(crate) const DEFAULT_CONCURRENT_STREAMS: u32 = 32; + +#[derive(Clone)] +pub struct Connection { + inner: QuinnConnection, + model: Model, + users: Arc>>, + udp_relay_ipv6: bool, + auth: Authenticated, + task_negotiation_timeout: Duration, + udp_sessions: Arc>>, + udp_relay_mode: Arc>>, + max_external_pkt_size: usize, + remote_uni_stream_cnt: Counter, + remote_bi_stream_cnt: Counter, + max_concurrent_uni_streams: Arc, + max_concurrent_bi_streams: Arc, +} + +#[allow(clippy::too_many_arguments)] +impl Connection { + pub async fn handle( + conn: Connecting, + users: Arc>>, + udp_relay_ipv6: bool, + zero_rtt_handshake: bool, + auth_timeout: Duration, + task_negotiation_timeout: Duration, + max_external_pkt_size: usize, + gc_interval: Duration, + gc_lifetime: Duration, + ) { + let addr = conn.remote_address(); + + let init = async { + let conn = if zero_rtt_handshake { + match conn.into_0rtt() { + Ok((conn, _)) => conn, + Err(conn) => conn.await?, + } + } else { + conn.await? + }; + + Ok::<_, Error>(Self::new( + conn, + users, + udp_relay_ipv6, + task_negotiation_timeout, + max_external_pkt_size, + )) + }; + + match init.await { + Ok(conn) => { + log::info!("[{addr}] connection established"); + + tokio::spawn(conn.clone().timeout_authenticate(auth_timeout)); + tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime)); + + loop { + if conn.is_closed() { + break; + } + + let handle_incoming = async { + tokio::select! { + res = conn.inner.accept_uni() => + tokio::spawn(conn.clone().handle_uni_stream(res?, conn.remote_uni_stream_cnt.reg())), + res = conn.inner.accept_bi() => + tokio::spawn(conn.clone().handle_bi_stream(res?, conn.remote_bi_stream_cnt.reg())), + res = conn.inner.read_datagram() => + tokio::spawn(conn.clone().handle_datagram(res?)), + }; + + Ok::<_, Error>(()) + }; + + match handle_incoming.await { + Ok(()) => {} + Err(err) if err.is_locally_closed() => {} + Err(err) if err.is_timeout_closed() => { + log::debug!("[{addr}] connection timeout") + } + Err(err) => log::warn!("[{addr}] connection error: {err}"), + } + } + } + Err(err) if err.is_locally_closed() || err.is_timeout_closed() => unreachable!(), + Err(err) => log::warn!("[{addr}] connection establishing error: {err}"), + } + } + + fn new( + conn: QuinnConnection, + users: Arc>>, + udp_relay_ipv6: bool, + task_negotiation_timeout: Duration, + max_external_pkt_size: usize, + ) -> Self { + Self { + inner: conn.clone(), + model: Model::::new(conn), + users, + udp_relay_ipv6, + auth: Authenticated::new(), + task_negotiation_timeout, + udp_sessions: Arc::new(Mutex::new(HashMap::new())), + udp_relay_mode: Arc::new(AtomicCell::new(None)), + max_external_pkt_size, + remote_uni_stream_cnt: Counter::new(), + remote_bi_stream_cnt: Counter::new(), + max_concurrent_uni_streams: Arc::new(AtomicU32::new(DEFAULT_CONCURRENT_STREAMS)), + max_concurrent_bi_streams: Arc::new(AtomicU32::new(DEFAULT_CONCURRENT_STREAMS)), + } + } + + fn authenticate(&self, auth: &Authenticate) -> Result<(), Error> { + if self.auth.get().is_some() { + Err(Error::DuplicatedAuth) + } else if self + .users + .get(&auth.uuid()) + .map_or(false, |password| auth.validate(password)) + { + self.auth.set(auth.uuid()); + Ok(()) + } else { + Err(Error::AuthFailed(auth.uuid())) + } + } + + async fn timeout_authenticate(self, timeout: Duration) { + time::sleep(timeout).await; + + if self.auth.get().is_none() { + let addr = self.inner.remote_address(); + log::warn!("[{addr}] [authenticate] timeout"); + self.close(); + } + } + + async fn collect_garbage(self, gc_interval: Duration, gc_lifetime: Duration) { + loop { + time::sleep(gc_interval).await; + + if self.is_closed() { + break; + } + + self.model.collect_garbage(gc_lifetime); + } + } + + fn is_closed(&self) -> bool { + self.inner.close_reason().is_some() + } + + fn close(&self) { + self.inner.close(ERROR_CODE, &[]); + } +} diff --git a/tuic-server/src/connection/udp_session.rs b/tuic-server/src/connection/udp_session.rs new file mode 100644 index 0000000..fb52c6c --- /dev/null +++ b/tuic-server/src/connection/udp_session.rs @@ -0,0 +1,161 @@ +use crate::{Connection, Error}; +use bytes::Bytes; +use parking_lot::Mutex; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; +use std::{ + io::Error as IoError, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, + sync::Arc, +}; +use tokio::{ + net::UdpSocket, + sync::oneshot::{self, Sender}, +}; +use tuic::Address; + +#[derive(Clone)] +pub(super) struct UdpSession(Arc); + +struct UdpSessionInner { + assoc_id: u16, + conn: Connection, + socket_v4: UdpSocket, + socket_v6: Option, + max_pkt_size: usize, + close: Mutex>>, +} + +impl UdpSession { + pub(super) fn new( + conn: Connection, + assoc_id: u16, + udp_relay_ipv6: bool, + max_pkt_size: usize, + ) -> Result { + let socket_v4 = { + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + .map_err(|err| Error::Socket("failed to create UDP associate IPv4 socket", err))?; + + socket.set_nonblocking(true).map_err(|err| { + Error::Socket( + "failed setting UDP associate IPv4 socket as non-blocking", + err, + ) + })?; + + socket + .bind(&SockAddr::from(SocketAddr::from(( + Ipv4Addr::UNSPECIFIED, + 0, + )))) + .map_err(|err| Error::Socket("failed to bind UDP associate IPv4 socket", err))?; + + UdpSocket::from_std(StdUdpSocket::from(socket))? + }; + + let socket_v6 = if udp_relay_ipv6 { + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) + .map_err(|err| Error::Socket("failed to create UDP associate IPv6 socket", err))?; + + socket.set_nonblocking(true).map_err(|err| { + Error::Socket( + "failed setting UDP associate IPv6 socket as non-blocking", + err, + ) + })?; + + socket.set_only_v6(true).map_err(|err| { + Error::Socket("failed setting UDP associate IPv6 socket as IPv6-only", err) + })?; + + socket + .bind(&SockAddr::from(SocketAddr::from(( + Ipv6Addr::UNSPECIFIED, + 0, + )))) + .map_err(|err| Error::Socket("failed to bind UDP associate IPv6 socket", err))?; + + Some(UdpSocket::from_std(StdUdpSocket::from(socket))?) + } else { + None + }; + + let (tx, rx) = oneshot::channel(); + + let session = Self(Arc::new(UdpSessionInner { + conn, + assoc_id, + socket_v4, + socket_v6, + max_pkt_size, + close: Mutex::new(Some(tx)), + })); + + let session_listening = session.clone(); + let listen = async move { + loop { + let (pkt, addr) = match session_listening.recv().await { + Ok(res) => res, + Err(err) => { + log::warn!("{err}"); // TODO + continue; + } + }; + + tokio::spawn(session_listening.0.conn.clone().send_packet( + pkt, + Address::SocketAddress(addr), + session_listening.0.assoc_id, + )); + } + }; + + tokio::spawn(async move { + tokio::select! { + _ = listen => unreachable!(), + _ = rx => {}, + } + }); + + Ok(session) + } + + pub(super) async fn send(&self, pkt: Bytes, addr: SocketAddr) -> Result<(), Error> { + let socket = match addr { + SocketAddr::V4(_) => &self.0.socket_v4, + SocketAddr::V6(_) => self + .0 + .socket_v6 + .as_ref() + .ok_or_else(|| Error::UdpRelayIpv6Disabled(addr))?, + }; + + socket.send_to(&pkt, addr).await?; + Ok(()) + } + + async fn recv(&self) -> Result<(Bytes, SocketAddr), IoError> { + async fn recv( + socket: &UdpSocket, + max_pkt_size: usize, + ) -> Result<(Bytes, SocketAddr), IoError> { + let mut buf = vec![0u8; max_pkt_size]; + let (n, addr) = socket.recv_from(&mut buf).await?; + buf.truncate(n); + Ok((Bytes::from(buf), addr)) + } + + if let Some(socket_v6) = &self.0.socket_v6 { + tokio::select! { + res = recv(&self.0.socket_v4, self.0.max_pkt_size) => res, + res = recv(socket_v6, self.0.max_pkt_size) => res, + } + } else { + recv(&self.0.socket_v4, self.0.max_pkt_size).await + } + } + + pub(super) fn close(&self) { + let _ = self.0.close.lock().take().unwrap().send(()); + } +} diff --git a/tuic-server/src/error.rs b/tuic-server/src/error.rs new file mode 100644 index 0000000..663d8ef --- /dev/null +++ b/tuic-server/src/error.rs @@ -0,0 +1,44 @@ +use quinn::ConnectionError; +use rustls::Error as RustlsError; +use std::{io::Error as IoError, net::SocketAddr}; +use thiserror::Error; +use tuic_quinn::Error as ModelError; +use uuid::Uuid; + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Rustls(#[from] RustlsError), + #[error("invalid max idle time")] + InvalidMaxIdleTime, + #[error(transparent)] + Connection(#[from] ConnectionError), + #[error(transparent)] + Model(#[from] ModelError), + #[error("duplicated authentication")] + DuplicatedAuth, + #[error("token length too short")] + ExportKeyingMaterial, + #[error("authentication failed: {0}")] + AuthFailed(Uuid), + #[error("received packet from unexpected source")] + UnexpectedPacketSource, + #[error("{0}: {1}")] + Socket(&'static str, IoError), + #[error("task negotiation timed out")] + TaskNegotiationTimeout, + #[error("failed sending packet to {0}: relaying IPv6 UDP packet is disabled")] + UdpRelayIpv6Disabled(SocketAddr), +} + +impl Error { + pub fn is_locally_closed(&self) -> bool { + matches!(self, Self::Connection(ConnectionError::LocallyClosed)) + } + + pub fn is_timeout_closed(&self) -> bool { + matches!(self, Self::Connection(ConnectionError::TimedOut)) + } +} diff --git a/tuic-server/src/lib.rs b/tuic-server/src/lib.rs new file mode 100644 index 0000000..da40121 --- /dev/null +++ b/tuic-server/src/lib.rs @@ -0,0 +1,14 @@ +pub(crate) mod config; +pub(crate) mod error; +pub(crate) mod server; +pub(crate) mod utils; + +pub mod connection; + +pub use crate::{ + config::{Config, ConfigError}, + connection::Connection, + error::Error, + server::Server, + utils::{CongestionControl, UdpRelayMode}, +}; diff --git a/tuic-server/src/main.rs b/tuic-server/src/main.rs index b4c4a52..346855c 100644 --- a/tuic-server/src/main.rs +++ b/tuic-server/src/main.rs @@ -1,19 +1,6 @@ -use self::{ - config::{Config, ConfigError}, - server::Server, -}; use env_logger::Builder as LoggerBuilder; -use quinn::ConnectionError; -use rustls::Error as RustlsError; -use std::{env, io::Error as IoError, net::SocketAddr, process}; -use thiserror::Error; -use tuic::Address; -use tuic_quinn::Error as ModelError; -use uuid::Uuid; - -mod config; -mod server; -mod utils; +use std::{env, process}; +use tuic_server::{Config, ConfigError, Server}; #[tokio::main] async fn main() { @@ -43,31 +30,3 @@ async fn main() { } } } - -#[derive(Debug, Error)] -pub enum Error { - #[error(transparent)] - Io(#[from] IoError), - #[error(transparent)] - Rustls(#[from] RustlsError), - #[error("invalid max idle time")] - InvalidMaxIdleTime, - #[error(transparent)] - Connection(#[from] ConnectionError), - #[error(transparent)] - Model(#[from] ModelError), - #[error("duplicated authentication")] - DuplicatedAuth, - #[error("token length too short")] - ExportKeyingMaterial, - #[error("authentication failed: {0}")] - AuthFailed(Uuid), - #[error("received packet from unexpected source")] - UnexpectedPacketSource, - #[error("{0}: {1}")] - Socket(&'static str, IoError), - #[error("task negotiation timed out")] - TaskNegotiationTimeout, - #[error("{0} resolved to {1} but IPv6 UDP relaying is disabled")] - UdpRelayIpv6Disabled(Address, SocketAddr), -} diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 33bec66..e9fee7a 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -1,49 +1,21 @@ use crate::{ - config::Config, - utils::{self, CongestionControl, UdpRelayMode}, + config::Config, connection::DEFAULT_CONCURRENT_STREAMS, utils, CongestionControl, Connection, Error, }; -use bytes::Bytes; -use crossbeam_utils::atomic::AtomicCell; -use parking_lot::Mutex; use quinn::{ congestion::{BbrConfig, CubicConfig, NewRenoConfig}, - Connecting, Connection as QuinnConnection, ConnectionError, Endpoint, EndpointConfig, - IdleTimeout, RecvStream, SendStream, ServerConfig, TokioRuntime, TransportConfig, VarInt, + Endpoint, EndpointConfig, IdleTimeout, ServerConfig, TokioRuntime, TransportConfig, VarInt, }; -use register_count::{Counter, Register}; use rustls::{version, ServerConfig as RustlsServerConfig}; use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use std::{ - collections::{hash_map::Entry, HashMap}, - future::Future, - io::{Error as IoError, ErrorKind}, - net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, - pin::Pin, - sync::{ - atomic::{AtomicU32, Ordering}, - Arc, - }, - task::{Context, Poll, Waker}, + collections::HashMap, + net::{SocketAddr, UdpSocket as StdUdpSocket}, + sync::Arc, time::Duration, }; -use tokio::{ - io::{self, AsyncWriteExt}, - net::{self, TcpStream, UdpSocket}, - sync::{ - oneshot::{self, Receiver, Sender}, - Mutex as AsyncMutex, - }, - time, -}; -use tokio_util::compat::FuturesAsyncReadCompatExt; -use tuic::Address; -use tuic_quinn::{side, Authenticate, Connect, Connection as Model, Packet, Task}; use uuid::Uuid; -const ERROR_CODE: VarInt = VarInt::from_u32(0); -const DEFAULT_CONCURRENT_STREAMS: u32 = 32; - pub struct Server { ep: Endpoint, users: Arc>>, @@ -172,638 +144,3 @@ impl Server { } } } - -#[derive(Clone)] -struct Connection { - inner: QuinnConnection, - model: Model, - users: Arc>>, - udp_relay_ipv6: bool, - auth: Authenticated, - task_negotiation_timeout: Duration, - udp_sessions: Arc>>, - udp_relay_mode: Arc>>, - max_external_pkt_size: usize, - remote_uni_stream_cnt: Counter, - remote_bi_stream_cnt: Counter, - max_concurrent_uni_streams: Arc, - max_concurrent_bi_streams: Arc, -} - -#[allow(clippy::too_many_arguments)] -impl Connection { - async fn handle( - conn: Connecting, - users: Arc>>, - udp_relay_ipv6: bool, - zero_rtt_handshake: bool, - auth_timeout: Duration, - task_negotiation_timeout: Duration, - max_external_pkt_size: usize, - gc_interval: Duration, - gc_lifetime: Duration, - ) { - let addr = conn.remote_address(); - - let init = async { - let conn = if zero_rtt_handshake { - match conn.into_0rtt() { - Ok((conn, _)) => conn, - Err(conn) => conn.await?, - } - } else { - conn.await? - }; - - Ok::<_, Error>(Self::new( - conn, - users, - udp_relay_ipv6, - task_negotiation_timeout, - max_external_pkt_size, - )) - }; - - match init.await { - Ok(conn) => { - log::info!("[{addr}] connection established"); - - tokio::spawn(conn.clone().timeout_authenticate(auth_timeout)); - tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime)); - - loop { - if conn.is_closed() { - break; - } - - let handle_incoming = async { - tokio::select! { - res = conn.inner.accept_uni() => - tokio::spawn(conn.clone().handle_uni_stream(res?, conn.remote_uni_stream_cnt.reg())), - res = conn.inner.accept_bi() => - tokio::spawn(conn.clone().handle_bi_stream(res?, conn.remote_bi_stream_cnt.reg())), - res = conn.inner.read_datagram() => - tokio::spawn(conn.clone().handle_datagram(res?)), - }; - - Ok::<_, Error>(()) - }; - - match handle_incoming.await { - Ok(()) => {} - Err(err) if err.is_locally_closed() => {} - Err(err) if err.is_timeout_closed() => { - log::debug!("[{addr}] connection timeout") - } - Err(err) => log::warn!("[{addr}] connection error: {err}"), - } - } - } - Err(err) if err.is_locally_closed() || err.is_timeout_closed() => unreachable!(), - Err(err) => log::warn!("[{addr}] connection establishing error: {err}"), - } - } - - fn new( - conn: QuinnConnection, - users: Arc>>, - udp_relay_ipv6: bool, - task_negotiation_timeout: Duration, - max_external_pkt_size: usize, - ) -> Self { - Self { - inner: conn.clone(), - model: Model::::new(conn), - users, - udp_relay_ipv6, - auth: Authenticated::new(), - task_negotiation_timeout, - udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), - udp_relay_mode: Arc::new(AtomicCell::new(None)), - max_external_pkt_size, - remote_uni_stream_cnt: Counter::new(), - remote_bi_stream_cnt: Counter::new(), - max_concurrent_uni_streams: Arc::new(AtomicU32::new(DEFAULT_CONCURRENT_STREAMS)), - max_concurrent_bi_streams: Arc::new(AtomicU32::new(DEFAULT_CONCURRENT_STREAMS)), - } - } - - fn authenticate(&self, auth: &Authenticate) -> Result<(), Error> { - if self.auth.get().is_some() { - Err(Error::DuplicatedAuth) - } else if self - .users - .get(&auth.uuid()) - .map_or(false, |password| auth.validate(password)) - { - self.auth.set(auth.uuid()); - Ok(()) - } else { - Err(Error::AuthFailed(auth.uuid())) - } - } - - async fn timeout_authenticate(self, timeout: Duration) { - time::sleep(timeout).await; - - if self.auth.get().is_none() { - let addr = self.inner.remote_address(); - log::warn!("[{addr}] [authenticate] timeout"); - self.close(); - } - } - - async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { - let addr = self.inner.remote_address(); - log::debug!("[{addr}] incoming unidirectional stream"); - - let max = self.max_concurrent_uni_streams.load(Ordering::Relaxed); - - if self.remote_uni_stream_cnt.count() as u32 == max { - self.max_concurrent_uni_streams - .store(max * 2, Ordering::Relaxed); - - self.inner - .set_max_concurrent_uni_streams(VarInt::from(max * 2)); - } - - let pre_process = async { - let task = time::timeout( - self.task_negotiation_timeout, - self.model.accept_uni_stream(recv), - ) - .await - .map_err(|_| Error::TaskNegotiationTimeout)??; - - if let Task::Authenticate(auth) = &task { - self.authenticate(auth)?; - } - - tokio::select! { - () = self.auth.clone() => {} - err = self.inner.closed() => return Err(Error::Connection(err)), - }; - - let same_pkt_src = matches!(task, Task::Packet(_)) - && matches!(self.get_udp_relay_mode(), Some(UdpRelayMode::Native)); - if same_pkt_src { - return Err(Error::UnexpectedPacketSource); - } - - Ok(task) - }; - - match pre_process.await { - Ok(Task::Authenticate(auth)) => self.handle_authenticate(auth).await, - Ok(Task::Packet(pkt)) => self.handle_packet(pkt, UdpRelayMode::Quic).await, - Ok(Task::Dissociate(assoc_id)) => self.handle_dissociate(assoc_id).await, - Ok(_) => unreachable!(), // already filtered in `tuic_quinn` - Err(err) => { - log::warn!("[{addr}] handle unidirection stream error: {err}"); - self.close(); - } - } - } - - async fn handle_bi_stream(self, (send, recv): (SendStream, RecvStream), _reg: Register) { - let addr = self.inner.remote_address(); - log::debug!("[{addr}] incoming bidirectional stream"); - - let max = self.max_concurrent_bi_streams.load(Ordering::Relaxed); - - if self.remote_bi_stream_cnt.count() as u32 == max { - self.max_concurrent_bi_streams - .store(max * 2, Ordering::Relaxed); - - self.inner - .set_max_concurrent_bi_streams(VarInt::from(max * 2)); - } - - let pre_process = async { - let task = time::timeout( - self.task_negotiation_timeout, - self.model.accept_bi_stream(send, recv), - ) - .await - .map_err(|_| Error::TaskNegotiationTimeout)??; - - tokio::select! { - () = self.auth.clone() => {} - err = self.inner.closed() => return Err(Error::Connection(err)), - }; - - Ok(task) - }; - - match pre_process.await { - Ok(Task::Connect(conn)) => self.handle_connect(conn).await, - Ok(_) => unreachable!(), // already filtered in `tuic_quinn` - Err(err) => { - log::warn!("[{addr}] handle bidirection stream error: {err}"); - self.close(); - } - } - } - - async fn handle_datagram(self, dg: Bytes) { - let addr = self.inner.remote_address(); - log::debug!("[{addr}] incoming datagram"); - - let pre_process = async { - let task = self.model.accept_datagram(dg)?; - - tokio::select! { - () = self.auth.clone() => {} - err = self.inner.closed() => return Err(Error::Connection(err)), - }; - - let same_pkt_src = matches!(task, Task::Packet(_)) - && matches!(self.get_udp_relay_mode(), Some(UdpRelayMode::Quic)); - if same_pkt_src { - return Err(Error::UnexpectedPacketSource); - } - - Ok(task) - }; - - match pre_process.await { - Ok(Task::Packet(pkt)) => self.handle_packet(pkt, UdpRelayMode::Native).await, - Ok(Task::Heartbeat) => self.handle_heartbeat().await, - Ok(_) => unreachable!(), - Err(err) => { - log::warn!("[{addr}] handle datagram error: {err}"); - self.close(); - } - } - } - - async fn handle_authenticate(&self, auth: Authenticate) { - log::info!( - "[{addr}] [{uuid}] [authenticate] authenticated as {auth_uuid}", - addr = self.inner.remote_address(), - uuid = self.auth.get().unwrap(), - auth_uuid = auth.uuid(), - ); - } - - async fn handle_connect(&self, conn: Connect) { - let target_addr = conn.addr().to_string(); - - log::info!( - "[{addr}] [{uuid}] [connect] {target_addr}", - addr = self.inner.remote_address(), - uuid = self.auth.get().unwrap(), - ); - - let process = async { - let mut stream = None; - let mut last_err = None; - - match resolve_dns(conn.addr()).await { - Ok(addrs) => { - for addr in addrs { - match TcpStream::connect(addr).await { - Ok(s) => { - stream = Some(s); - break; - } - Err(err) => last_err = Some(err), - } - } - } - Err(err) => last_err = Some(err), - } - - if let Some(mut stream) = stream { - let mut conn = conn.compat(); - let res = io::copy_bidirectional(&mut conn, &mut stream).await; - let _ = conn.get_mut().reset(ERROR_CODE); - let _ = stream.shutdown().await; - res?; - Ok::<_, Error>(()) - } else { - let _ = conn.compat().shutdown().await; - Err(last_err - .unwrap_or_else(|| IoError::new(ErrorKind::NotFound, "no address resolved")))? - } - }; - - match process.await { - Ok(()) => {} - Err(err) => log::warn!( - "[{addr}] [{uuid}] [connect] relaying connection to {target_addr} error: {err}", - addr = self.inner.remote_address(), - uuid = self.auth.get().unwrap(), - ), - } - } - - async fn handle_packet(&self, pkt: Packet, mode: UdpRelayMode) { - let assoc_id = pkt.assoc_id(); - let pkt_id = pkt.pkt_id(); - let frag_id = pkt.frag_id(); - let frag_total = pkt.frag_total(); - - log::info!( - "[{addr}] [{uuid}] [packet] [{assoc_id:#06x}] [from-{mode}] [{pkt_id:#06x}] {frag_id}/{frag_total}", - addr = self.inner.remote_address(), - uuid = self.auth.get().unwrap(), - ); - - self.set_udp_relay_mode(mode); - - let process = async { - let Some((pkt, addr, assoc_id)) = pkt.accept().await? else { - return Ok(()); - }; - - let (socket_v4, socket_v6) = match self.udp_sessions.lock().await.entry(assoc_id) { - Entry::Occupied(mut entry) => { - let session = entry.get_mut(); - (session.socket_v4.clone(), session.socket_v6.clone()) - } - Entry::Vacant(entry) => { - let session = entry.insert( - UdpSession::new(assoc_id, self.clone(), self.udp_relay_ipv6).await?, - ); - - (session.socket_v4.clone(), session.socket_v6.clone()) - } - }; - - let Some(socket_addr) = resolve_dns(&addr).await?.next() else { - return Err(Error::from(IoError::new(ErrorKind::NotFound, "no address resolved"))); - }; - - let socket = match socket_addr { - SocketAddr::V4(_) => socket_v4, - SocketAddr::V6(_) => { - socket_v6.ok_or_else(|| Error::UdpRelayIpv6Disabled(addr, socket_addr))? - } - }; - - socket.send_to(&pkt, socket_addr).await?; - - Ok(()) - }; - - match process.await { - Ok(()) => {} - Err(err) => log::warn!( - "[{addr}] [{uuid}] [packet] [{assoc_id:#06x}] [from-{mode}] [{pkt_id:#06x}] error handling fragment {frag_id}/{frag_total}: {err}", - addr = self.inner.remote_address(), - uuid = self.auth.get().unwrap(), - ), - } - } - - async fn handle_dissociate(&self, assoc_id: u16) { - log::info!( - "[{addr}] [{uuid}] [dissociate] [{assoc_id:#06x}]", - addr = self.inner.remote_address(), - uuid = self.auth.get().unwrap(), - ); - - self.udp_sessions.lock().await.remove(&assoc_id); - } - - async fn handle_heartbeat(&self) { - log::info!( - "[{addr}] [{uuid}] [heartbeat]", - addr = self.inner.remote_address(), - uuid = self.auth.get().unwrap(), - ); - } - - async fn collect_garbage(self, gc_interval: Duration, gc_lifetime: Duration) { - loop { - time::sleep(gc_interval).await; - - if self.is_closed() { - break; - } - - self.model.collect_garbage(gc_lifetime); - } - } - - fn set_udp_relay_mode(&self, mode: UdpRelayMode) { - self.udp_relay_mode.store(Some(mode)); - } - - fn get_udp_relay_mode(&self) -> Option { - self.udp_relay_mode.load() - } - - fn is_closed(&self) -> bool { - self.inner.close_reason().is_some() - } - - fn close(&self) { - self.inner.close(ERROR_CODE, &[]); - } -} - -async fn resolve_dns(addr: &Address) -> Result, IoError> { - match addr { - Address::None => Err(IoError::new(ErrorKind::InvalidInput, "empty address")), - Address::DomainAddress(domain, port) => Ok(net::lookup_host((domain.as_str(), *port)) - .await? - .collect::>() - .into_iter()), - Address::SocketAddress(addr) => Ok(vec![*addr].into_iter()), - } -} - -struct UdpSession { - socket_v4: Arc, - socket_v6: Option>, - cancel: Option>, -} - -impl UdpSession { - async fn new(assoc_id: u16, conn: Connection, udp_relay_ipv6: bool) -> Result { - let socket_v4 = Arc::new( - UdpSocket::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))) - .await - .map_err(|err| Error::Socket("failed to create UDP associate IPv4 socket", err))?, - ); - let socket_v6 = if udp_relay_ipv6 { - let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) - .map_err(|err| Error::Socket("failed to create UDP associate IPv6 socket", err))?; - - socket.set_nonblocking(true).map_err(|err| { - Error::Socket( - "failed setting UDP associate IPv6 socket as non-blocking", - err, - ) - })?; - - socket.set_only_v6(true).map_err(|err| { - Error::Socket("failed setting UDP associate IPv6 socket as IPv6-only", err) - })?; - - socket - .bind(&SockAddr::from(SocketAddr::from(( - Ipv6Addr::UNSPECIFIED, - 0, - )))) - .map_err(|err| Error::Socket("failed to bind UDP associate IPv6 socket", err))?; - - Some(Arc::new(UdpSocket::from_std(StdUdpSocket::from(socket))?)) - } else { - None - }; - - let (tx, rx) = oneshot::channel(); - - tokio::spawn(Self::listen_incoming( - assoc_id, - conn, - socket_v4.clone(), - socket_v6.clone(), - rx, - )); - - Ok(Self { - socket_v4, - socket_v6, - cancel: Some(tx), - }) - } - - async fn listen_incoming( - assoc_id: u16, - conn: Connection, - socket_v4: Arc, - socket_v6: Option>, - cancel: Receiver<()>, - ) { - async fn send_pkt(conn: Connection, pkt: Bytes, target_addr: SocketAddr, assoc_id: u16) { - let addr = conn.inner.remote_address(); - let target_addr_tuic = Address::SocketAddress(target_addr); - - let res = match conn.get_udp_relay_mode() { - Some(UdpRelayMode::Native) => { - log::info!("[{addr}] [packet-to-native] [{assoc_id}] [{target_addr_tuic}]"); - conn.model.packet_native(pkt, target_addr_tuic, assoc_id) - } - Some(UdpRelayMode::Quic) => { - log::info!("[{addr}] [packet-to-quic] [{assoc_id}] [{target_addr_tuic}]"); - conn.model - .packet_quic(pkt, target_addr_tuic, assoc_id) - .await - } - None => unreachable!(), - }; - - if let Err(err) = res { - let target_addr_tuic = Address::SocketAddress(target_addr); - log::warn!("[{addr}] [packet-to-quic] [{assoc_id}] [{target_addr_tuic}] {err}"); - } - } - - let addr = conn.inner.remote_address(); - - tokio::select! { - _ = cancel => {} - () = async { - loop { - match Self::accept( - &socket_v4, - socket_v6.as_deref(), - conn.max_external_pkt_size, - ).await { - Ok((pkt, target_addr)) => { - tokio::spawn(send_pkt(conn.clone(), pkt, target_addr, assoc_id)); - } - Err(err) => log::warn!("[{addr}] [packet-to-*] [{assoc_id}] {err}"), - } - } - } => unreachable!(), - } - } - - async fn accept( - socket_v4: &UdpSocket, - socket_v6: Option<&UdpSocket>, - max_pkt_size: usize, - ) -> Result<(Bytes, SocketAddr), IoError> { - async fn read_pkt( - socket: &UdpSocket, - max_pkt_size: usize, - ) -> Result<(Bytes, SocketAddr), IoError> { - let mut buf = vec![0u8; max_pkt_size]; - let (n, addr) = socket.recv_from(&mut buf).await?; - buf.truncate(n); - Ok((Bytes::from(buf), addr)) - } - - if let Some(socket_v6) = socket_v6 { - tokio::select! { - res = read_pkt(socket_v4, max_pkt_size) => res, - res = read_pkt(socket_v6, max_pkt_size) => res, - } - } else { - read_pkt(socket_v4, max_pkt_size).await - } - } -} - -impl Drop for UdpSession { - fn drop(&mut self) { - let _ = self.cancel.take().unwrap().send(()); - } -} - -#[derive(Clone)] -struct Authenticated(Arc); - -struct AuthenticatedInner { - uuid: AtomicCell>, - broadcast: Mutex>, -} - -impl Authenticated { - fn new() -> Self { - Self(Arc::new(AuthenticatedInner { - uuid: AtomicCell::new(None), - broadcast: Mutex::new(Vec::new()), - })) - } - - fn set(&self, uuid: Uuid) { - self.0.uuid.store(Some(uuid)); - - for waker in self.0.broadcast.lock().drain(..) { - waker.wake(); - } - } - - fn get(&self) -> Option { - self.0.uuid.load() - } -} - -impl Future for Authenticated { - type Output = (); - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.get().is_some() { - Poll::Ready(()) - } else { - self.0.broadcast.lock().push(cx.waker().clone()); - Poll::Pending - } - } -} - -impl Error { - fn is_locally_closed(&self) -> bool { - matches!(self, Self::Connection(ConnectionError::LocallyClosed)) - } - - fn is_timeout_closed(&self) -> bool { - matches!(self, Self::Connection(ConnectionError::TimedOut)) - } -} diff --git a/tuic-server/src/utils.rs b/tuic-server/src/utils.rs index 2574217..f32924f 100644 --- a/tuic-server/src/utils.rs +++ b/tuic-server/src/utils.rs @@ -8,7 +8,7 @@ use std::{ str::FromStr, }; -pub fn load_certs(path: PathBuf) -> Result, IoError> { +pub(crate) fn load_certs(path: PathBuf) -> Result, IoError> { let mut file = BufReader::new(File::open(&path)?); let mut certs = Vec::new(); @@ -25,7 +25,7 @@ pub fn load_certs(path: PathBuf) -> Result, IoError> { Ok(certs) } -pub fn load_priv_key(path: PathBuf) -> Result { +pub(crate) fn load_priv_key(path: PathBuf) -> Result { let mut file = BufReader::new(File::open(&path)?); let mut priv_key = None;