rename IsAuthed
to Authenticated
& store uuid
This commit is contained in:
parent
3bf1ffe137
commit
f5326259bd
@ -21,7 +21,7 @@ use std::{
|
|||||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket},
|
net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket},
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::{
|
sync::{
|
||||||
atomic::{AtomicBool, AtomicUsize, Ordering},
|
atomic::{AtomicUsize, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
},
|
},
|
||||||
task::{Context, Poll, Waker},
|
task::{Context, Poll, Waker},
|
||||||
@ -178,7 +178,7 @@ struct Connection {
|
|||||||
model: Model<side::Server>,
|
model: Model<side::Server>,
|
||||||
users: Arc<HashMap<Uuid, Vec<u8>>>,
|
users: Arc<HashMap<Uuid, Vec<u8>>>,
|
||||||
udp_relay_ipv6: bool,
|
udp_relay_ipv6: bool,
|
||||||
is_authed: IsAuthed,
|
auth: Authenticated,
|
||||||
task_negotiation_timeout: Duration,
|
task_negotiation_timeout: Duration,
|
||||||
udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>,
|
udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>,
|
||||||
udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>,
|
udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>,
|
||||||
@ -235,19 +235,31 @@ impl Connection {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
match conn.accept().await {
|
let handle_incoming = async {
|
||||||
|
tokio::select! {
|
||||||
|
res = conn.inner.accept_uni() =>
|
||||||
|
tokio::spawn(conn.clone().handle_uni_stream(res?, conn.remote_uni_stream_cnt.reg())),
|
||||||
|
res = conn.inner.accept_bi() =>
|
||||||
|
tokio::spawn(conn.clone().handle_bi_stream(res?, conn.remote_bi_stream_cnt.reg())),
|
||||||
|
res = conn.inner.read_datagram() =>
|
||||||
|
tokio::spawn(conn.clone().handle_datagram(res?)),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok::<_, Error>(())
|
||||||
|
};
|
||||||
|
|
||||||
|
match handle_incoming.await {
|
||||||
Ok(()) => {}
|
Ok(()) => {}
|
||||||
Err(err) if err.is_locally_closed() => {}
|
Err(err) if err.is_locally_closed() => {}
|
||||||
Err(err) if err.is_timeout_closed() => {
|
Err(err) if err.is_timeout_closed() => {
|
||||||
log::debug!("[{addr}] connection timeout")
|
log::debug!("[{addr}] connection timeout")
|
||||||
}
|
}
|
||||||
Err(err) => log::warn!("[{addr}] {err}"),
|
Err(err) => log::warn!("[{addr}] connection error: {err}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(err) if err.is_locally_closed() => unreachable!(),
|
Err(err) if err.is_locally_closed() || err.is_timeout_closed() => unreachable!(),
|
||||||
Err(err) if err.is_timeout_closed() => log::debug!("[{addr}] connection timeout"),
|
Err(err) => log::warn!("[{addr}] connection establishing error: {err}"),
|
||||||
Err(err) => log::warn!("[{addr}] {err}"),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -263,7 +275,7 @@ impl Connection {
|
|||||||
model: Model::<side::Server>::new(conn),
|
model: Model::<side::Server>::new(conn),
|
||||||
users,
|
users,
|
||||||
udp_relay_ipv6,
|
udp_relay_ipv6,
|
||||||
is_authed: IsAuthed::new(),
|
auth: Authenticated::new(),
|
||||||
task_negotiation_timeout,
|
task_negotiation_timeout,
|
||||||
udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())),
|
udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||||
udp_relay_mode: Arc::new(AtomicCell::new(None)),
|
udp_relay_mode: Arc::new(AtomicCell::new(None)),
|
||||||
@ -275,19 +287,6 @@ impl Connection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn accept(&self) -> Result<(), Error> {
|
|
||||||
tokio::select! {
|
|
||||||
res = self.inner.accept_uni() =>
|
|
||||||
tokio::spawn(self.clone().handle_uni_stream(res?, self.remote_uni_stream_cnt.reg())),
|
|
||||||
res = self.inner.accept_bi() =>
|
|
||||||
tokio::spawn(self.clone().handle_bi_stream(res?, self.remote_bi_stream_cnt.reg())),
|
|
||||||
res = self.inner.read_datagram() =>
|
|
||||||
tokio::spawn(self.clone().handle_datagram(res?)),
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) {
|
async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) {
|
||||||
let addr = self.inner.remote_address();
|
let addr = self.inner.remote_address();
|
||||||
log::debug!("[{addr}] incoming unidirectional stream");
|
log::debug!("[{addr}] incoming unidirectional stream");
|
||||||
@ -311,21 +310,21 @@ impl Connection {
|
|||||||
.map_err(|_| Error::TaskNegotiationTimeout)??;
|
.map_err(|_| Error::TaskNegotiationTimeout)??;
|
||||||
|
|
||||||
if let Task::Authenticate(auth) = &task {
|
if let Task::Authenticate(auth) = &task {
|
||||||
if self.is_authed() {
|
if self.auth.get().is_some() {
|
||||||
return Err(Error::DuplicatedAuth);
|
return Err(Error::DuplicatedAuth);
|
||||||
} else if self
|
} else if self
|
||||||
.users
|
.users
|
||||||
.get(&auth.uuid())
|
.get(&auth.uuid())
|
||||||
.map_or(false, |password| auth.validate(password))
|
.map_or(false, |password| auth.validate(password))
|
||||||
{
|
{
|
||||||
self.set_authed();
|
self.auth.set(auth.uuid());
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::AuthFailed(auth.uuid()));
|
return Err(Error::AuthFailed(auth.uuid()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
() = self.authed() => {}
|
() = self.auth.clone() => {}
|
||||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -339,14 +338,16 @@ impl Connection {
|
|||||||
};
|
};
|
||||||
|
|
||||||
match pre_process.await {
|
match pre_process.await {
|
||||||
Ok(Task::Authenticate(auth)) => log::info!("[{addr}] authenticated as {}", auth.uuid()),
|
Ok(Task::Authenticate(auth)) => {
|
||||||
|
log::info!("[{addr}] [{uuid}] [authenticate]", uuid = auth.uuid())
|
||||||
|
}
|
||||||
Ok(Task::Packet(pkt)) => {
|
Ok(Task::Packet(pkt)) => {
|
||||||
let assoc_id = pkt.assoc_id();
|
let assoc_id = pkt.assoc_id();
|
||||||
let pkt_id = pkt.pkt_id();
|
let pkt_id = pkt.pkt_id();
|
||||||
let frag_id = pkt.frag_id();
|
let frag_id = pkt.frag_id();
|
||||||
let frag_total = pkt.frag_total();
|
let frag_total = pkt.frag_total();
|
||||||
log::info!(
|
log::info!(
|
||||||
"[{addr}] [packet-from-quic] [{assoc_id}] [{pkt_id}] [{frag_id}:{frag_total}]"
|
"[{addr}] [packet-from-quic] [{assoc_id}] [{pkt_id}] [{frag_id}/{frag_total}]"
|
||||||
);
|
);
|
||||||
|
|
||||||
self.set_udp_relay_mode(UdpRelayMode::Quic);
|
self.set_udp_relay_mode(UdpRelayMode::Quic);
|
||||||
@ -396,7 +397,7 @@ impl Connection {
|
|||||||
.map_err(|_| Error::TaskNegotiationTimeout)??;
|
.map_err(|_| Error::TaskNegotiationTimeout)??;
|
||||||
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
() = self.authed() => {}
|
() = self.auth.clone() => {}
|
||||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -429,7 +430,7 @@ impl Connection {
|
|||||||
let task = self.model.accept_datagram(dg)?;
|
let task = self.model.accept_datagram(dg)?;
|
||||||
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
() = self.authed() => {}
|
() = self.auth.clone() => {}
|
||||||
err = self.inner.closed() => return Err(Error::Connection(err)),
|
err = self.inner.closed() => return Err(Error::Connection(err)),
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -544,7 +545,7 @@ impl Connection {
|
|||||||
async fn handle_auth_timeout(self, timeout: Duration) {
|
async fn handle_auth_timeout(self, timeout: Duration) {
|
||||||
time::sleep(timeout).await;
|
time::sleep(timeout).await;
|
||||||
|
|
||||||
if !self.is_authed() {
|
if self.auth.get().is_none() {
|
||||||
let addr = self.inner.remote_address();
|
let addr = self.inner.remote_address();
|
||||||
log::warn!("[{addr}] authentication timeout");
|
log::warn!("[{addr}] authentication timeout");
|
||||||
self.close();
|
self.close();
|
||||||
@ -563,18 +564,6 @@ impl Connection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_authed(&self) {
|
|
||||||
self.is_authed.set_authed();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_authed(&self) -> bool {
|
|
||||||
self.is_authed.is_authed()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn authed(&self) -> IsAuthed {
|
|
||||||
self.is_authed.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_udp_relay_mode(&self, mode: UdpRelayMode) {
|
fn set_udp_relay_mode(&self, mode: UdpRelayMode) {
|
||||||
self.udp_relay_mode.store(Some(mode));
|
self.udp_relay_mode.store(Some(mode));
|
||||||
}
|
}
|
||||||
@ -745,40 +734,42 @@ impl Drop for UdpSession {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
struct IsAuthed {
|
struct Authenticated(Arc<AuthenticatedInner>);
|
||||||
is_authed: Arc<AtomicBool>,
|
|
||||||
broadcast: Arc<Mutex<Vec<Waker>>>,
|
struct AuthenticatedInner {
|
||||||
|
uuid: AtomicCell<Option<Uuid>>,
|
||||||
|
broadcast: Mutex<Vec<Waker>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IsAuthed {
|
impl Authenticated {
|
||||||
fn new() -> Self {
|
fn new() -> Self {
|
||||||
Self {
|
Self(Arc::new(AuthenticatedInner {
|
||||||
is_authed: Arc::new(AtomicBool::new(false)),
|
uuid: AtomicCell::new(None),
|
||||||
broadcast: Arc::new(Mutex::new(Vec::new())),
|
broadcast: Mutex::new(Vec::new()),
|
||||||
}
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_authed(&self) {
|
fn set(&self, uuid: Uuid) {
|
||||||
self.is_authed.store(true, Ordering::Release);
|
self.0.uuid.store(Some(uuid));
|
||||||
|
|
||||||
for waker in self.broadcast.lock().drain(..) {
|
for waker in self.0.broadcast.lock().drain(..) {
|
||||||
waker.wake();
|
waker.wake();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_authed(&self) -> bool {
|
fn get(&self) -> Option<Uuid> {
|
||||||
self.is_authed.load(Ordering::Relaxed)
|
self.0.uuid.load()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Future for IsAuthed {
|
impl Future for Authenticated {
|
||||||
type Output = ();
|
type Output = ();
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||||
if self.is_authed.load(Ordering::Relaxed) {
|
if self.get().is_some() {
|
||||||
Poll::Ready(())
|
Poll::Ready(())
|
||||||
} else {
|
} else {
|
||||||
self.broadcast.lock().push(cx.waker().clone());
|
self.0.broadcast.lock().push(cx.waker().clone());
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user