1
0

replace nested async functions with async blocks

This commit is contained in:
EAimTY 2023-05-29 19:24:26 +09:00
parent 191dffd03e
commit 3bf1ffe137
2 changed files with 71 additions and 107 deletions

View File

@ -130,56 +130,45 @@ impl Endpoint {
} }
async fn connect(&mut self) -> Result<Connection, Error> { async fn connect(&mut self) -> Result<Connection, Error> {
async fn connect_to(
ep: &mut QuinnEndpoint,
addr: SocketAddr,
server_name: &str,
zero_rtt_handshake: bool,
) -> Result<QuinnConnection, Error> {
let match_ipv4 = addr.is_ipv4() && ep.local_addr().map_or(false, |addr| addr.is_ipv4());
let match_ipv6 = addr.is_ipv6() && ep.local_addr().map_or(false, |addr| addr.is_ipv6());
if !match_ipv4 && !match_ipv6 {
let bind_addr = if addr.is_ipv4() {
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
} else {
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
};
ep.rebind(
UdpSocket::bind(bind_addr).map_err(|err| {
Error::Socket("failed to create endpoint UDP socket", err)
})?,
)
.map_err(|err| Error::Socket("failed to rebind endpoint UDP socket", err))?;
}
let conn = ep.connect(addr, server_name)?;
let conn = if zero_rtt_handshake {
match conn.into_0rtt() {
Ok((conn, _)) => conn,
Err(conn) => conn.await?,
}
} else {
conn.await?
};
Ok(conn)
}
let mut last_err = None; let mut last_err = None;
for addr in self.server.resolve().await? { for addr in self.server.resolve().await? {
let res = connect_to( let connect_to = async {
&mut self.ep, let match_ipv4 =
addr, addr.is_ipv4() && self.ep.local_addr().map_or(false, |addr| addr.is_ipv4());
self.server.server_name(), let match_ipv6 =
self.zero_rtt_handshake, addr.is_ipv6() && self.ep.local_addr().map_or(false, |addr| addr.is_ipv6());
)
.await;
match res { if !match_ipv4 && !match_ipv6 {
let bind_addr = if addr.is_ipv4() {
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
} else {
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
};
self.ep
.rebind(UdpSocket::bind(bind_addr).map_err(|err| {
Error::Socket("failed to create endpoint UDP socket", err)
})?)
.map_err(|err| {
Error::Socket("failed to rebind endpoint UDP socket", err)
})?;
}
let conn = self.ep.connect(addr, self.server.server_name())?;
let conn = if self.zero_rtt_handshake {
match conn.into_0rtt() {
Ok((conn, _)) => conn,
Err(conn) => conn.await?,
}
} else {
conn.await?
};
Ok(conn)
};
match connect_to.await {
Ok(conn) => { Ok(conn) => {
return Ok(Connection::new( return Ok(Connection::new(
conn, conn,

View File

@ -202,47 +202,28 @@ impl Connection {
gc_interval: Duration, gc_interval: Duration,
gc_lifetime: Duration, gc_lifetime: Duration,
) { ) {
async fn init( let addr = conn.remote_address();
conn: Connecting,
users: Arc<HashMap<Uuid, Vec<u8>>>, let init = async {
udp_relay_ipv6: bool,
zero_rtt_handshake: bool,
task_negotiation_timeout: Duration,
max_external_pkt_size: usize,
) -> Result<Connection, Error> {
let conn = if zero_rtt_handshake { let conn = if zero_rtt_handshake {
match conn.into_0rtt() { match conn.into_0rtt() {
Ok((conn, _)) => conn, Ok((conn, _)) => conn,
Err(conn) => { Err(conn) => conn.await?,
log::info!("0-RTT handshake failed, fallback to 1-RTT handshake");
conn.await?
}
} }
} else { } else {
conn.await? conn.await?
}; };
Ok(Connection::new( Ok::<_, Error>(Self::new(
conn, conn,
users, users,
udp_relay_ipv6, udp_relay_ipv6,
task_negotiation_timeout, task_negotiation_timeout,
max_external_pkt_size, max_external_pkt_size,
)) ))
} };
let addr = conn.remote_address(); match init.await {
match init(
conn,
users,
udp_relay_ipv6,
zero_rtt_handshake,
task_negotiation_timeout,
max_external_pkt_size,
)
.await
{
Ok(conn) => { Ok(conn) => {
log::info!("[{addr}] connection established"); log::info!("[{addr}] connection established");
@ -321,44 +302,43 @@ impl Connection {
.set_max_concurrent_uni_streams(VarInt::from((max * 2) as u32)); .set_max_concurrent_uni_streams(VarInt::from((max * 2) as u32));
} }
async fn pre_process( let pre_process = async {
conn: &Connection, let task = time::timeout(
recv: RecvStream, self.task_negotiation_timeout,
task_negotiation_timeout: Duration, self.model.accept_uni_stream(recv),
) -> Result<Task, Error> { )
let task = time::timeout(task_negotiation_timeout, conn.model.accept_uni_stream(recv)) .await
.await .map_err(|_| Error::TaskNegotiationTimeout)??;
.map_err(|_| Error::TaskNegotiationTimeout)??;
if let Task::Authenticate(auth) = &task { if let Task::Authenticate(auth) = &task {
if conn.is_authed() { if self.is_authed() {
return Err(Error::DuplicatedAuth); return Err(Error::DuplicatedAuth);
} else if conn } else if self
.users .users
.get(&auth.uuid()) .get(&auth.uuid())
.map_or(false, |password| auth.validate(password)) .map_or(false, |password| auth.validate(password))
{ {
conn.set_authed(); self.set_authed();
} else { } else {
return Err(Error::AuthFailed(auth.uuid())); return Err(Error::AuthFailed(auth.uuid()));
} }
} }
tokio::select! { tokio::select! {
() = conn.authed() => {} () = self.authed() => {}
err = conn.inner.closed() => return Err(Error::Connection(err)), err = self.inner.closed() => return Err(Error::Connection(err)),
}; };
let same_pkt_src = matches!(task, Task::Packet(_)) let same_pkt_src = matches!(task, Task::Packet(_))
&& matches!(conn.get_udp_relay_mode(), Some(UdpRelayMode::Native)); && matches!(self.get_udp_relay_mode(), Some(UdpRelayMode::Native));
if same_pkt_src { if same_pkt_src {
return Err(Error::UnexpectedPacketSource); return Err(Error::UnexpectedPacketSource);
} }
Ok(task) Ok(task)
} };
match pre_process(&self, recv, self.task_negotiation_timeout).await { match pre_process.await {
Ok(Task::Authenticate(auth)) => log::info!("[{addr}] authenticated as {}", auth.uuid()), Ok(Task::Authenticate(auth)) => log::info!("[{addr}] authenticated as {}", auth.uuid()),
Ok(Task::Packet(pkt)) => { Ok(Task::Packet(pkt)) => {
let assoc_id = pkt.assoc_id(); let assoc_id = pkt.assoc_id();
@ -407,28 +387,23 @@ impl Connection {
.set_max_concurrent_bi_streams(VarInt::from((max * 2) as u32)); .set_max_concurrent_bi_streams(VarInt::from((max * 2) as u32));
} }
async fn pre_process( let pre_process = async {
conn: &Connection,
send: SendStream,
recv: RecvStream,
task_negotiation_timeout: Duration,
) -> Result<Task, Error> {
let task = time::timeout( let task = time::timeout(
task_negotiation_timeout, self.task_negotiation_timeout,
conn.model.accept_bi_stream(send, recv), self.model.accept_bi_stream(send, recv),
) )
.await .await
.map_err(|_| Error::TaskNegotiationTimeout)??; .map_err(|_| Error::TaskNegotiationTimeout)??;
tokio::select! { tokio::select! {
() = conn.authed() => {} () = self.authed() => {}
err = conn.inner.closed() => return Err(Error::Connection(err)), err = self.inner.closed() => return Err(Error::Connection(err)),
}; };
Ok(task) Ok(task)
} };
match pre_process(&self, send, recv, self.task_negotiation_timeout).await { match pre_process.await {
Ok(Task::Connect(conn)) => { Ok(Task::Connect(conn)) => {
let target_addr = conn.addr().to_string(); let target_addr = conn.addr().to_string();
log::info!("[{addr}] [connect] [{target_addr}]"); log::info!("[{addr}] [connect] [{target_addr}]");
@ -450,24 +425,24 @@ impl Connection {
let addr = self.inner.remote_address(); let addr = self.inner.remote_address();
log::debug!("[{addr}] incoming datagram"); log::debug!("[{addr}] incoming datagram");
async fn pre_process(conn: &Connection, dg: Bytes) -> Result<Task, Error> { let pre_process = async {
let task = conn.model.accept_datagram(dg)?; let task = self.model.accept_datagram(dg)?;
tokio::select! { tokio::select! {
() = conn.authed() => {} () = self.authed() => {}
err = conn.inner.closed() => return Err(Error::Connection(err)), err = self.inner.closed() => return Err(Error::Connection(err)),
}; };
let same_pkt_src = matches!(task, Task::Packet(_)) let same_pkt_src = matches!(task, Task::Packet(_))
&& matches!(conn.get_udp_relay_mode(), Some(UdpRelayMode::Quic)); && matches!(self.get_udp_relay_mode(), Some(UdpRelayMode::Quic));
if same_pkt_src { if same_pkt_src {
return Err(Error::UnexpectedPacketSource); return Err(Error::UnexpectedPacketSource);
} }
Ok(task) Ok(task)
} };
match pre_process(&self, dg).await { match pre_process.await {
Ok(Task::Packet(pkt)) => { Ok(Task::Packet(pkt)) => {
let assoc_id = pkt.assoc_id(); let assoc_id = pkt.assoc_id();
let pkt_id = pkt.pkt_id(); let pkt_id = pkt.pkt_id();