1
0

auto increase max concurrent stream count

This commit is contained in:
EAimTY 2023-02-04 16:17:54 +09:00
parent 8306a7f061
commit 7781f5c62a

View File

@ -3,6 +3,7 @@ use bytes::Bytes;
use crossbeam_utils::atomic::AtomicCell; use crossbeam_utils::atomic::AtomicCell;
use parking_lot::Mutex; use parking_lot::Mutex;
use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt};
use register_count::{Counter, Register};
use std::{ use std::{
collections::{hash_map::Entry, HashMap}, collections::{hash_map::Entry, HashMap},
future::Future, future::Future,
@ -10,7 +11,7 @@ use std::{
net::{Ipv4Addr, Ipv6Addr, SocketAddr}, net::{Ipv4Addr, Ipv6Addr, SocketAddr},
pin::Pin, pin::Pin,
sync::{ sync::{
atomic::{AtomicBool, Ordering}, atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Arc,
}, },
task::{Context, Poll, Waker}, task::{Context, Poll, Waker},
@ -27,6 +28,8 @@ use tokio_util::compat::FuturesAsyncReadCompatExt;
use tuic::Address; use tuic::Address;
use tuic_quinn::{side, Connect, Connection as Model, Packet, Task}; use tuic_quinn::{side, Connect, Connection as Model, Packet, Task};
const DEFAULT_CONCURRENT_STREAMS: usize = 32;
pub struct Server { pub struct Server {
ep: Endpoint, ep: Endpoint,
token: Arc<[u8]>, token: Arc<[u8]>,
@ -61,6 +64,10 @@ struct Connection {
is_authed: IsAuthed, is_authed: IsAuthed,
udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>, udp_sessions: Arc<AsyncMutex<HashMap<u16, UdpSession>>>,
udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>, udp_relay_mode: Arc<AtomicCell<Option<UdpRelayMode>>>,
remote_uni_stream_cnt: Counter,
remote_bi_stream_cnt: Counter,
max_concurrent_uni_streams: Arc<AtomicUsize>,
max_concurrent_bi_streams: Arc<AtomicUsize>,
} }
impl Connection { impl Connection {
@ -111,20 +118,37 @@ impl Connection {
is_authed: IsAuthed::new(), is_authed: IsAuthed::new(),
udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())),
udp_relay_mode: Arc::new(AtomicCell::new(None)), udp_relay_mode: Arc::new(AtomicCell::new(None)),
remote_uni_stream_cnt: Counter::new(),
remote_bi_stream_cnt: Counter::new(),
max_concurrent_uni_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)),
max_concurrent_bi_streams: Arc::new(AtomicUsize::new(DEFAULT_CONCURRENT_STREAMS)),
}) })
} }
async fn accept(&self) -> Result<(), Error> { async fn accept(&self) -> Result<(), Error> {
tokio::select! { tokio::select! {
res = self.inner.accept_uni() => tokio::spawn(self.clone().handle_uni_stream(res?)), res = self.inner.accept_uni() =>
res = self.inner.accept_bi() => tokio::spawn(self.clone().handle_bi_stream(res?)), tokio::spawn(self.clone().handle_uni_stream(res?, self.remote_uni_stream_cnt.reg())),
res = self.inner.read_datagram() => tokio::spawn(self.clone().handle_datagram(res?)), res = self.inner.accept_bi() =>
tokio::spawn(self.clone().handle_bi_stream(res?, self.remote_bi_stream_cnt.reg())),
res = self.inner.read_datagram() =>
tokio::spawn(self.clone().handle_datagram(res?)),
}; };
Ok(()) Ok(())
} }
async fn handle_uni_stream(self, recv: RecvStream) { async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) {
let max = self.max_concurrent_uni_streams.load(Ordering::Relaxed);
if self.remote_uni_stream_cnt.count() == max {
self.max_concurrent_uni_streams
.store(max * 2, Ordering::Relaxed);
self.inner
.set_max_concurrent_uni_streams(VarInt::from((max * 2) as u32));
}
async fn pre_process(conn: &Connection, recv: RecvStream) -> Result<Task, Error> { async fn pre_process(conn: &Connection, recv: RecvStream) -> Result<Task, Error> {
let task = conn.model.accept_uni_stream(recv).await?; let task = conn.model.accept_uni_stream(recv).await?;
@ -133,6 +157,7 @@ impl Connection {
return Err(Error::DuplicatedAuth); return Err(Error::DuplicatedAuth);
} else { } else {
let mut buf = [0; 32]; let mut buf = [0; 32];
conn.inner conn.inner
.export_keying_material(&mut buf, &conn.token, &conn.token) .export_keying_material(&mut buf, &conn.token, &conn.token)
.map_err(|_| Error::ExportKeyingMaterial)?; .map_err(|_| Error::ExportKeyingMaterial)?;
@ -180,7 +205,17 @@ impl Connection {
} }
} }
async fn handle_bi_stream(self, (send, recv): (SendStream, RecvStream)) { async fn handle_bi_stream(self, (send, recv): (SendStream, RecvStream), _reg: Register) {
let max = self.max_concurrent_bi_streams.load(Ordering::Relaxed);
if self.remote_bi_stream_cnt.count() == max {
self.max_concurrent_bi_streams
.store(max * 2, Ordering::Relaxed);
self.inner
.set_max_concurrent_bi_streams(VarInt::from((max * 2) as u32));
}
async fn pre_process( async fn pre_process(
conn: &Connection, conn: &Connection,
send: SendStream, send: SendStream,
@ -432,7 +467,7 @@ impl UdpSession {
socket_v4: &UdpSocket, socket_v4: &UdpSocket,
socket_v6: Option<&UdpSocket>, socket_v6: Option<&UdpSocket>,
) -> Result<(Bytes, SocketAddr), IoError> { ) -> Result<(Bytes, SocketAddr), IoError> {
async fn read_packet(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> { async fn read_pkt(socket: &UdpSocket) -> Result<(Bytes, SocketAddr), IoError> {
let mut buf = vec![0u8; 65535]; let mut buf = vec![0u8; 65535];
let (n, addr) = socket.recv_from(&mut buf).await?; let (n, addr) = socket.recv_from(&mut buf).await?;
buf.truncate(n); buf.truncate(n);
@ -441,11 +476,11 @@ impl UdpSession {
if let Some(socket_v6) = socket_v6 { if let Some(socket_v6) = socket_v6 {
tokio::select! { tokio::select! {
res = read_packet(socket_v4) => res, res = read_pkt(socket_v4) => res,
res = read_packet(socket_v6) => res, res = read_pkt(socket_v6) => res,
} }
} else { } else {
read_packet(socket_v4).await read_pkt(socket_v4).await
} }
} }
} }