1
0

fix udp packet relay error

This commit is contained in:
EAimTY 2022-07-01 19:56:08 +09:00
parent 1147c5a7df
commit 9267ec45cf
4 changed files with 85 additions and 76 deletions

View File

@ -1,7 +1,7 @@
use super::{ use super::{
incoming::{self, Sender as IncomingSender}, incoming::{self, Sender as IncomingSender},
request::Wait as WaitRequest, request::Wait as WaitRequest,
stream::{BiStream, IncomingUniStreams, Register as StreamRegister, SendStream}, stream::{BiStream, IncomingUniStreams, RecvStream, Register as StreamRegister, SendStream},
Address, ServerAddr, UdpRelayMode, Address, ServerAddr, UdpRelayMode,
}; };
use bytes::Bytes; use bytes::Bytes;
@ -22,7 +22,6 @@ use std::{
time::Duration, time::Duration,
}; };
use tokio::{ use tokio::{
io::AsyncWriteExt,
net, net,
sync::{mpsc::Sender as MpscSender, Mutex as AsyncMutex, OwnedMutexGuard}, sync::{mpsc::Sender as MpscSender, Mutex as AsyncMutex, OwnedMutexGuard},
time, time,
@ -190,7 +189,7 @@ impl Connection {
let mut send = conn.get_send_stream().await?; let mut send = conn.get_send_stream().await?;
let cmd = Command::new_authenticate(token_digest); let cmd = Command::new_authenticate(token_digest);
cmd.write_to(&mut send).await?; cmd.write_to(&mut send).await?;
let _ = send.shutdown().await; send.finish().await?;
Ok(()) Ok(())
} }
@ -205,7 +204,7 @@ impl Connection {
let mut send = conn.get_send_stream().await?; let mut send = conn.get_send_stream().await?;
let cmd = Command::new_heartbeat(); let cmd = Command::new_heartbeat();
cmd.write_to(&mut send).await?; cmd.write_to(&mut send).await?;
let _ = send.shutdown().await; send.finish().await?;
Ok(()) Ok(())
} }
@ -233,7 +232,11 @@ impl Connection {
pub async fn get_bi_stream(&self) -> Result<BiStream> { pub async fn get_bi_stream(&self) -> Result<BiStream> {
let (send, recv) = self.controller.open_bi().await?; let (send, recv) = self.controller.open_bi().await?;
let reg = (*self.stream_reg).clone(); // clone inner, not itself let reg = (*self.stream_reg).clone(); // clone inner, not itself
Ok(BiStream::new(send, recv, reg))
Ok(BiStream::new(
SendStream::new(send, reg.clone()),
RecvStream::new(recv, reg),
))
} }
pub fn send_datagram(&self, data: Bytes) -> Result<()> { pub fn send_datagram(&self, data: Bytes) -> Result<()> {

View File

@ -4,7 +4,7 @@ use quinn::{
SendStream as QuinnSendStream, SendStream as QuinnSendStream,
}; };
use std::{ use std::{
io::{IoSlice, Result}, io::{Error, IoSlice, Result},
pin::Pin, pin::Pin,
result::Result as StdResult, result::Result as StdResult,
sync::{Arc, Weak}, sync::{Arc, Weak},
@ -12,69 +12,6 @@ use std::{
}; };
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
pub struct BiStream {
send: QuinnSendStream,
recv: QuinnRecvStream,
_reg: Register,
}
impl BiStream {
#[inline]
pub fn new(send: QuinnSendStream, recv: QuinnRecvStream, reg: Register) -> Self {
Self {
send,
recv,
_reg: reg,
}
}
}
impl AsyncRead for BiStream {
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
Pin::new(&mut self.recv).poll_read(cx, buf)
}
}
impl AsyncWrite for BiStream {
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.send).poll_write(cx, buf)
}
#[inline]
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize>> {
Pin::new(&mut self.send).poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.send.is_write_vectored()
}
#[inline]
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.send).poll_flush(cx)
}
#[inline]
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.send).poll_shutdown(cx)
}
}
pub struct SendStream { pub struct SendStream {
send: QuinnSendStream, send: QuinnSendStream,
_reg: Register, _reg: Register,
@ -85,6 +22,11 @@ impl SendStream {
pub fn new(send: QuinnSendStream, reg: Register) -> Self { pub fn new(send: QuinnSendStream, reg: Register) -> Self {
Self { send, _reg: reg } Self { send, _reg: reg }
} }
#[inline]
pub async fn finish(&mut self) -> Result<()> {
self.send.finish().await.map_err(Error::from)
}
} }
impl AsyncWrite for SendStream { impl AsyncWrite for SendStream {
@ -145,6 +87,69 @@ impl AsyncRead for RecvStream {
} }
} }
pub struct BiStream {
send: SendStream,
recv: RecvStream,
}
impl BiStream {
#[inline]
pub fn new(send: SendStream, recv: RecvStream) -> Self {
Self { send, recv }
}
#[inline]
pub async fn finish(&mut self) -> Result<()> {
self.send.finish().await
}
}
impl AsyncRead for BiStream {
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
Pin::new(&mut self.recv).poll_read(cx, buf)
}
}
impl AsyncWrite for BiStream {
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.send).poll_write(cx, buf)
}
#[inline]
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize>> {
Pin::new(&mut self.send).poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.send.is_write_vectored()
}
#[inline]
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.send).poll_flush(cx)
}
#[inline]
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.send).poll_shutdown(cx)
}
}
pub struct IncomingUniStreams { pub struct IncomingUniStreams {
incoming: QuinnIncomingUniStreams, incoming: QuinnIncomingUniStreams,
reg: Registry, reg: Registry,

View File

@ -15,7 +15,7 @@ impl Connection {
let resp = match TuicCommand::read_from(&mut stream).await { let resp = match TuicCommand::read_from(&mut stream).await {
Ok(resp) => resp, Ok(resp) => resp,
Err(err) => { Err(err) => {
let _ = stream.shutdown().await; stream.finish().await?;
return Err(err); return Err(err);
} }
}; };
@ -23,7 +23,7 @@ impl Connection {
if let TuicCommand::Response(true) = resp { if let TuicCommand::Response(true) = resp {
Ok(Some(stream)) Ok(Some(stream))
} else { } else {
let _ = stream.shutdown().await; stream.finish().await?;
Ok(None) Ok(None)
} }
} }
@ -67,7 +67,8 @@ impl Connection {
UdpRelayMode::Quic(()) => { UdpRelayMode::Quic(()) => {
let mut send = conn.get_send_stream().await?; let mut send = conn.get_send_stream().await?;
cmd.write_to(&mut send).await?; cmd.write_to(&mut send).await?;
let _ = send.shutdown().await; send.write_all(&pkt).await?;
send.finish().await?;
} }
} }
@ -107,7 +108,7 @@ impl Connection {
let mut send = conn.get_send_stream().await?; let mut send = conn.get_send_stream().await?;
cmd.write_to(&mut send).await?; cmd.write_to(&mut send).await?;
let _ = send.shutdown().await; send.finish().await?;
Ok(()) Ok(())
} }

View File

@ -13,7 +13,7 @@ use std::{
}; };
use thiserror::Error; use thiserror::Error;
use tokio::{ use tokio::{
io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}, io::{self, AsyncRead, AsyncWrite, ReadBuf},
net::{self, TcpStream}, net::{self, TcpStream},
}; };
use tuic_protocol::{Address, Command}; use tuic_protocol::{Address, Command};
@ -47,7 +47,7 @@ pub async fn connect(
} else { } else {
let resp = Command::new_response(false); let resp = Command::new_response(false);
resp.write_to(&mut send).await?; resp.write_to(&mut send).await?;
let _ = send.shutdown().await; send.finish().await?;
}; };
Ok(()) Ok(())
@ -92,7 +92,7 @@ pub async fn packet_to_uni_stream(
let cmd = Command::new_packet(assoc_id, pkt.len() as u16, addr); let cmd = Command::new_packet(assoc_id, pkt.len() as u16, addr);
cmd.write_to(&mut stream).await?; cmd.write_to(&mut stream).await?;
stream.write_all(&pkt).await?; stream.write_all(&pkt).await?;
let _ = stream.shutdown().await; stream.finish().await?;
Ok(()) Ok(())
} }