From 6c61c70c06de0c3846be251ce5c0565ab19bdb90 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Mon, 29 May 2023 20:46:12 +0900 Subject: [PATCH] better error handling mechanism on server --- tuic-server/src/server.rs | 297 ++++++++++++++++++++------------------ tuic-server/src/utils.rs | 10 ++ 2 files changed, 170 insertions(+), 137 deletions(-) diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 497f7af..046e727 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -38,9 +38,10 @@ use tokio::{ }; use tokio_util::compat::FuturesAsyncReadCompatExt; use tuic::Address; -use tuic_quinn::{side, Connect, Connection as Model, Packet, Task}; +use tuic_quinn::{side, Authenticate, Connect, Connection as Model, Packet, Task}; use uuid::Uuid; +const ERROR_CODE: VarInt = VarInt::from_u32(0); const DEFAULT_CONCURRENT_STREAMS: usize = 32; pub struct Server { @@ -227,7 +228,7 @@ impl Connection { Ok(conn) => { log::info!("[{addr}] connection established"); - tokio::spawn(conn.clone().handle_auth_timeout(auth_timeout)); + tokio::spawn(conn.clone().timeout_authenticate(auth_timeout)); tokio::spawn(conn.clone().collect_garbage(gc_interval, gc_lifetime)); loop { @@ -287,6 +288,31 @@ impl Connection { } } + fn authenticate(&self, auth: &Authenticate) -> Result<(), Error> { + if self.auth.get().is_some() { + Err(Error::DuplicatedAuth) + } else if self + .users + .get(&auth.uuid()) + .map_or(false, |password| auth.validate(password)) + { + self.auth.set(auth.uuid()); + Ok(()) + } else { + Err(Error::AuthFailed(auth.uuid())) + } + } + + async fn timeout_authenticate(self, timeout: Duration) { + time::sleep(timeout).await; + + if self.auth.get().is_none() { + let addr = self.inner.remote_address(); + log::warn!("[{addr}] [authenticate] timeout"); + self.close(); + } + } + async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { let addr = self.inner.remote_address(); log::debug!("[{addr}] incoming unidirectional stream"); @@ -310,17 +336,7 @@ impl Connection { .map_err(|_| Error::TaskNegotiationTimeout)??; if let Task::Authenticate(auth) = &task { - if self.auth.get().is_some() { - return Err(Error::DuplicatedAuth); - } else if self - .users - .get(&auth.uuid()) - .map_or(false, |password| auth.validate(password)) - { - self.auth.set(auth.uuid()); - } else { - return Err(Error::AuthFailed(auth.uuid())); - } + self.authenticate(auth)?; } tokio::select! { @@ -338,35 +354,10 @@ impl Connection { }; match pre_process.await { - Ok(Task::Authenticate(auth)) => { - log::info!("[{addr}] [{uuid}] [authenticate]", uuid = auth.uuid()) - } - Ok(Task::Packet(pkt)) => { - let assoc_id = pkt.assoc_id(); - let pkt_id = pkt.pkt_id(); - let frag_id = pkt.frag_id(); - let frag_total = pkt.frag_total(); - log::info!( - "[{addr}] [packet-from-quic] [{assoc_id}] [{pkt_id}] [{frag_id}/{frag_total}]" - ); - - self.set_udp_relay_mode(UdpRelayMode::Quic); - match self.handle_packet(pkt).await { - Ok(()) => {} - Err(err) => log::warn!( - "[{addr}] [packet-from-quic] [{assoc_id}] [{pkt_id}] [{frag_id}:{frag_total}] {err}" - ), - } - } - Ok(Task::Dissociate(assoc_id)) => { - log::info!("[{addr}] [dissociate] [{assoc_id}]"); - - match self.handle_dissociate(assoc_id).await { - Ok(()) => {} - Err(err) => log::warn!("[{addr}] [dissociate] [{assoc_id}] {err}"), - } - } - Ok(_) => unreachable!(), + Ok(Task::Authenticate(auth)) => self.handle_authenticate(auth).await, + Ok(Task::Packet(pkt)) => self.handle_packet(pkt, UdpRelayMode::Quic).await, + Ok(Task::Dissociate(assoc_id)) => self.handle_dissociate(assoc_id).await, + Ok(_) => unreachable!(), // already filtered in `tuic_quinn` Err(err) => { log::warn!("[{addr}] handle unidirection stream error: {err}"); self.close(); @@ -405,16 +396,8 @@ impl Connection { }; match pre_process.await { - Ok(Task::Connect(conn)) => { - let target_addr = conn.addr().to_string(); - log::info!("[{addr}] [connect] [{target_addr}]"); - - match self.handle_connect(conn).await { - Ok(()) => {} - Err(err) => log::warn!("[{addr}] [connect] [{target_addr}] {err}"), - } - } - Ok(_) => unreachable!(), + Ok(Task::Connect(conn)) => self.handle_connect(conn).await, + Ok(_) => unreachable!(), // already filtered in `tuic_quinn` Err(err) => { log::warn!("[{addr}] handle bidirection stream error: {err}"); self.close(); @@ -444,24 +427,8 @@ impl Connection { }; match pre_process.await { - Ok(Task::Packet(pkt)) => { - let assoc_id = pkt.assoc_id(); - let pkt_id = pkt.pkt_id(); - let frag_id = pkt.frag_id(); - let frag_total = pkt.frag_total(); - log::info!( - "[{addr}] [packet-from-native] [{assoc_id}] [{pkt_id}] [{frag_id}:{frag_total}]" - ); - - self.set_udp_relay_mode(UdpRelayMode::Native); - match self.handle_packet(pkt).await { - Ok(()) => {} - Err(err) => log::warn!( - "[{addr}] [packet-from-native] [{assoc_id}] [{pkt_id}] [{frag_id}:{frag_total}] {err}" - ), - } - } - Ok(Task::Heartbeat) => log::info!("[{addr}] [heartbeat]"), + Ok(Task::Packet(pkt)) => self.handle_packet(pkt, UdpRelayMode::Native).await, + Ok(Task::Heartbeat) => self.handle_heartbeat().await, Ok(_) => unreachable!(), Err(err) => { log::warn!("[{addr}] handle datagram error: {err}"); @@ -470,86 +437,142 @@ impl Connection { } } - async fn handle_connect(&self, conn: Connect) -> Result<(), Error> { - let mut stream = None; - let mut last_err = None; + async fn handle_authenticate(&self, auth: Authenticate) { + log::info!( + "[{addr}] [{uuid}] [authenticate] authenticated as {auth_uuid}", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + auth_uuid = auth.uuid(), + ); + } - match resolve_dns(conn.addr()).await { - Ok(addrs) => { - for addr in addrs { - match TcpStream::connect(addr).await { - Ok(s) => { - stream = Some(s); - break; + async fn handle_connect(&self, conn: Connect) { + let target_addr = conn.addr().to_string(); + + log::info!( + "[{addr}] [{uuid}] [connect] {target_addr}", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ); + + let process = async { + let mut stream = None; + let mut last_err = None; + + match resolve_dns(conn.addr()).await { + Ok(addrs) => { + for addr in addrs { + match TcpStream::connect(addr).await { + Ok(s) => { + stream = Some(s); + break; + } + Err(err) => last_err = Some(err), } - Err(err) => last_err = Some(err), } } + Err(err) => last_err = Some(err), } - Err(err) => last_err = Some(err), - } - if let Some(mut stream) = stream { - let mut conn = conn.compat(); - let res = io::copy_bidirectional(&mut conn, &mut stream).await; - let _ = conn.get_mut().reset(VarInt::from_u32(0)); - let _ = stream.shutdown().await; - res?; + if let Some(mut stream) = stream { + let mut conn = conn.compat(); + let res = io::copy_bidirectional(&mut conn, &mut stream).await; + let _ = conn.get_mut().reset(ERROR_CODE); + let _ = stream.shutdown().await; + res?; + Ok::<_, Error>(()) + } else { + let _ = conn.compat().shutdown().await; + Err(last_err + .unwrap_or_else(|| IoError::new(ErrorKind::NotFound, "no address resolved")))? + } + }; + + match process.await { + Ok(()) => {} + Err(err) => log::warn!( + "[{addr}] [{uuid}] [connect] relaying connection to {target_addr} error: {err}", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ), + } + } + + async fn handle_packet(&self, pkt: Packet, mode: UdpRelayMode) { + let assoc_id = pkt.assoc_id(); + let pkt_id = pkt.pkt_id(); + let frag_id = pkt.frag_id(); + let frag_total = pkt.frag_total(); + + log::info!( + "[{addr}] [{uuid}] [packet] [{assoc_id:#06x}] [from-{mode}] [{pkt_id:#06x}] {frag_id}/{frag_total}", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ); + + self.set_udp_relay_mode(mode); + + let process = async { + let Some((pkt, addr, assoc_id)) = pkt.accept().await? else { + return Ok(()); + }; + + let (socket_v4, socket_v6) = match self.udp_sessions.lock().await.entry(assoc_id) { + Entry::Occupied(mut entry) => { + let session = entry.get_mut(); + (session.socket_v4.clone(), session.socket_v6.clone()) + } + Entry::Vacant(entry) => { + let session = entry.insert( + UdpSession::new(assoc_id, self.clone(), self.udp_relay_ipv6).await?, + ); + + (session.socket_v4.clone(), session.socket_v6.clone()) + } + }; + + let Some(socket_addr) = resolve_dns(&addr).await?.next() else { + return Err(Error::from(IoError::new(ErrorKind::NotFound, "no address resolved"))); + }; + + let socket = match socket_addr { + SocketAddr::V4(_) => socket_v4, + SocketAddr::V6(_) => { + socket_v6.ok_or_else(|| Error::UdpRelayIpv6Disabled(addr, socket_addr))? + } + }; + + socket.send_to(&pkt, socket_addr).await?; + Ok(()) - } else { - let _ = conn.compat().shutdown().await; - Err(last_err - .unwrap_or_else(|| IoError::new(ErrorKind::NotFound, "no address resolved")))? + }; + + match process.await { + Ok(()) => {} + Err(err) => log::warn!( + "[{addr}] [{uuid}] [packet] [{assoc_id:#06x}] [from-{mode}] [{pkt_id:#06x}] error handling fragment {frag_id}/{frag_total}: {err}", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ), } } - async fn handle_packet(&self, pkt: Packet) -> Result<(), Error> { - let Some((pkt, addr, assoc_id)) = pkt.accept().await? else { - return Ok(()); - }; + async fn handle_dissociate(&self, assoc_id: u16) { + log::info!( + "[{addr}] [{uuid}] [dissociate] [{assoc_id:#06x}]", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ); - let (socket_v4, socket_v6) = match self.udp_sessions.lock().await.entry(assoc_id) { - Entry::Occupied(mut entry) => { - let session = entry.get_mut(); - (session.socket_v4.clone(), session.socket_v6.clone()) - } - Entry::Vacant(entry) => { - let session = entry - .insert(UdpSession::new(assoc_id, self.clone(), self.udp_relay_ipv6).await?); - - (session.socket_v4.clone(), session.socket_v6.clone()) - } - }; - - let Some(socket_addr) = resolve_dns(&addr).await?.next() else { - return Err(Error::from(IoError::new(ErrorKind::NotFound, "no address resolved"))); - }; - - let socket = match socket_addr { - SocketAddr::V4(_) => socket_v4, - SocketAddr::V6(_) => { - socket_v6.ok_or_else(|| Error::UdpRelayIpv6Disabled(addr, socket_addr))? - } - }; - - socket.send_to(&pkt, socket_addr).await?; - - Ok(()) - } - - async fn handle_dissociate(&self, assoc_id: u16) -> Result<(), Error> { self.udp_sessions.lock().await.remove(&assoc_id); - Ok(()) } - async fn handle_auth_timeout(self, timeout: Duration) { - time::sleep(timeout).await; - - if self.auth.get().is_none() { - let addr = self.inner.remote_address(); - log::warn!("[{addr}] authentication timeout"); - self.close(); - } + async fn handle_heartbeat(&self) { + log::info!( + "[{addr}] [{uuid}] [heartbeat]", + addr = self.inner.remote_address(), + uuid = self.auth.get().unwrap(), + ); } async fn collect_garbage(self, gc_interval: Duration, gc_lifetime: Duration) { @@ -577,7 +600,7 @@ impl Connection { } fn close(&self) { - self.inner.close(VarInt::from_u32(0), b""); + self.inner.close(ERROR_CODE, &[]); } } diff --git a/tuic-server/src/utils.rs b/tuic-server/src/utils.rs index 31061c7..2574217 100644 --- a/tuic-server/src/utils.rs +++ b/tuic-server/src/utils.rs @@ -1,6 +1,7 @@ use rustls::{Certificate, PrivateKey}; use rustls_pemfile::Item; use std::{ + fmt::{Display, Formatter, Result as FmtResult}, fs::{self, File}, io::{BufReader, Error as IoError}, path::PathBuf, @@ -46,6 +47,15 @@ pub enum UdpRelayMode { Quic, } +impl Display for UdpRelayMode { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + Self::Native => write!(f, "native"), + Self::Quic => write!(f, "quic"), + } + } +} + pub enum CongestionControl { Cubic, NewReno,