diff --git a/tuic-server/README.md b/tuic-server/README.md index 14bb360..2e25ddf 100644 --- a/tuic-server/README.md +++ b/tuic-server/README.md @@ -67,6 +67,10 @@ tuic-server -c PATH/TO/CONFIG // Default: 3s "auth_timeout": "3s", + // Optional. Maximum duration server expects for task negotiation + // Default: 3s + "task_negotiation_timeout": "3s", + // Optional. How long the server should wait before closing an idle connection // Default: 10s "max_idle_time": "10s", diff --git a/tuic-server/src/config.rs b/tuic-server/src/config.rs index 2442290..4ad7d13 100644 --- a/tuic-server/src/config.rs +++ b/tuic-server/src/config.rs @@ -55,6 +55,12 @@ pub struct Config { )] pub auth_timeout: Duration, + #[serde( + default = "default::task_negotiation_timeout", + deserialize_with = "deserialize_duration" + )] + pub task_negotiation_timeout: Duration, + #[serde( default = "default::max_idle_time", deserialize_with = "deserialize_duration" @@ -136,6 +142,10 @@ mod default { Duration::from_secs(3) } + pub fn task_negotiation_timeout() -> Duration { + Duration::from_secs(3) + } + pub fn max_idle_time() -> Duration { Duration::from_secs(10) } diff --git a/tuic-server/src/main.rs b/tuic-server/src/main.rs index 51bf977..b4c4a52 100644 --- a/tuic-server/src/main.rs +++ b/tuic-server/src/main.rs @@ -66,6 +66,8 @@ pub enum Error { UnexpectedPacketSource, #[error("{0}: {1}")] Socket(&'static str, IoError), + #[error("task negotiation timed out")] + TaskNegotiationTimeout, #[error("{0} resolved to {1} but IPv6 UDP relaying is disabled")] UdpRelayIpv6Disabled(Address, SocketAddr), } diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 8eba5f8..686a76a 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -49,6 +49,7 @@ pub struct Server { udp_relay_ipv6: bool, zero_rtt_handshake: bool, auth_timeout: Duration, + task_negotiation_timeout: Duration, max_external_pkt_size: usize, gc_interval: Duration, gc_lifetime: Duration, @@ -136,6 +137,7 @@ impl Server { udp_relay_ipv6: cfg.udp_relay_ipv6, zero_rtt_handshake: cfg.zero_rtt_handshake, auth_timeout: cfg.auth_timeout, + task_negotiation_timeout: cfg.task_negotiation_timeout, max_external_pkt_size: cfg.max_external_packet_size, gc_interval: cfg.gc_interval, gc_lifetime: cfg.gc_lifetime, @@ -159,6 +161,7 @@ impl Server { self.udp_relay_ipv6, self.zero_rtt_handshake, self.auth_timeout, + self.task_negotiation_timeout, self.max_external_pkt_size, self.gc_interval, self.gc_lifetime, @@ -174,6 +177,7 @@ struct Connection { users: Arc>>, udp_relay_ipv6: bool, is_authed: IsAuthed, + task_negotiation_timeout: Duration, udp_sessions: Arc>>, udp_relay_mode: Arc>>, max_external_pkt_size: usize, @@ -191,22 +195,52 @@ impl Connection { udp_relay_ipv6: bool, zero_rtt_handshake: bool, auth_timeout: Duration, + task_negotiation_timeout: Duration, max_external_pkt_size: usize, gc_interval: Duration, gc_lifetime: Duration, ) { + async fn init( + conn: Connecting, + users: Arc>>, + udp_relay_ipv6: bool, + zero_rtt_handshake: bool, + task_negotiation_timeout: Duration, + max_external_pkt_size: usize, + ) -> Result { + let conn = if zero_rtt_handshake { + match conn.into_0rtt() { + Ok((conn, _)) => conn, + Err(conn) => { + log::info!("0-RTT handshake failed, fallback to 1-RTT handshake"); + conn.await? + } + } + } else { + conn.await? + }; + + Ok(Connection::new( + conn, + users, + udp_relay_ipv6, + task_negotiation_timeout, + max_external_pkt_size, + )) + } + let addr = conn.remote_address(); - let conn = Self::init( + match init( conn, users, udp_relay_ipv6, zero_rtt_handshake, + task_negotiation_timeout, max_external_pkt_size, ) - .await; - - match conn { + .await + { Ok(conn) => { log::info!("[{addr}] connection established"); @@ -234,31 +268,20 @@ impl Connection { } } - async fn init( - conn: Connecting, + fn new( + conn: QuinnConnection, users: Arc>>, udp_relay_ipv6: bool, - zero_rtt_handshake: bool, + task_negotiation_timeout: Duration, max_external_pkt_size: usize, - ) -> Result { - let conn = if zero_rtt_handshake { - match conn.into_0rtt() { - Ok((conn, _)) => conn, - Err(conn) => { - log::info!("0-RTT handshake failed, fallback to 1-RTT handshake"); - conn.await? - } - } - } else { - conn.await? - }; - - Ok(Self { + ) -> Self { + Self { inner: conn.clone(), model: Model::::new(conn), users, udp_relay_ipv6, is_authed: IsAuthed::new(), + task_negotiation_timeout, udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_relay_mode: Arc::new(AtomicCell::new(None)), max_external_pkt_size, @@ -266,7 +289,7 @@ impl Connection { remote_bi_stream_cnt: Counter::new(), max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), max_concurrent_bi_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), - }) + } } async fn accept(&self) -> Result<(), Error> { @@ -296,8 +319,14 @@ impl Connection { .set_max_concurrent_uni_streams(VarInt::from((max * 2) as u32)); } - async fn pre_process(conn: &Connection, recv: RecvStream) -> Result { - let task = conn.model.accept_uni_stream(recv).await?; + async fn pre_process( + conn: &Connection, + recv: RecvStream, + task_negotiation_timeout: Duration, + ) -> Result { + let task = time::timeout(task_negotiation_timeout, conn.model.accept_uni_stream(recv)) + .await + .map_err(|_| Error::TaskNegotiationTimeout)??; if let Task::Authenticate(auth) = &task { if conn.is_authed() { @@ -315,7 +344,7 @@ impl Connection { tokio::select! { () = conn.authed() => {} - err = conn.inner.closed() => Err(err)?, + err = conn.inner.closed() => return Err(Error::Connection(err)), }; let same_pkt_src = matches!(task, Task::Packet(_)) @@ -327,7 +356,7 @@ impl Connection { Ok(task) } - match pre_process(&self, recv).await { + match pre_process(&self, recv, self.task_negotiation_timeout).await { Ok(Task::Authenticate(auth)) => log::info!("[{addr}] authenticated as {}", auth.uuid()), Ok(Task::Packet(pkt)) => { let assoc_id = pkt.assoc_id(); @@ -380,18 +409,24 @@ impl Connection { conn: &Connection, send: SendStream, recv: RecvStream, + task_negotiation_timeout: Duration, ) -> Result { - let task = conn.model.accept_bi_stream(send, recv).await?; + let task = time::timeout( + task_negotiation_timeout, + conn.model.accept_bi_stream(send, recv), + ) + .await + .map_err(|_| Error::TaskNegotiationTimeout)??; tokio::select! { () = conn.authed() => {} - err = conn.inner.closed() => Err(err)?, + err = conn.inner.closed() => return Err(Error::Connection(err)), }; Ok(task) } - match pre_process(&self, send, recv).await { + match pre_process(&self, send, recv, self.task_negotiation_timeout).await { Ok(Task::Connect(conn)) => { let target_addr = conn.addr().to_string(); log::info!("[{addr}] [connect] [{target_addr}]"); @@ -418,7 +453,7 @@ impl Connection { tokio::select! { () = conn.authed() => {} - err = conn.inner.closed() => Err(err)?, + err = conn.inner.closed() => return Err(Error::Connection(err)), }; let same_pkt_src = matches!(task, Task::Packet(_))