diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index 5bdd4f2..8d4365e 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -130,56 +130,45 @@ impl Endpoint { } async fn connect(&mut self) -> Result { - async fn connect_to( - ep: &mut QuinnEndpoint, - addr: SocketAddr, - server_name: &str, - zero_rtt_handshake: bool, - ) -> Result { - let match_ipv4 = addr.is_ipv4() && ep.local_addr().map_or(false, |addr| addr.is_ipv4()); - let match_ipv6 = addr.is_ipv6() && ep.local_addr().map_or(false, |addr| addr.is_ipv6()); - - if !match_ipv4 && !match_ipv6 { - let bind_addr = if addr.is_ipv4() { - SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)) - } else { - SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)) - }; - - ep.rebind( - UdpSocket::bind(bind_addr).map_err(|err| { - Error::Socket("failed to create endpoint UDP socket", err) - })?, - ) - .map_err(|err| Error::Socket("failed to rebind endpoint UDP socket", err))?; - } - - let conn = ep.connect(addr, server_name)?; - - let conn = if zero_rtt_handshake { - match conn.into_0rtt() { - Ok((conn, _)) => conn, - Err(conn) => conn.await?, - } - } else { - conn.await? - }; - - Ok(conn) - } - let mut last_err = None; for addr in self.server.resolve().await? { - let res = connect_to( - &mut self.ep, - addr, - self.server.server_name(), - self.zero_rtt_handshake, - ) - .await; + let connect_to = async { + let match_ipv4 = + addr.is_ipv4() && self.ep.local_addr().map_or(false, |addr| addr.is_ipv4()); + let match_ipv6 = + addr.is_ipv6() && self.ep.local_addr().map_or(false, |addr| addr.is_ipv6()); - match res { + if !match_ipv4 && !match_ipv6 { + let bind_addr = if addr.is_ipv4() { + SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)) + } else { + SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0)) + }; + + self.ep + .rebind(UdpSocket::bind(bind_addr).map_err(|err| { + Error::Socket("failed to create endpoint UDP socket", err) + })?) + .map_err(|err| { + Error::Socket("failed to rebind endpoint UDP socket", err) + })?; + } + + let conn = self.ep.connect(addr, self.server.server_name())?; + let conn = if self.zero_rtt_handshake { + match conn.into_0rtt() { + Ok((conn, _)) => conn, + Err(conn) => conn.await?, + } + } else { + conn.await? + }; + + Ok(conn) + }; + + match connect_to.await { Ok(conn) => { return Ok(Connection::new( conn, diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 486359a..a30a890 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -202,47 +202,28 @@ impl Connection { gc_interval: Duration, gc_lifetime: Duration, ) { - async fn init( - conn: Connecting, - users: Arc>>, - udp_relay_ipv6: bool, - zero_rtt_handshake: bool, - task_negotiation_timeout: Duration, - max_external_pkt_size: usize, - ) -> Result { + let addr = conn.remote_address(); + + let init = async { let conn = if zero_rtt_handshake { match conn.into_0rtt() { Ok((conn, _)) => conn, - Err(conn) => { - log::info!("0-RTT handshake failed, fallback to 1-RTT handshake"); - conn.await? - } + Err(conn) => conn.await?, } } else { conn.await? }; - Ok(Connection::new( + Ok::<_, Error>(Self::new( conn, users, udp_relay_ipv6, task_negotiation_timeout, max_external_pkt_size, )) - } + }; - let addr = conn.remote_address(); - - match init( - conn, - users, - udp_relay_ipv6, - zero_rtt_handshake, - task_negotiation_timeout, - max_external_pkt_size, - ) - .await - { + match init.await { Ok(conn) => { log::info!("[{addr}] connection established"); @@ -321,44 +302,43 @@ impl Connection { .set_max_concurrent_uni_streams(VarInt::from((max * 2) as u32)); } - async fn pre_process( - conn: &Connection, - recv: RecvStream, - task_negotiation_timeout: Duration, - ) -> Result { - let task = time::timeout(task_negotiation_timeout, conn.model.accept_uni_stream(recv)) - .await - .map_err(|_| Error::TaskNegotiationTimeout)??; + let pre_process = async { + let task = time::timeout( + self.task_negotiation_timeout, + self.model.accept_uni_stream(recv), + ) + .await + .map_err(|_| Error::TaskNegotiationTimeout)??; if let Task::Authenticate(auth) = &task { - if conn.is_authed() { + if self.is_authed() { return Err(Error::DuplicatedAuth); - } else if conn + } else if self .users .get(&auth.uuid()) .map_or(false, |password| auth.validate(password)) { - conn.set_authed(); + self.set_authed(); } else { return Err(Error::AuthFailed(auth.uuid())); } } tokio::select! { - () = conn.authed() => {} - err = conn.inner.closed() => return Err(Error::Connection(err)), + () = self.authed() => {} + err = self.inner.closed() => return Err(Error::Connection(err)), }; let same_pkt_src = matches!(task, Task::Packet(_)) - && matches!(conn.get_udp_relay_mode(), Some(UdpRelayMode::Native)); + && matches!(self.get_udp_relay_mode(), Some(UdpRelayMode::Native)); if same_pkt_src { return Err(Error::UnexpectedPacketSource); } Ok(task) - } + }; - match pre_process(&self, recv, self.task_negotiation_timeout).await { + match pre_process.await { Ok(Task::Authenticate(auth)) => log::info!("[{addr}] authenticated as {}", auth.uuid()), Ok(Task::Packet(pkt)) => { let assoc_id = pkt.assoc_id(); @@ -407,28 +387,23 @@ impl Connection { .set_max_concurrent_bi_streams(VarInt::from((max * 2) as u32)); } - async fn pre_process( - conn: &Connection, - send: SendStream, - recv: RecvStream, - task_negotiation_timeout: Duration, - ) -> Result { + let pre_process = async { let task = time::timeout( - task_negotiation_timeout, - conn.model.accept_bi_stream(send, recv), + self.task_negotiation_timeout, + self.model.accept_bi_stream(send, recv), ) .await .map_err(|_| Error::TaskNegotiationTimeout)??; tokio::select! { - () = conn.authed() => {} - err = conn.inner.closed() => return Err(Error::Connection(err)), + () = self.authed() => {} + err = self.inner.closed() => return Err(Error::Connection(err)), }; Ok(task) - } + }; - match pre_process(&self, send, recv, self.task_negotiation_timeout).await { + match pre_process.await { Ok(Task::Connect(conn)) => { let target_addr = conn.addr().to_string(); log::info!("[{addr}] [connect] [{target_addr}]"); @@ -450,24 +425,24 @@ impl Connection { let addr = self.inner.remote_address(); log::debug!("[{addr}] incoming datagram"); - async fn pre_process(conn: &Connection, dg: Bytes) -> Result { - let task = conn.model.accept_datagram(dg)?; + let pre_process = async { + let task = self.model.accept_datagram(dg)?; tokio::select! { - () = conn.authed() => {} - err = conn.inner.closed() => return Err(Error::Connection(err)), + () = self.authed() => {} + err = self.inner.closed() => return Err(Error::Connection(err)), }; let same_pkt_src = matches!(task, Task::Packet(_)) - && matches!(conn.get_udp_relay_mode(), Some(UdpRelayMode::Quic)); + && matches!(self.get_udp_relay_mode(), Some(UdpRelayMode::Quic)); if same_pkt_src { return Err(Error::UnexpectedPacketSource); } Ok(task) - } + }; - match pre_process(&self, dg).await { + match pre_process.await { Ok(Task::Packet(pkt)) => { let assoc_id = pkt.assoc_id(); let pkt_id = pkt.pkt_id();