diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index a30a890..497f7af 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -21,7 +21,7 @@ use std::{ net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, pin::Pin, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, }, task::{Context, Poll, Waker}, @@ -178,7 +178,7 @@ struct Connection { model: Model, users: Arc>>, udp_relay_ipv6: bool, - is_authed: IsAuthed, + auth: Authenticated, task_negotiation_timeout: Duration, udp_sessions: Arc>>, udp_relay_mode: Arc>>, @@ -235,19 +235,31 @@ impl Connection { break; } - match conn.accept().await { + let handle_incoming = async { + tokio::select! { + res = conn.inner.accept_uni() => + tokio::spawn(conn.clone().handle_uni_stream(res?, conn.remote_uni_stream_cnt.reg())), + res = conn.inner.accept_bi() => + tokio::spawn(conn.clone().handle_bi_stream(res?, conn.remote_bi_stream_cnt.reg())), + res = conn.inner.read_datagram() => + tokio::spawn(conn.clone().handle_datagram(res?)), + }; + + Ok::<_, Error>(()) + }; + + match handle_incoming.await { Ok(()) => {} Err(err) if err.is_locally_closed() => {} Err(err) if err.is_timeout_closed() => { log::debug!("[{addr}] connection timeout") } - Err(err) => log::warn!("[{addr}] {err}"), + Err(err) => log::warn!("[{addr}] connection error: {err}"), } } } - Err(err) if err.is_locally_closed() => unreachable!(), - Err(err) if err.is_timeout_closed() => log::debug!("[{addr}] connection timeout"), - Err(err) => log::warn!("[{addr}] {err}"), + Err(err) if err.is_locally_closed() || err.is_timeout_closed() => unreachable!(), + Err(err) => log::warn!("[{addr}] connection establishing error: {err}"), } } @@ -263,7 +275,7 @@ impl Connection { model: Model::::new(conn), users, udp_relay_ipv6, - is_authed: IsAuthed::new(), + auth: Authenticated::new(), task_negotiation_timeout, udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_relay_mode: Arc::new(AtomicCell::new(None)), @@ -275,19 +287,6 @@ impl Connection { } } - async fn accept(&self) -> Result<(), Error> { - tokio::select! { - res = self.inner.accept_uni() => - tokio::spawn(self.clone().handle_uni_stream(res?, self.remote_uni_stream_cnt.reg())), - res = self.inner.accept_bi() => - tokio::spawn(self.clone().handle_bi_stream(res?, self.remote_bi_stream_cnt.reg())), - res = self.inner.read_datagram() => - tokio::spawn(self.clone().handle_datagram(res?)), - }; - - Ok(()) - } - async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { let addr = self.inner.remote_address(); log::debug!("[{addr}] incoming unidirectional stream"); @@ -311,21 +310,21 @@ impl Connection { .map_err(|_| Error::TaskNegotiationTimeout)??; if let Task::Authenticate(auth) = &task { - if self.is_authed() { + if self.auth.get().is_some() { return Err(Error::DuplicatedAuth); } else if self .users .get(&auth.uuid()) .map_or(false, |password| auth.validate(password)) { - self.set_authed(); + self.auth.set(auth.uuid()); } else { return Err(Error::AuthFailed(auth.uuid())); } } tokio::select! { - () = self.authed() => {} + () = self.auth.clone() => {} err = self.inner.closed() => return Err(Error::Connection(err)), }; @@ -339,14 +338,16 @@ impl Connection { }; match pre_process.await { - Ok(Task::Authenticate(auth)) => log::info!("[{addr}] authenticated as {}", auth.uuid()), + Ok(Task::Authenticate(auth)) => { + log::info!("[{addr}] [{uuid}] [authenticate]", uuid = auth.uuid()) + } Ok(Task::Packet(pkt)) => { let assoc_id = pkt.assoc_id(); let pkt_id = pkt.pkt_id(); let frag_id = pkt.frag_id(); let frag_total = pkt.frag_total(); log::info!( - "[{addr}] [packet-from-quic] [{assoc_id}] [{pkt_id}] [{frag_id}:{frag_total}]" + "[{addr}] [packet-from-quic] [{assoc_id}] [{pkt_id}] [{frag_id}/{frag_total}]" ); self.set_udp_relay_mode(UdpRelayMode::Quic); @@ -396,7 +397,7 @@ impl Connection { .map_err(|_| Error::TaskNegotiationTimeout)??; tokio::select! { - () = self.authed() => {} + () = self.auth.clone() => {} err = self.inner.closed() => return Err(Error::Connection(err)), }; @@ -429,7 +430,7 @@ impl Connection { let task = self.model.accept_datagram(dg)?; tokio::select! { - () = self.authed() => {} + () = self.auth.clone() => {} err = self.inner.closed() => return Err(Error::Connection(err)), }; @@ -544,7 +545,7 @@ impl Connection { async fn handle_auth_timeout(self, timeout: Duration) { time::sleep(timeout).await; - if !self.is_authed() { + if self.auth.get().is_none() { let addr = self.inner.remote_address(); log::warn!("[{addr}] authentication timeout"); self.close(); @@ -563,18 +564,6 @@ impl Connection { } } - fn set_authed(&self) { - self.is_authed.set_authed(); - } - - fn is_authed(&self) -> bool { - self.is_authed.is_authed() - } - - fn authed(&self) -> IsAuthed { - self.is_authed.clone() - } - fn set_udp_relay_mode(&self, mode: UdpRelayMode) { self.udp_relay_mode.store(Some(mode)); } @@ -745,40 +734,42 @@ impl Drop for UdpSession { } #[derive(Clone)] -struct IsAuthed { - is_authed: Arc, - broadcast: Arc>>, +struct Authenticated(Arc); + +struct AuthenticatedInner { + uuid: AtomicCell>, + broadcast: Mutex>, } -impl IsAuthed { +impl Authenticated { fn new() -> Self { - Self { - is_authed: Arc::new(AtomicBool::new(false)), - broadcast: Arc::new(Mutex::new(Vec::new())), - } + Self(Arc::new(AuthenticatedInner { + uuid: AtomicCell::new(None), + broadcast: Mutex::new(Vec::new()), + })) } - fn set_authed(&self) { - self.is_authed.store(true, Ordering::Release); + fn set(&self, uuid: Uuid) { + self.0.uuid.store(Some(uuid)); - for waker in self.broadcast.lock().drain(..) { + for waker in self.0.broadcast.lock().drain(..) { waker.wake(); } } - fn is_authed(&self) -> bool { - self.is_authed.load(Ordering::Relaxed) + fn get(&self) -> Option { + self.0.uuid.load() } } -impl Future for IsAuthed { +impl Future for Authenticated { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.is_authed.load(Ordering::Relaxed) { + if self.get().is_some() { Poll::Ready(()) } else { - self.broadcast.lock().push(cx.waker().clone()); + self.0.broadcast.lock().push(cx.waker().clone()); Poll::Pending } }