1
0

adding uni stream parsing methods

This commit is contained in:
EAimTY 2023-01-26 20:47:20 +09:00
parent 101d4427eb
commit 10cee5276e
6 changed files with 183 additions and 33 deletions

View File

@ -1,6 +1,10 @@
use self::{marshal::Marshal, side::Side}; use self::{
marshal::Marshal,
side::Side,
unmarshal::{Unmarshal, UnmarshalError},
};
use bytes::Bytes; use bytes::Bytes;
use futures_util::{io::Cursor, AsyncRead, AsyncWrite, AsyncWriteExt}; use futures_util::{io::Cursor, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use quinn::{ use quinn::{
Connection as QuinnConnection, ConnectionError, RecvStream, SendDatagramError, SendStream, Connection as QuinnConnection, ConnectionError, RecvStream, SendDatagramError, SendStream,
}; };
@ -13,7 +17,8 @@ use thiserror::Error;
use tuic::{ use tuic::{
model::{ model::{
side::{Rx, Tx}, side::{Rx, Tx},
Connect as ConnectModel, Connection as ConnectionModel, AssembleError, Connect as ConnectModel, Connection as ConnectionModel,
Packet as PacketModel,
}, },
protocol::{Address, Header}, protocol::{Address, Header},
}; };
@ -33,26 +38,11 @@ pub mod side {
pub struct Connection<'conn, Side> { pub struct Connection<'conn, Side> {
conn: &'conn QuinnConnection, conn: &'conn QuinnConnection,
model: ConnectionModel<Bytes>, model: ConnectionModel<Vec<u8>>,
_marker: Side, _marker: Side,
} }
impl<'conn> Connection<'conn, side::Client> { impl<'conn, Side> Connection<'conn, Side> {
pub fn new(conn: &'conn QuinnConnection) -> Self {
Self {
conn,
model: ConnectionModel::new(),
_marker: side::Client,
}
}
pub async fn connect(&self, addr: Address) -> Result<Connect, Error> {
let (mut send, recv) = self.conn.open_bi().await?;
let model = self.model.send_connect(addr);
model.header().marshal(&mut send).await?;
Ok(Connect::new(Side::Client(model), send, recv))
}
pub async fn packet_native( pub async fn packet_native(
&self, &self,
pkt: impl AsRef<[u8]>, pkt: impl AsRef<[u8]>,
@ -92,6 +82,92 @@ impl<'conn> Connection<'conn, side::Client> {
Ok(()) Ok(())
} }
async fn accept_packet_quic(
&self,
model: PacketModel<Rx, Vec<u8>>,
mut recv: &mut RecvStream,
) -> Result<Option<(Vec<u8>, Address, u16)>, Error> {
let mut buf = vec![0; *model.size() as usize];
AsyncReadExt::read_exact(&mut recv, &mut buf).await?;
let mut asm = Vec::new();
Ok(model
.assemble(buf)?
.map(|pkt| pkt.assemble(&mut asm))
.map(|(addr, assoc_id)| (asm, addr, assoc_id)))
}
}
impl<'conn> Connection<'conn, side::Client> {
pub fn new(conn: &'conn QuinnConnection) -> Self {
Self {
conn,
model: ConnectionModel::new(),
_marker: side::Client,
}
}
pub async fn connect(&self, addr: Address) -> Result<Connect, Error> {
let (mut send, recv) = self.conn.open_bi().await?;
let model = self.model.send_connect(addr);
model.header().marshal(&mut send).await?;
Ok(Connect::new(Side::Client(model), send, recv))
}
pub async fn handle_uni_stream(&self, mut recv: RecvStream) -> Result<Task, Error> {
match Header::unmarshal(&mut recv).await? {
Header::Authenticate(_) => Err(Error::BadCommand("authenticate")),
Header::Connect(_) => Err(Error::BadCommand("connect")),
Header::Packet(pkt) => {
let model = self.model.recv_packet(pkt);
Ok(Task::Packet(
self.accept_packet_quic(model, &mut recv).await?,
))
}
Header::Dissociate(_) => Err(Error::BadCommand("dissociate")),
Header::Heartbeat(hb) => {
let _ = self.model.recv_heartbeat(hb);
Ok(Task::Heartbeat)
}
_ => unreachable!(),
}
}
}
impl<'conn> Connection<'conn, side::Server> {
pub fn new(conn: &'conn QuinnConnection) -> Self {
Self {
conn,
model: ConnectionModel::new(),
_marker: side::Server,
}
}
pub async fn handle_uni_stream(&self, mut recv: RecvStream) -> Result<Task, Error> {
match Header::unmarshal(&mut recv).await? {
Header::Authenticate(auth) => {
let model = self.model.recv_authenticate(auth);
Ok(Task::Authenticate(*model.token()))
}
Header::Connect(_) => Err(Error::BadCommand("connect")),
Header::Packet(pkt) => {
let model = self.model.recv_packet(pkt);
Ok(Task::Packet(
self.accept_packet_quic(model, &mut recv).await?,
))
}
Header::Dissociate(dissoc) => {
let _ = self.model.recv_dissociate(dissoc);
Ok(Task::Dissociate)
}
Header::Heartbeat(hb) => {
let _ = self.model.recv_heartbeat(hb);
Ok(Task::Heartbeat)
}
_ => unreachable!(),
}
}
} }
pub struct Connect { pub struct Connect {
@ -112,8 +188,8 @@ impl Connect {
pub fn addr(&self) -> &Address { pub fn addr(&self) -> &Address {
match &self.model { match &self.model {
Side::Client(model) => { Side::Client(model) => {
let Header::Connect(connect) = model.header() else { unreachable!() }; let Header::Connect(conn) = model.header() else { unreachable!() };
&connect.addr() &conn.addr()
} }
Side::Server(model) => model.addr(), Side::Server(model) => model.addr(),
} }
@ -148,6 +224,15 @@ impl AsyncWrite for Connect {
} }
} }
#[non_exhaustive]
pub enum Task {
Authenticate([u8; 8]),
Connect(Connect),
Packet(Option<(Vec<u8>, Address, u16)>),
Dissociate,
Heartbeat,
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum Error { pub enum Error {
#[error(transparent)] #[error(transparent)]
@ -156,4 +241,10 @@ pub enum Error {
Connection(#[from] ConnectionError), Connection(#[from] ConnectionError),
#[error(transparent)] #[error(transparent)]
SendDatagram(#[from] SendDatagramError), SendDatagram(#[from] SendDatagramError),
#[error(transparent)]
Unmarshal(#[from] UnmarshalError),
#[error(transparent)]
Assemble(#[from] AssembleError),
#[error("{0}")]
BadCommand(&'static str),
} }

View File

@ -1,7 +1,22 @@
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::AsyncRead; use futures_util::AsyncRead;
use thiserror::Error;
use tuic::protocol::Header;
#[async_trait] #[async_trait]
trait Unmarshal { pub(super) trait Unmarshal
fn unmarshal(&self, s: &mut impl AsyncRead) -> Result<(), ()>; where
Self: Sized,
{
async fn unmarshal(s: &mut impl AsyncRead) -> Result<Self, UnmarshalError>;
} }
#[async_trait]
impl Unmarshal for Header {
async fn unmarshal(s: &mut impl AsyncRead) -> Result<Self, UnmarshalError> {
todo!()
}
}
#[derive(Debug, Error)]
pub enum UnmarshalError {}

View File

@ -216,7 +216,7 @@ where
self.sessions self.sessions
.entry(assoc_id) .entry(assoc_id)
.or_insert_with(|| UdpSession::new(self.task_associate_count.register())) .or_insert_with(|| UdpSession::new(self.task_associate_count.register()))
.insert(pkt_id, frag_total, frag_id, size, addr, data) .insert(assoc_id, pkt_id, frag_total, frag_id, size, addr, data)
} }
fn collect_garbage(&mut self, timeout: Duration) { fn collect_garbage(&mut self, timeout: Duration) {
@ -273,6 +273,7 @@ where
fn insert( fn insert(
&mut self, &mut self,
assoc_id: u16,
pkt_id: u16, pkt_id: u16,
frag_total: u8, frag_total: u8,
frag_id: u8, frag_id: u8,
@ -284,7 +285,7 @@ where
.pkt_buf .pkt_buf
.entry(pkt_id) .entry(pkt_id)
.or_insert_with(|| PacketBuffer::new(frag_total)) .or_insert_with(|| PacketBuffer::new(frag_total))
.insert(frag_total, frag_id, size, addr, data)?; .insert(assoc_id, frag_total, frag_id, size, addr, data)?;
if res.is_some() { if res.is_some() {
self.pkt_buf.remove(&pkt_id); self.pkt_buf.remove(&pkt_id);
@ -325,6 +326,7 @@ where
fn insert( fn insert(
&mut self, &mut self,
assoc_id: u16,
frag_total: u8, frag_total: u8,
frag_id: u8, frag_id: u8,
size: u16, size: u16,
@ -358,6 +360,7 @@ where
Ok(Some(Assemblable::new( Ok(Some(Assemblable::new(
mem::take(&mut self.buf), mem::take(&mut self.buf),
self.addr.take(), self.addr.take(),
assoc_id,
))) )))
} else { } else {
Ok(None) Ok(None)
@ -368,23 +371,28 @@ where
pub struct Assemblable<B> { pub struct Assemblable<B> {
buf: Vec<Option<B>>, buf: Vec<Option<B>>,
addr: Address, addr: Address,
assoc_id: u16,
} }
impl<B> Assemblable<B> impl<B> Assemblable<B>
where where
B: AsRef<[u8]>, B: AsRef<[u8]>,
{ {
fn new(buf: Vec<Option<B>>, addr: Address) -> Self { fn new(buf: Vec<Option<B>>, addr: Address, assoc_id: u16) -> Self {
Self { buf, addr } Self {
buf,
addr,
assoc_id,
}
} }
pub fn assemble<A>(self, buf: &mut A) -> Address pub fn assemble<A>(self, buf: &mut A) -> (Address, u16)
where where
A: Assembler<B>, A: Assembler<B>,
{ {
let data = self.buf.into_iter().map(|b| b.unwrap()); let data = self.buf.into_iter().map(|b| b.unwrap());
buf.assemble(data); buf.assemble(data);
self.addr (self.addr, self.assoc_id)
} }
} }
@ -396,6 +404,17 @@ where
fn assemble(&mut self, data: impl IntoIterator<Item = B>); fn assemble(&mut self, data: impl IntoIterator<Item = B>);
} }
impl<B> Assembler<B> for Vec<u8>
where
B: AsRef<[u8]>,
{
fn assemble(&mut self, data: impl IntoIterator<Item = B>) {
for d in data {
self.extend_from_slice(d.as_ref());
}
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum AssembleError { pub enum AssembleError {
#[error("invalid fragment size")] #[error("invalid fragment size")]

View File

@ -38,6 +38,16 @@ impl<B> Packet<side::Tx, B> {
let Side::Tx(tx) = self.inner else { unreachable!() }; let Side::Tx(tx) = self.inner else { unreachable!() };
Fragments::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload) Fragments::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload)
} }
pub fn assoc_id(&self) -> &u16 {
let Side::Tx(tx) = &self.inner else { unreachable!() };
&tx.assoc_id
}
pub fn addr(&self) -> &Address {
let Side::Tx(tx) = &self.inner else { unreachable!() };
&tx.addr
}
} }
pub struct Rx<B> { pub struct Rx<B> {
@ -91,6 +101,21 @@ where
data, data,
) )
} }
pub fn assoc_id(&self) -> &u16 {
let Side::Rx(rx) = &self.inner else { unreachable!() };
&rx.assoc_id
}
pub fn addr(&self) -> &Address {
let Side::Rx(rx) = &self.inner else { unreachable!() };
&rx.addr
}
pub fn size(&self) -> &u16 {
let Side::Rx(rx) = &self.inner else { unreachable!() };
&rx.size
}
} }
pub struct Fragments<'a, P> pub struct Fragments<'a, P>

View File

@ -33,7 +33,7 @@ impl Command for Connect {
} }
impl From<Connect> for (Address,) { impl From<Connect> for (Address,) {
fn from(connect: Connect) -> Self { fn from(conn: Connect) -> Self {
(connect.addr,) (conn.addr,)
} }
} }

View File

@ -51,7 +51,7 @@ impl Header {
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
2 + match self { 2 + match self {
Self::Authenticate(auth) => auth.len(), Self::Authenticate(auth) => auth.len(),
Self::Connect(connect) => connect.len(), Self::Connect(conn) => conn.len(),
Self::Packet(packet) => packet.len(), Self::Packet(packet) => packet.len(),
Self::Dissociate(dissociate) => dissociate.len(), Self::Dissociate(dissociate) => dissociate.len(),
Self::Heartbeat(heartbeat) => heartbeat.len(), Self::Heartbeat(heartbeat) => heartbeat.len(),