rename IsAuthed
to Authenticated
& store uuid
This commit is contained in:
parent
3bf1ffe137
commit
f5326259bd
@ -21,7 +21,7 @@ use std::{
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket},
|
||||
pin::Pin,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicUsize, Ordering},
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
task::{Context, Poll, Waker},
|
||||
@ -178,7 +178,7 @@ struct Connection {
|
||||
model: Model<side::Server>,
|
||||
users: Arc<HashMap<Uuid, Vec<u8>>>,
|
||||
udp_relay_ipv6: bool,
|
||||
is_authed: IsAuthed,
|
||||
auth: Authenticated,
|
||||
task_negotiation_timeout: Duration,
|
||||
udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>,
|
||||
udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>,
|
||||
@ -235,19 +235,31 @@ impl Connection {
|
||||
break;
|
||||
}
|
||||
|
||||
match conn.accept().await {
|
||||
let handle_incoming = async {
|
||||
tokio::select! {
|
||||
res = conn.inner.accept_uni() =>
|
||||
tokio::spawn(conn.clone().handle_uni_stream(res?, conn.remote_uni_stream_cnt.reg())),
|
||||
res = conn.inner.accept_bi() =>
|
||||
tokio::spawn(conn.clone().handle_bi_stream(res?, conn.remote_bi_stream_cnt.reg())),
|
||||
res = conn.inner.read_datagram() =>
|
||||
tokio::spawn(conn.clone().handle_datagram(res?)),
|
||||
};
|
||||
|
||||
Ok::<_, Error>(())
|
||||
};
|
||||
|
||||
match handle_incoming.await {
|
||||
Ok(()) => {}
|
||||
Err(err) if err.is_locally_closed() => {}
|
||||
Err(err) if err.is_timeout_closed() => {
|
||||
log::debug!("[{addr}] connection timeout")
|
||||
}
|
||||
Err(err) => log::warn!("[{addr}] {err}"),
|
||||
Err(err) => log::warn!("[{addr}] connection error: {err}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) if err.is_locally_closed() => unreachable!(),
|
||||
Err(err) if err.is_timeout_closed() => log::debug!("[{addr}] connection timeout"),
|
||||
Err(err) => log::warn!("[{addr}] {err}"),
|
||||
Err(err) if err.is_locally_closed() || err.is_timeout_closed() => unreachable!(),
|
||||
Err(err) => log::warn!("[{addr}] connection establishing error: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
@ -263,7 +275,7 @@ impl Connection {
|
||||
model: Model::<side::Server>::new(conn),
|
||||
users,
|
||||
udp_relay_ipv6,
|
||||
is_authed: IsAuthed::new(),
|
||||
auth: Authenticated::new(),
|
||||
task_negotiation_timeout,
|
||||
udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||
udp_relay_mode: Arc::new(AtomicCell::new(None)),
|
||||
@ -275,19 +287,6 @@ impl Connection {
|
||||
}
|
||||
}
|
||||
|
||||
async fn accept(&self) -> Result<(), Error> {
|
||||
tokio::select! {
|
||||
res = self.inner.accept_uni() =>
|
||||
tokio::spawn(self.clone().handle_uni_stream(res?, self.remote_uni_stream_cnt.reg())),
|
||||
res = self.inner.accept_bi() =>
|
||||
tokio::spawn(self.clone().handle_bi_stream(res?, self.remote_bi_stream_cnt.reg())),
|
||||
res = self.inner.read_datagram() =>
|
||||
tokio::spawn(self.clone().handle_datagram(res?)),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) {
|
||||
let addr = self.inner.remote_address();
|
||||
log::debug!("[{addr}] incoming unidirectional stream");
|
||||
@ -311,21 +310,21 @@ impl Connection {
|
||||
.map_err(|_| Error::TaskNegotiationTimeout)??;
|
||||
|
||||
if let Task::Authenticate(auth) = &task {
|
||||
if self.is_authed() {
|
||||
if self.auth.get().is_some() {
|
||||
return Err(Error::DuplicatedAuth);
|
||||
} else if self
|
||||
.users
|
||||
.get(&auth.uuid())
|
||||
.map_or(false, |password| auth.validate(password))
|
||||
{
|
||||
self.set_authed();
|
||||
self.auth.set(auth.uuid());
|
||||
} else {
|
||||
return Err(Error::AuthFailed(auth.uuid()));
|
||||
}
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
() = self.authed() => {}
|
||||
() = self.auth.clone() => {}
|
||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||
};
|
||||
|
||||
@ -339,14 +338,16 @@ impl Connection {
|
||||
};
|
||||
|
||||
match pre_process.await {
|
||||
Ok(Task::Authenticate(auth)) => log::info!("[{addr}] authenticated as {}", auth.uuid()),
|
||||
Ok(Task::Authenticate(auth)) => {
|
||||
log::info!("[{addr}] [{uuid}] [authenticate]", uuid = auth.uuid())
|
||||
}
|
||||
Ok(Task::Packet(pkt)) => {
|
||||
let assoc_id = pkt.assoc_id();
|
||||
let pkt_id = pkt.pkt_id();
|
||||
let frag_id = pkt.frag_id();
|
||||
let frag_total = pkt.frag_total();
|
||||
log::info!(
|
||||
"[{addr}] [packet-from-quic] [{assoc_id}] [{pkt_id}] [{frag_id}:{frag_total}]"
|
||||
"[{addr}] [packet-from-quic] [{assoc_id}] [{pkt_id}] [{frag_id}/{frag_total}]"
|
||||
);
|
||||
|
||||
self.set_udp_relay_mode(UdpRelayMode::Quic);
|
||||
@ -396,7 +397,7 @@ impl Connection {
|
||||
.map_err(|_| Error::TaskNegotiationTimeout)??;
|
||||
|
||||
tokio::select! {
|
||||
() = self.authed() => {}
|
||||
() = self.auth.clone() => {}
|
||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||
};
|
||||
|
||||
@ -429,7 +430,7 @@ impl Connection {
|
||||
let task = self.model.accept_datagram(dg)?;
|
||||
|
||||
tokio::select! {
|
||||
() = self.authed() => {}
|
||||
() = self.auth.clone() => {}
|
||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||
};
|
||||
|
||||
@ -544,7 +545,7 @@ impl Connection {
|
||||
async fn handle_auth_timeout(self, timeout: Duration) {
|
||||
time::sleep(timeout).await;
|
||||
|
||||
if !self.is_authed() {
|
||||
if self.auth.get().is_none() {
|
||||
let addr = self.inner.remote_address();
|
||||
log::warn!("[{addr}] authentication timeout");
|
||||
self.close();
|
||||
@ -563,18 +564,6 @@ impl Connection {
|
||||
}
|
||||
}
|
||||
|
||||
fn set_authed(&self) {
|
||||
self.is_authed.set_authed();
|
||||
}
|
||||
|
||||
fn is_authed(&self) -> bool {
|
||||
self.is_authed.is_authed()
|
||||
}
|
||||
|
||||
fn authed(&self) -> IsAuthed {
|
||||
self.is_authed.clone()
|
||||
}
|
||||
|
||||
fn set_udp_relay_mode(&self, mode: UdpRelayMode) {
|
||||
self.udp_relay_mode.store(Some(mode));
|
||||
}
|
||||
@ -745,40 +734,42 @@ impl Drop for UdpSession {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct IsAuthed {
|
||||
is_authed: Arc<AtomicBool>,
|
||||
broadcast: Arc<Mutex<Vec<Waker>>>,
|
||||
struct Authenticated(Arc<AuthenticatedInner>);
|
||||
|
||||
struct AuthenticatedInner {
|
||||
uuid: AtomicCell<Option<Uuid>>,
|
||||
broadcast: Mutex<Vec<Waker>>,
|
||||
}
|
||||
|
||||
impl IsAuthed {
|
||||
impl Authenticated {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
is_authed: Arc::new(AtomicBool::new(false)),
|
||||
broadcast: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
Self(Arc::new(AuthenticatedInner {
|
||||
uuid: AtomicCell::new(None),
|
||||
broadcast: Mutex::new(Vec::new()),
|
||||
}))
|
||||
}
|
||||
|
||||
fn set_authed(&self) {
|
||||
self.is_authed.store(true, Ordering::Release);
|
||||
fn set(&self, uuid: Uuid) {
|
||||
self.0.uuid.store(Some(uuid));
|
||||
|
||||
for waker in self.broadcast.lock().drain(..) {
|
||||
for waker in self.0.broadcast.lock().drain(..) {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
fn is_authed(&self) -> bool {
|
||||
self.is_authed.load(Ordering::Relaxed)
|
||||
fn get(&self) -> Option<Uuid> {
|
||||
self.0.uuid.load()
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for IsAuthed {
|
||||
impl Future for Authenticated {
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
if self.is_authed.load(Ordering::Relaxed) {
|
||||
if self.get().is_some() {
|
||||
Poll::Ready(())
|
||||
} else {
|
||||
self.broadcast.lock().push(cx.waker().clone());
|
||||
self.0.broadcast.lock().push(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user