diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index baf176c..70c6f8b 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -3,6 +3,7 @@ use bytes::Bytes; use crossbeam_utils::atomic::AtomicCell; use parking_lot::Mutex; use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; +use register_count::{Counter, Register}; use std::{ collections::{hash_map::Entry, HashMap}, future::Future, @@ -10,7 +11,7 @@ use std::{ net::{Ipv4Addr, Ipv6Addr, SocketAddr}, pin::Pin, sync::{ - atomic::{AtomicBool, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, }, task::{Context, Poll, Waker}, @@ -27,6 +28,8 @@ use tokio_util::compat::FuturesAsyncReadCompatExt; use tuic::Address; use tuic_quinn::{side, Connect, Connection as Model, Packet, Task}; +const DEFAULT_CONCURRENT_STREAMS: usize = 32; + pub struct Server { ep: Endpoint, token: Arc<[u8]>, @@ -61,6 +64,10 @@ struct Connection { is_authed: IsAuthed, udp_sessions: Arc>>, udp_relay_mode: Arc>>, + remote_uni_stream_cnt: Counter, + remote_bi_stream_cnt: Counter, + max_concurrent_uni_streams: Arc, + max_concurrent_bi_streams: Arc, } impl Connection { @@ -111,20 +118,37 @@ impl Connection { is_authed: IsAuthed::new(), udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_relay_mode: Arc::new(AtomicCell::new(None)), + remote_uni_stream_cnt: Counter::new(), + remote_bi_stream_cnt: Counter::new(), + max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), + max_concurrent_bi_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)), }) } async fn accept(&self) -> Result<(), Error> { tokio::select! { - res = self.inner.accept_uni() => tokio::spawn(self.clone().handle_uni_stream(res?)), - res = self.inner.accept_bi() => tokio::spawn(self.clone().handle_bi_stream(res?)), - res = self.inner.read_datagram() => tokio::spawn(self.clone().handle_datagram(res?)), + res = self.inner.accept_uni() => + tokio::spawn(self.clone().handle_uni_stream(res?, self.remote_uni_stream_cnt.reg())), + res = self.inner.accept_bi() => + tokio::spawn(self.clone().handle_bi_stream(res?, self.remote_bi_stream_cnt.reg())), + res = self.inner.read_datagram() => + tokio::spawn(self.clone().handle_datagram(res?)), }; Ok(()) } - async fn handle_uni_stream(self, recv: RecvStream) { + async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { + let max = self.max_concurrent_uni_streams.load(Ordering::Relaxed); + + if self.remote_uni_stream_cnt.count() == max { + self.max_concurrent_uni_streams + .store(max * 2, Ordering::Relaxed); + + self.inner + .set_max_concurrent_uni_streams(VarInt::from((max * 2) as u32)); + } + async fn pre_process(conn: &Connection, recv: RecvStream) -> Result { let task = conn.model.accept_uni_stream(recv).await?; @@ -133,6 +157,7 @@ impl Connection { return Err(Error::DuplicatedAuth); } else { let mut buf = [0; 32]; + conn.inner .export_keying_material(&mut buf, &conn.token, &conn.token) .map_err(|_| Error::ExportKeyingMaterial)?; @@ -180,7 +205,17 @@ impl Connection { } } - async fn handle_bi_stream(self, (send, recv): (SendStream, RecvStream)) { + async fn handle_bi_stream(self, (send, recv): (SendStream, RecvStream), _reg: Register) { + let max = self.max_concurrent_bi_streams.load(Ordering::Relaxed); + + if self.remote_bi_stream_cnt.count() == max { + self.max_concurrent_bi_streams + .store(max * 2, Ordering::Relaxed); + + self.inner + .set_max_concurrent_bi_streams(VarInt::from((max * 2) as u32)); + } + async fn pre_process( conn: &Connection, send: SendStream, @@ -432,7 +467,7 @@ impl UdpSession { socket_v4: &UdpSocket, socket_v6: Option<&UdpSocket>, ) -> Result<(Bytes, SocketAddr), IoError> { - async fn read_packet(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> { + async fn read_pkt(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> { let mut buf = vec![0u8; 65535]; let (n, addr) = socket.recv_from(&mut buf).await?; buf.truncate(n); @@ -441,11 +476,11 @@ impl UdpSession { if let Some(socket_v6) = socket_v6 { tokio::select! { - res = read_packet(socket_v4) => res, - res = read_packet(socket_v6) => res, + res = read_pkt(socket_v4) => res, + res = read_pkt(socket_v6) => res, } } else { - read_packet(socket_v4).await + read_pkt(socket_v4).await } } }