From 10cee5276e8a0b429faf2a1cab9acadbb1aad0bc Mon Sep 17 00:00:00 2001 From: EAimTY Date: Thu, 26 Jan 2023 20:47:20 +0900 Subject: [PATCH] adding uni stream parsing methods --- tuic-quinn/src/lib.rs | 135 +++++++++++++++++++++++++++++------ tuic-quinn/src/unmarshal.rs | 19 ++++- tuic/src/model/mod.rs | 31 ++++++-- tuic/src/model/packet.rs | 25 +++++++ tuic/src/protocol/connect.rs | 4 +- tuic/src/protocol/mod.rs | 2 +- 6 files changed, 183 insertions(+), 33 deletions(-) diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index e9a5249..b247b11 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -1,6 +1,10 @@ -use self::{marshal::Marshal, side::Side}; +use self::{ + marshal::Marshal, + side::Side, + unmarshal::{Unmarshal, UnmarshalError}, +}; use bytes::Bytes; -use futures_util::{io::Cursor, AsyncRead, AsyncWrite, AsyncWriteExt}; +use futures_util::{io::Cursor, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use quinn::{ Connection as QuinnConnection, ConnectionError, RecvStream, SendDatagramError, SendStream, }; @@ -13,7 +17,8 @@ use thiserror::Error; use tuic::{ model::{ side::{Rx, Tx}, - Connect as ConnectModel, Connection as ConnectionModel, + AssembleError, Connect as ConnectModel, Connection as ConnectionModel, + Packet as PacketModel, }, protocol::{Address, Header}, }; @@ -33,26 +38,11 @@ pub mod side { pub struct Connection<'conn, Side> { conn: &'conn QuinnConnection, - model: ConnectionModel, + model: ConnectionModel>, _marker: Side, } -impl<'conn> Connection<'conn, side::Client> { - pub fn new(conn: &'conn QuinnConnection) -> Self { - Self { - conn, - model: ConnectionModel::new(), - _marker: side::Client, - } - } - - pub async fn connect(&self, addr: Address) -> Result { - let (mut send, recv) = self.conn.open_bi().await?; - let model = self.model.send_connect(addr); - model.header().marshal(&mut send).await?; - Ok(Connect::new(Side::Client(model), send, recv)) - } - +impl<'conn, Side> Connection<'conn, Side> { pub async fn packet_native( &self, pkt: impl AsRef<[u8]>, @@ -92,6 +82,92 @@ impl<'conn> Connection<'conn, side::Client> { Ok(()) } + + async fn accept_packet_quic( + &self, + model: PacketModel>, + mut recv: &mut RecvStream, + ) -> Result, Address, u16)>, Error> { + let mut buf = vec![0; *model.size() as usize]; + AsyncReadExt::read_exact(&mut recv, &mut buf).await?; + let mut asm = Vec::new(); + + Ok(model + .assemble(buf)? + .map(|pkt| pkt.assemble(&mut asm)) + .map(|(addr, assoc_id)| (asm, addr, assoc_id))) + } +} + +impl<'conn> Connection<'conn, side::Client> { + pub fn new(conn: &'conn QuinnConnection) -> Self { + Self { + conn, + model: ConnectionModel::new(), + _marker: side::Client, + } + } + + pub async fn connect(&self, addr: Address) -> Result { + let (mut send, recv) = self.conn.open_bi().await?; + let model = self.model.send_connect(addr); + model.header().marshal(&mut send).await?; + Ok(Connect::new(Side::Client(model), send, recv)) + } + + pub async fn handle_uni_stream(&self, mut recv: RecvStream) -> Result { + match Header::unmarshal(&mut recv).await? { + Header::Authenticate(_) => Err(Error::BadCommand("authenticate")), + Header::Connect(_) => Err(Error::BadCommand("connect")), + Header::Packet(pkt) => { + let model = self.model.recv_packet(pkt); + Ok(Task::Packet( + self.accept_packet_quic(model, &mut recv).await?, + )) + } + Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), + Header::Heartbeat(hb) => { + let _ = self.model.recv_heartbeat(hb); + Ok(Task::Heartbeat) + } + _ => unreachable!(), + } + } +} + +impl<'conn> Connection<'conn, side::Server> { + pub fn new(conn: &'conn QuinnConnection) -> Self { + Self { + conn, + model: ConnectionModel::new(), + _marker: side::Server, + } + } + + pub async fn handle_uni_stream(&self, mut recv: RecvStream) -> Result { + match Header::unmarshal(&mut recv).await? { + Header::Authenticate(auth) => { + let model = self.model.recv_authenticate(auth); + Ok(Task::Authenticate(*model.token())) + } + Header::Connect(_) => Err(Error::BadCommand("connect")), + Header::Packet(pkt) => { + let model = self.model.recv_packet(pkt); + Ok(Task::Packet( + self.accept_packet_quic(model, &mut recv).await?, + )) + } + Header::Dissociate(dissoc) => { + let _ = self.model.recv_dissociate(dissoc); + Ok(Task::Dissociate) + } + Header::Heartbeat(hb) => { + let _ = self.model.recv_heartbeat(hb); + Ok(Task::Heartbeat) + } + _ => unreachable!(), + } + } } pub struct Connect { @@ -112,8 +188,8 @@ impl Connect { pub fn addr(&self) -> &Address { match &self.model { Side::Client(model) => { - let Header::Connect(connect) = model.header() else { unreachable!() }; - &connect.addr() + let Header::Connect(conn) = model.header() else { unreachable!() }; + &conn.addr() } Side::Server(model) => model.addr(), } @@ -148,6 +224,15 @@ impl AsyncWrite for Connect { } } +#[non_exhaustive] +pub enum Task { + Authenticate([u8; 8]), + Connect(Connect), + Packet(Option<(Vec, Address, u16)>), + Dissociate, + Heartbeat, +} + #[derive(Debug, Error)] pub enum Error { #[error(transparent)] @@ -156,4 +241,10 @@ pub enum Error { Connection(#[from] ConnectionError), #[error(transparent)] SendDatagram(#[from] SendDatagramError), + #[error(transparent)] + Unmarshal(#[from] UnmarshalError), + #[error(transparent)] + Assemble(#[from] AssembleError), + #[error("{0}")] + BadCommand(&'static str), } diff --git a/tuic-quinn/src/unmarshal.rs b/tuic-quinn/src/unmarshal.rs index c7e6e11..1f1c9b7 100644 --- a/tuic-quinn/src/unmarshal.rs +++ b/tuic-quinn/src/unmarshal.rs @@ -1,7 +1,22 @@ use async_trait::async_trait; use futures_util::AsyncRead; +use thiserror::Error; +use tuic::protocol::Header; #[async_trait] -trait Unmarshal { - fn unmarshal(&self, s: &mut impl AsyncRead) -> Result<(), ()>; +pub(super) trait Unmarshal +where + Self: Sized, +{ + async fn unmarshal(s: &mut impl AsyncRead) -> Result; } + +#[async_trait] +impl Unmarshal for Header { + async fn unmarshal(s: &mut impl AsyncRead) -> Result { + todo!() + } +} + +#[derive(Debug, Error)] +pub enum UnmarshalError {} diff --git a/tuic/src/model/mod.rs b/tuic/src/model/mod.rs index f526e41..47b29a6 100644 --- a/tuic/src/model/mod.rs +++ b/tuic/src/model/mod.rs @@ -216,7 +216,7 @@ where self.sessions .entry(assoc_id) .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) - .insert(pkt_id, frag_total, frag_id, size, addr, data) + .insert(assoc_id, pkt_id, frag_total, frag_id, size, addr, data) } fn collect_garbage(&mut self, timeout: Duration) { @@ -273,6 +273,7 @@ where fn insert( &mut self, + assoc_id: u16, pkt_id: u16, frag_total: u8, frag_id: u8, @@ -284,7 +285,7 @@ where .pkt_buf .entry(pkt_id) .or_insert_with(|| PacketBuffer::new(frag_total)) - .insert(frag_total, frag_id, size, addr, data)?; + .insert(assoc_id, frag_total, frag_id, size, addr, data)?; if res.is_some() { self.pkt_buf.remove(&pkt_id); @@ -325,6 +326,7 @@ where fn insert( &mut self, + assoc_id: u16, frag_total: u8, frag_id: u8, size: u16, @@ -358,6 +360,7 @@ where Ok(Some(Assemblable::new( mem::take(&mut self.buf), self.addr.take(), + assoc_id, ))) } else { Ok(None) @@ -368,23 +371,28 @@ where pub struct Assemblable { buf: Vec>, addr: Address, + assoc_id: u16, } impl Assemblable where B: AsRef<[u8]>, { - fn new(buf: Vec>, addr: Address) -> Self { - Self { buf, addr } + fn new(buf: Vec>, addr: Address, assoc_id: u16) -> Self { + Self { + buf, + addr, + assoc_id, + } } - pub fn assemble(self, buf: &mut A) -> Address + pub fn assemble(self, buf: &mut A) -> (Address, u16) where A: Assembler, { let data = self.buf.into_iter().map(|b| b.unwrap()); buf.assemble(data); - self.addr + (self.addr, self.assoc_id) } } @@ -396,6 +404,17 @@ where fn assemble(&mut self, data: impl IntoIterator); } +impl Assembler for Vec +where + B: AsRef<[u8]>, +{ + fn assemble(&mut self, data: impl IntoIterator) { + for d in data { + self.extend_from_slice(d.as_ref()); + } + } +} + #[derive(Debug, Error)] pub enum AssembleError { #[error("invalid fragment size")] diff --git a/tuic/src/model/packet.rs b/tuic/src/model/packet.rs index 181fc7c..ea336e3 100644 --- a/tuic/src/model/packet.rs +++ b/tuic/src/model/packet.rs @@ -38,6 +38,16 @@ impl Packet { let Side::Tx(tx) = self.inner else { unreachable!() }; Fragments::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload) } + + pub fn assoc_id(&self) -> &u16 { + let Side::Tx(tx) = &self.inner else { unreachable!() }; + &tx.assoc_id + } + + pub fn addr(&self) -> &Address { + let Side::Tx(tx) = &self.inner else { unreachable!() }; + &tx.addr + } } pub struct Rx { @@ -91,6 +101,21 @@ where data, ) } + + pub fn assoc_id(&self) -> &u16 { + let Side::Rx(rx) = &self.inner else { unreachable!() }; + &rx.assoc_id + } + + pub fn addr(&self) -> &Address { + let Side::Rx(rx) = &self.inner else { unreachable!() }; + &rx.addr + } + + pub fn size(&self) -> &u16 { + let Side::Rx(rx) = &self.inner else { unreachable!() }; + &rx.size + } } pub struct Fragments<'a, P> diff --git a/tuic/src/protocol/connect.rs b/tuic/src/protocol/connect.rs index 1814558..20a4d96 100644 --- a/tuic/src/protocol/connect.rs +++ b/tuic/src/protocol/connect.rs @@ -33,7 +33,7 @@ impl Command for Connect { } impl From for (Address,) { - fn from(connect: Connect) -> Self { - (connect.addr,) + fn from(conn: Connect) -> Self { + (conn.addr,) } } diff --git a/tuic/src/protocol/mod.rs b/tuic/src/protocol/mod.rs index 8368314..c890893 100644 --- a/tuic/src/protocol/mod.rs +++ b/tuic/src/protocol/mod.rs @@ -51,7 +51,7 @@ impl Header { pub fn len(&self) -> usize { 2 + match self { Self::Authenticate(auth) => auth.len(), - Self::Connect(connect) => connect.len(), + Self::Connect(conn) => conn.len(), Self::Packet(packet) => packet.len(), Self::Dissociate(dissociate) => dissociate.len(), Self::Heartbeat(heartbeat) => heartbeat.len(),