1
0

rename IsAuthed to Authenticated & store uuid

This commit is contained in:
EAimTY 2023-05-29 19:56:19 +09:00
parent 3bf1ffe137
commit f5326259bd

View File

@ -21,7 +21,7 @@ use std::{
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket},
pin::Pin, pin::Pin,
sync::{ sync::{
atomic::{AtomicBool, AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
Arc, Arc,
}, },
task::{Context, Poll, Waker}, task::{Context, Poll, Waker},
@ -178,7 +178,7 @@ struct Connection {
model: Model<side::Server>, model: Model<side::Server>,
users: Arc<HashMap<Uuid, Vec<u8>>>, users: Arc<HashMap<Uuid, Vec<u8>>>,
udp_relay_ipv6: bool, udp_relay_ipv6: bool,
is_authed: IsAuthed, auth: Authenticated,
task_negotiation_timeout: Duration, task_negotiation_timeout: Duration,
udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>, udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>,
udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>, udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>,
@ -235,19 +235,31 @@ impl Connection {
break; break;
} }
match conn.accept().await { let handle_incoming = async {
tokio::select! {
res = conn.inner.accept_uni() =>
tokio::spawn(conn.clone().handle_uni_stream(res?, conn.remote_uni_stream_cnt.reg())),
res = conn.inner.accept_bi() =>
tokio::spawn(conn.clone().handle_bi_stream(res?, conn.remote_bi_stream_cnt.reg())),
res = conn.inner.read_datagram() =>
tokio::spawn(conn.clone().handle_datagram(res?)),
};
Ok::<_, Error>(())
};
match handle_incoming.await {
Ok(()) => {} Ok(()) => {}
Err(err) if err.is_locally_closed() => {} Err(err) if err.is_locally_closed() => {}
Err(err) if err.is_timeout_closed() => { Err(err) if err.is_timeout_closed() => {
log::debug!("[{addr}] connection timeout") log::debug!("[{addr}] connection timeout")
} }
Err(err) => log::warn!("[{addr}] {err}"), Err(err) => log::warn!("[{addr}] connection error: {err}"),
} }
} }
} }
Err(err) if err.is_locally_closed() => unreachable!(), Err(err) if err.is_locally_closed() || err.is_timeout_closed() => unreachable!(),
Err(err) if err.is_timeout_closed() => log::debug!("[{addr}] connection timeout"), Err(err) => log::warn!("[{addr}] connection establishing error: {err}"),
Err(err) => log::warn!("[{addr}] {err}"),
} }
} }
@ -263,7 +275,7 @@ impl Connection {
model: Model::<side::Server>::new(conn), model: Model::<side::Server>::new(conn),
users, users,
udp_relay_ipv6, udp_relay_ipv6,
is_authed: IsAuthed::new(), auth: Authenticated::new(),
task_negotiation_timeout, task_negotiation_timeout,
udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())),
udp_relay_mode: Arc::new(AtomicCell::new(None)), udp_relay_mode: Arc::new(AtomicCell::new(None)),
@ -275,19 +287,6 @@ impl Connection {
} }
} }
async fn accept(&self) -> Result<(), Error> {
tokio::select! {
res = self.inner.accept_uni() =>
tokio::spawn(self.clone().handle_uni_stream(res?, self.remote_uni_stream_cnt.reg())),
res = self.inner.accept_bi() =>
tokio::spawn(self.clone().handle_bi_stream(res?, self.remote_bi_stream_cnt.reg())),
res = self.inner.read_datagram() =>
tokio::spawn(self.clone().handle_datagram(res?)),
};
Ok(())
}
async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) {
let addr = self.inner.remote_address(); let addr = self.inner.remote_address();
log::debug!("[{addr}] incoming unidirectional stream"); log::debug!("[{addr}] incoming unidirectional stream");
@ -311,21 +310,21 @@ impl Connection {
.map_err(|_| Error::TaskNegotiationTimeout)??; .map_err(|_| Error::TaskNegotiationTimeout)??;
if let Task::Authenticate(auth) = &task { if let Task::Authenticate(auth) = &task {
if self.is_authed() { if self.auth.get().is_some() {
return Err(Error::DuplicatedAuth); return Err(Error::DuplicatedAuth);
} else if self } else if self
.users .users
.get(&auth.uuid()) .get(&auth.uuid())
.map_or(false, |password| auth.validate(password)) .map_or(false, |password| auth.validate(password))
{ {
self.set_authed(); self.auth.set(auth.uuid());
} else { } else {
return Err(Error::AuthFailed(auth.uuid())); return Err(Error::AuthFailed(auth.uuid()));
} }
} }
tokio::select! { tokio::select! {
() = self.authed() => {} () = self.auth.clone() => {}
err = self.inner.closed() => return Err(Error::Connection(err)), err = self.inner.closed() => return Err(Error::Connection(err)),
}; };
@ -339,14 +338,16 @@ impl Connection {
}; };
match pre_process.await { match pre_process.await {
Ok(Task::Authenticate(auth)) => log::info!("[{addr}] authenticated as {}", auth.uuid()), Ok(Task::Authenticate(auth)) => {
log::info!("[{addr}] [{uuid}] [authenticate]", uuid = auth.uuid())
}
Ok(Task::Packet(pkt)) => { Ok(Task::Packet(pkt)) => {
let assoc_id = pkt.assoc_id(); let assoc_id = pkt.assoc_id();
let pkt_id = pkt.pkt_id(); let pkt_id = pkt.pkt_id();
let frag_id = pkt.frag_id(); let frag_id = pkt.frag_id();
let frag_total = pkt.frag_total(); let frag_total = pkt.frag_total();
log::info!( log::info!(
"[{addr}] [packet-from-quic] [{assoc_id}] [{pkt_id}] [{frag_id}:{frag_total}]" "[{addr}] [packet-from-quic] [{assoc_id}] [{pkt_id}] [{frag_id}/{frag_total}]"
); );
self.set_udp_relay_mode(UdpRelayMode::Quic); self.set_udp_relay_mode(UdpRelayMode::Quic);
@ -396,7 +397,7 @@ impl Connection {
.map_err(|_| Error::TaskNegotiationTimeout)??; .map_err(|_| Error::TaskNegotiationTimeout)??;
tokio::select! { tokio::select! {
() = self.authed() => {} () = self.auth.clone() => {}
err = self.inner.closed() => return Err(Error::Connection(err)), err = self.inner.closed() => return Err(Error::Connection(err)),
}; };
@ -429,7 +430,7 @@ impl Connection {
let task = self.model.accept_datagram(dg)?; let task = self.model.accept_datagram(dg)?;
tokio::select! { tokio::select! {
() = self.authed() => {} () = self.auth.clone() => {}
err = self.inner.closed() => return Err(Error::Connection(err)), err = self.inner.closed() => return Err(Error::Connection(err)),
}; };
@ -544,7 +545,7 @@ impl Connection {
async fn handle_auth_timeout(self, timeout: Duration) { async fn handle_auth_timeout(self, timeout: Duration) {
time::sleep(timeout).await; time::sleep(timeout).await;
if !self.is_authed() { if self.auth.get().is_none() {
let addr = self.inner.remote_address(); let addr = self.inner.remote_address();
log::warn!("[{addr}] authentication timeout"); log::warn!("[{addr}] authentication timeout");
self.close(); self.close();
@ -563,18 +564,6 @@ impl Connection {
} }
} }
fn set_authed(&self) {
self.is_authed.set_authed();
}
fn is_authed(&self) -> bool {
self.is_authed.is_authed()
}
fn authed(&self) -> IsAuthed {
self.is_authed.clone()
}
fn set_udp_relay_mode(&self, mode: UdpRelayMode) { fn set_udp_relay_mode(&self, mode: UdpRelayMode) {
self.udp_relay_mode.store(Some(mode)); self.udp_relay_mode.store(Some(mode));
} }
@ -745,40 +734,42 @@ impl Drop for UdpSession {
} }
#[derive(Clone)] #[derive(Clone)]
struct IsAuthed { struct Authenticated(Arc<AuthenticatedInner>);
is_authed: Arc<AtomicBool>,
broadcast: Arc<Mutex<Vec<Waker>>>, struct AuthenticatedInner {
uuid: AtomicCell<Option<Uuid>>,
broadcast: Mutex<Vec<Waker>>,
} }
impl IsAuthed { impl Authenticated {
fn new() -> Self { fn new() -> Self {
Self { Self(Arc::new(AuthenticatedInner {
is_authed: Arc::new(AtomicBool::new(false)), uuid: AtomicCell::new(None),
broadcast: Arc::new(Mutex::new(Vec::new())), broadcast: Mutex::new(Vec::new()),
} }))
} }
fn set_authed(&self) { fn set(&self, uuid: Uuid) {
self.is_authed.store(true, Ordering::Release); self.0.uuid.store(Some(uuid));
for waker in self.broadcast.lock().drain(..) { for waker in self.0.broadcast.lock().drain(..) {
waker.wake(); waker.wake();
} }
} }
fn is_authed(&self) -> bool { fn get(&self) -> Option<Uuid> {
self.is_authed.load(Ordering::Relaxed) self.0.uuid.load()
} }
} }
impl Future for IsAuthed { impl Future for Authenticated {
type Output = (); type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.is_authed.load(Ordering::Relaxed) { if self.get().is_some() {
Poll::Ready(()) Poll::Ready(())
} else { } else {
self.broadcast.lock().push(cx.waker().clone()); self.0.broadcast.lock().push(cx.waker().clone());
Poll::Pending Poll::Pending
} }
} }