diff --git a/client/src/relay/connection.rs b/client/src/relay/connection.rs index 75eb25e..850405b 100644 --- a/client/src/relay/connection.rs +++ b/client/src/relay/connection.rs @@ -1,7 +1,7 @@ use super::{ incoming::{self, Sender as IncomingSender}, request::Wait as WaitRequest, - stream::{BiStream, IncomingUniStreams, Register as StreamRegister, SendStream}, + stream::{BiStream, IncomingUniStreams, RecvStream, Register as StreamRegister, SendStream}, Address, ServerAddr, UdpRelayMode, }; use bytes::Bytes; @@ -22,7 +22,6 @@ use std::{ time::Duration, }; use tokio::{ - io::AsyncWriteExt, net, sync::{mpsc::Sender as MpscSender, Mutex as AsyncMutex, OwnedMutexGuard}, time, @@ -190,7 +189,7 @@ impl Connection { let mut send = conn.get_send_stream().await?; let cmd = Command::new_authenticate(token_digest); cmd.write_to(&mut send).await?; - let _ = send.shutdown().await; + send.finish().await?; Ok(()) } @@ -205,7 +204,7 @@ impl Connection { let mut send = conn.get_send_stream().await?; let cmd = Command::new_heartbeat(); cmd.write_to(&mut send).await?; - let _ = send.shutdown().await; + send.finish().await?; Ok(()) } @@ -233,7 +232,11 @@ impl Connection { pub async fn get_bi_stream(&self) -> Result { let (send, recv) = self.controller.open_bi().await?; let reg = (*self.stream_reg).clone(); // clone inner, not itself - Ok(BiStream::new(send, recv, reg)) + + Ok(BiStream::new( + SendStream::new(send, reg.clone()), + RecvStream::new(recv, reg), + )) } pub fn send_datagram(&self, data: Bytes) -> Result<()> { diff --git a/client/src/relay/stream.rs b/client/src/relay/stream.rs index 97d384a..535c71c 100644 --- a/client/src/relay/stream.rs +++ b/client/src/relay/stream.rs @@ -4,7 +4,7 @@ use quinn::{ SendStream as QuinnSendStream, }; use std::{ - io::{IoSlice, Result}, + io::{Error, IoSlice, Result}, pin::Pin, result::Result as StdResult, sync::{Arc, Weak}, @@ -12,69 +12,6 @@ use std::{ }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -pub struct BiStream { - send: QuinnSendStream, - recv: QuinnRecvStream, - _reg: Register, -} - -impl BiStream { - #[inline] - pub fn new(send: QuinnSendStream, recv: QuinnRecvStream, reg: Register) -> Self { - Self { - send, - recv, - _reg: reg, - } - } -} - -impl AsyncRead for BiStream { - #[inline] - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.recv).poll_read(cx, buf) - } -} - -impl AsyncWrite for BiStream { - #[inline] - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.send).poll_write(cx, buf) - } - - #[inline] - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.send).poll_write_vectored(cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.send.is_write_vectored() - } - - #[inline] - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.send).poll_flush(cx) - } - - #[inline] - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.send).poll_shutdown(cx) - } -} - pub struct SendStream { send: QuinnSendStream, _reg: Register, @@ -85,6 +22,11 @@ impl SendStream { pub fn new(send: QuinnSendStream, reg: Register) -> Self { Self { send, _reg: reg } } + + #[inline] + pub async fn finish(&mut self) -> Result<()> { + self.send.finish().await.map_err(Error::from) + } } impl AsyncWrite for SendStream { @@ -145,6 +87,69 @@ impl AsyncRead for RecvStream { } } +pub struct BiStream { + send: SendStream, + recv: RecvStream, +} + +impl BiStream { + #[inline] + pub fn new(send: SendStream, recv: RecvStream) -> Self { + Self { send, recv } + } + + #[inline] + pub async fn finish(&mut self) -> Result<()> { + self.send.finish().await + } +} + +impl AsyncRead for BiStream { + #[inline] + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.recv).poll_read(cx, buf) + } +} + +impl AsyncWrite for BiStream { + #[inline] + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.send).poll_write(cx, buf) + } + + #[inline] + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.send).poll_write_vectored(cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.send.is_write_vectored() + } + + #[inline] + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.send).poll_flush(cx) + } + + #[inline] + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.send).poll_shutdown(cx) + } +} + pub struct IncomingUniStreams { incoming: QuinnIncomingUniStreams, reg: Registry, diff --git a/client/src/relay/task.rs b/client/src/relay/task.rs index eac934d..2e95d1d 100644 --- a/client/src/relay/task.rs +++ b/client/src/relay/task.rs @@ -15,7 +15,7 @@ impl Connection { let resp = match TuicCommand::read_from(&mut stream).await { Ok(resp) => resp, Err(err) => { - let _ = stream.shutdown().await; + stream.finish().await?; return Err(err); } }; @@ -23,7 +23,7 @@ impl Connection { if let TuicCommand::Response(true) = resp { Ok(Some(stream)) } else { - let _ = stream.shutdown().await; + stream.finish().await?; Ok(None) } } @@ -67,7 +67,8 @@ impl Connection { UdpRelayMode::Quic(()) => { let mut send = conn.get_send_stream().await?; cmd.write_to(&mut send).await?; - let _ = send.shutdown().await; + send.write_all(&pkt).await?; + send.finish().await?; } } @@ -107,7 +108,7 @@ impl Connection { let mut send = conn.get_send_stream().await?; cmd.write_to(&mut send).await?; - let _ = send.shutdown().await; + send.finish().await?; Ok(()) } diff --git a/server/src/connection/task.rs b/server/src/connection/task.rs index b789700..2affc7b 100644 --- a/server/src/connection/task.rs +++ b/server/src/connection/task.rs @@ -13,7 +13,7 @@ use std::{ }; use thiserror::Error; use tokio::{ - io::{self, AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}, + io::{self, AsyncRead, AsyncWrite, ReadBuf}, net::{self, TcpStream}, }; use tuic_protocol::{Address, Command}; @@ -47,7 +47,7 @@ pub async fn connect( } else { let resp = Command::new_response(false); resp.write_to(&mut send).await?; - let _ = send.shutdown().await; + send.finish().await?; }; Ok(()) @@ -92,7 +92,7 @@ pub async fn packet_to_uni_stream( let cmd = Command::new_packet(assoc_id, pkt.len() as u16, addr); cmd.write_to(&mut stream).await?; stream.write_all(&pkt).await?; - let _ = stream.shutdown().await; + stream.finish().await?; Ok(()) }