1
0

fix udp packet relay error

This commit is contained in:
EAimTY 2022-07-01 19:56:08 +09:00
parent 1147c5a7df
commit 9267ec45cf
4 changed files with 85 additions and 76 deletions

View File

@ -1,7 +1,7 @@
use super::{
incoming::{self, Sender as IncomingSender},
request::Wait as WaitRequest,
stream::{BiStream, IncomingUniStreams, Register as StreamRegister, SendStream},
stream::{BiStream, IncomingUniStreams, RecvStream, Register as StreamRegister, SendStream},
Address, ServerAddr, UdpRelayMode,
};
use bytes::Bytes;
@ -22,7 +22,6 @@ use std::{
time::Duration,
};
use tokio::{
io::AsyncWriteExt,
net,
sync::{mpsc::Sender as MpscSender, Mutex as AsyncMutex, OwnedMutexGuard},
time,
@ -190,7 +189,7 @@ impl Connection {
let mut send = conn.get_send_stream().await?;
let cmd = Command::new_authenticate(token_digest);
cmd.write_to(&mut send).await?;
let _ = send.shutdown().await;
send.finish().await?;
Ok(())
}
@ -205,7 +204,7 @@ impl Connection {
let mut send = conn.get_send_stream().await?;
let cmd = Command::new_heartbeat();
cmd.write_to(&mut send).await?;
let _ = send.shutdown().await;
send.finish().await?;
Ok(())
}
@ -233,7 +232,11 @@ impl Connection {
pub async fn get_bi_stream(&self) -> Result<BiStream> {
let (send, recv) = self.controller.open_bi().await?;
let reg = (*self.stream_reg).clone(); // clone inner, not itself
Ok(BiStream::new(send, recv, reg))
Ok(BiStream::new(
SendStream::new(send, reg.clone()),
RecvStream::new(recv, reg),
))
}
pub fn send_datagram(&self, data: Bytes) -> Result<()> {

View File

@ -4,7 +4,7 @@ use quinn::{
SendStream as QuinnSendStream,
};
use std::{
io::{IoSlice, Result},
io::{Error, IoSlice, Result},
pin::Pin,
result::Result as StdResult,
sync::{Arc, Weak},
@ -12,69 +12,6 @@ use std::{
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub struct BiStream {
send: QuinnSendStream,
recv: QuinnRecvStream,
_reg: Register,
}
impl BiStream {
#[inline]
pub fn new(send: QuinnSendStream, recv: QuinnRecvStream, reg: Register) -> Self {
Self {
send,
recv,
_reg: reg,
}
}
}
impl AsyncRead for BiStream {
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
Pin::new(&mut self.recv).poll_read(cx, buf)
}
}
impl AsyncWrite for BiStream {
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.send).poll_write(cx, buf)
}
#[inline]
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize>> {
Pin::new(&mut self.send).poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.send.is_write_vectored()
}
#[inline]
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.send).poll_flush(cx)
}
#[inline]
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.send).poll_shutdown(cx)
}
}
pub struct SendStream {
send: QuinnSendStream,
_reg: Register,
@ -85,6 +22,11 @@ impl SendStream {
pub fn new(send: QuinnSendStream, reg: Register) -> Self {
Self { send, _reg: reg }
}
#[inline]
pub async fn finish(&mut self) -> Result<()> {
self.send.finish().await.map_err(Error::from)
}
}
impl AsyncWrite for SendStream {
@ -145,6 +87,69 @@ impl AsyncRead for RecvStream {
}
}
pub struct BiStream {
send: SendStream,
recv: RecvStream,
}
impl BiStream {
#[inline]
pub fn new(send: SendStream, recv: RecvStream) -> Self {
Self { send, recv }
}
#[inline]
pub async fn finish(&mut self) -> Result<()> {
self.send.finish().await
}
}
impl AsyncRead for BiStream {
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
Pin::new(&mut self.recv).poll_read(cx, buf)
}
}
impl AsyncWrite for BiStream {
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.send).poll_write(cx, buf)
}
#[inline]
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize>> {
Pin::new(&mut self.send).poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.send.is_write_vectored()
}
#[inline]
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.send).poll_flush(cx)
}
#[inline]
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.send).poll_shutdown(cx)
}
}
pub struct IncomingUniStreams {
incoming: QuinnIncomingUniStreams,
reg: Registry,

View File

@ -15,7 +15,7 @@ impl Connection {
let resp = match TuicCommand::read_from(&mut stream).await {
Ok(resp) => resp,
Err(err) => {
let _ = stream.shutdown().await;
stream.finish().await?;
return Err(err);
}
};
@ -23,7 +23,7 @@ impl Connection {
if let TuicCommand::Response(true) = resp {
Ok(Some(stream))
} else {
let _ = stream.shutdown().await;
stream.finish().await?;
Ok(None)
}
}
@ -67,7 +67,8 @@ impl Connection {
UdpRelayMode::Quic(()) => {
let mut send = conn.get_send_stream().await?;
cmd.write_to(&mut send).await?;
let _ = send.shutdown().await;
send.write_all(&pkt).await?;
send.finish().await?;
}
}
@ -107,7 +108,7 @@ impl Connection {
let mut send = conn.get_send_stream().await?;
cmd.write_to(&mut send).await?;
let _ = send.shutdown().await;
send.finish().await?;
Ok(())
}

View File

@ -13,7 +13,7 @@ use std::{
};
use thiserror::Error;
use tokio::{
io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
io::{self, AsyncRead, AsyncWrite, ReadBuf},
net::{self, TcpStream},
};
use tuic_protocol::{Address, Command};
@ -47,7 +47,7 @@ pub async fn connect(
} else {
let resp = Command::new_response(false);
resp.write_to(&mut send).await?;
let _ = send.shutdown().await;
send.finish().await?;
};
Ok(())
@ -92,7 +92,7 @@ pub async fn packet_to_uni_stream(
let cmd = Command::new_packet(assoc_id, pkt.len() as u16, addr);
cmd.write_to(&mut stream).await?;
stream.write_all(&pkt).await?;
let _ = stream.shutdown().await;
stream.finish().await?;
Ok(())
}