1
0

adding uni stream parsing methods

This commit is contained in:
EAimTY 2023-01-26 20:47:20 +09:00
parent 101d4427eb
commit 10cee5276e
6 changed files with 183 additions and 33 deletions

View File

@ -1,6 +1,10 @@
use self::{marshal::Marshal, side::Side};
use self::{
marshal::Marshal,
side::Side,
unmarshal::{Unmarshal, UnmarshalError},
};
use bytes::Bytes;
use futures_util::{io::Cursor, AsyncRead, AsyncWrite, AsyncWriteExt};
use futures_util::{io::Cursor, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use quinn::{
Connection as QuinnConnection, ConnectionError, RecvStream, SendDatagramError, SendStream,
};
@ -13,7 +17,8 @@ use thiserror::Error;
use tuic::{
model::{
side::{Rx, Tx},
Connect as ConnectModel, Connection as ConnectionModel,
AssembleError, Connect as ConnectModel, Connection as ConnectionModel,
Packet as PacketModel,
},
protocol::{Address, Header},
};
@ -33,26 +38,11 @@ pub mod side {
pub struct Connection<'conn, Side> {
conn: &'conn QuinnConnection,
model: ConnectionModel<Bytes>,
model: ConnectionModel<Vec<u8>>,
_marker: Side,
}
impl<'conn> Connection<'conn, side::Client> {
pub fn new(conn: &'conn QuinnConnection) -> Self {
Self {
conn,
model: ConnectionModel::new(),
_marker: side::Client,
}
}
pub async fn connect(&self, addr: Address) -> Result<Connect, Error> {
let (mut send, recv) = self.conn.open_bi().await?;
let model = self.model.send_connect(addr);
model.header().marshal(&mut send).await?;
Ok(Connect::new(Side::Client(model), send, recv))
}
impl<'conn, Side> Connection<'conn, Side> {
pub async fn packet_native(
&self,
pkt: impl AsRef<[u8]>,
@ -92,6 +82,92 @@ impl<'conn> Connection<'conn, side::Client> {
Ok(())
}
async fn accept_packet_quic(
&self,
model: PacketModel<Rx, Vec<u8>>,
mut recv: &mut RecvStream,
) -> Result<Option<(Vec<u8>, Address, u16)>, Error> {
let mut buf = vec![0; *model.size() as usize];
AsyncReadExt::read_exact(&mut recv, &mut buf).await?;
let mut asm = Vec::new();
Ok(model
.assemble(buf)?
.map(|pkt| pkt.assemble(&mut asm))
.map(|(addr, assoc_id)| (asm, addr, assoc_id)))
}
}
impl<'conn> Connection<'conn, side::Client> {
pub fn new(conn: &'conn QuinnConnection) -> Self {
Self {
conn,
model: ConnectionModel::new(),
_marker: side::Client,
}
}
pub async fn connect(&self, addr: Address) -> Result<Connect, Error> {
let (mut send, recv) = self.conn.open_bi().await?;
let model = self.model.send_connect(addr);
model.header().marshal(&mut send).await?;
Ok(Connect::new(Side::Client(model), send, recv))
}
pub async fn handle_uni_stream(&self, mut recv: RecvStream) -> Result<Task, Error> {
match Header::unmarshal(&mut recv).await? {
Header::Authenticate(_) => Err(Error::BadCommand("authenticate")),
Header::Connect(_) => Err(Error::BadCommand("connect")),
Header::Packet(pkt) => {
let model = self.model.recv_packet(pkt);
Ok(Task::Packet(
self.accept_packet_quic(model, &mut recv).await?,
))
}
Header::Dissociate(_) => Err(Error::BadCommand("dissociate")),
Header::Heartbeat(hb) => {
let _ = self.model.recv_heartbeat(hb);
Ok(Task::Heartbeat)
}
_ => unreachable!(),
}
}
}
impl<'conn> Connection<'conn, side::Server> {
pub fn new(conn: &'conn QuinnConnection) -> Self {
Self {
conn,
model: ConnectionModel::new(),
_marker: side::Server,
}
}
pub async fn handle_uni_stream(&self, mut recv: RecvStream) -> Result<Task, Error> {
match Header::unmarshal(&mut recv).await? {
Header::Authenticate(auth) => {
let model = self.model.recv_authenticate(auth);
Ok(Task::Authenticate(*model.token()))
}
Header::Connect(_) => Err(Error::BadCommand("connect")),
Header::Packet(pkt) => {
let model = self.model.recv_packet(pkt);
Ok(Task::Packet(
self.accept_packet_quic(model, &mut recv).await?,
))
}
Header::Dissociate(dissoc) => {
let _ = self.model.recv_dissociate(dissoc);
Ok(Task::Dissociate)
}
Header::Heartbeat(hb) => {
let _ = self.model.recv_heartbeat(hb);
Ok(Task::Heartbeat)
}
_ => unreachable!(),
}
}
}
pub struct Connect {
@ -112,8 +188,8 @@ impl Connect {
pub fn addr(&self) -> &Address {
match &self.model {
Side::Client(model) => {
let Header::Connect(connect) = model.header() else { unreachable!() };
&connect.addr()
let Header::Connect(conn) = model.header() else { unreachable!() };
&conn.addr()
}
Side::Server(model) => model.addr(),
}
@ -148,6 +224,15 @@ impl AsyncWrite for Connect {
}
}
#[non_exhaustive]
pub enum Task {
Authenticate([u8; 8]),
Connect(Connect),
Packet(Option<(Vec<u8>, Address, u16)>),
Dissociate,
Heartbeat,
}
#[derive(Debug, Error)]
pub enum Error {
#[error(transparent)]
@ -156,4 +241,10 @@ pub enum Error {
Connection(#[from] ConnectionError),
#[error(transparent)]
SendDatagram(#[from] SendDatagramError),
#[error(transparent)]
Unmarshal(#[from] UnmarshalError),
#[error(transparent)]
Assemble(#[from] AssembleError),
#[error("{0}")]
BadCommand(&'static str),
}

View File

@ -1,7 +1,22 @@
use async_trait::async_trait;
use futures_util::AsyncRead;
use thiserror::Error;
use tuic::protocol::Header;
#[async_trait]
trait Unmarshal {
fn unmarshal(&self, s: &mut impl AsyncRead) -> Result<(), ()>;
pub(super) trait Unmarshal
where
Self: Sized,
{
async fn unmarshal(s: &mut impl AsyncRead) -> Result<Self, UnmarshalError>;
}
#[async_trait]
impl Unmarshal for Header {
async fn unmarshal(s: &mut impl AsyncRead) -> Result<Self, UnmarshalError> {
todo!()
}
}
#[derive(Debug, Error)]
pub enum UnmarshalError {}

View File

@ -216,7 +216,7 @@ where
self.sessions
.entry(assoc_id)
.or_insert_with(|| UdpSession::new(self.task_associate_count.register()))
.insert(pkt_id, frag_total, frag_id, size, addr, data)
.insert(assoc_id, pkt_id, frag_total, frag_id, size, addr, data)
}
fn collect_garbage(&mut self, timeout: Duration) {
@ -273,6 +273,7 @@ where
fn insert(
&mut self,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
@ -284,7 +285,7 @@ where
.pkt_buf
.entry(pkt_id)
.or_insert_with(|| PacketBuffer::new(frag_total))
.insert(frag_total, frag_id, size, addr, data)?;
.insert(assoc_id, frag_total, frag_id, size, addr, data)?;
if res.is_some() {
self.pkt_buf.remove(&pkt_id);
@ -325,6 +326,7 @@ where
fn insert(
&mut self,
assoc_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
@ -358,6 +360,7 @@ where
Ok(Some(Assemblable::new(
mem::take(&mut self.buf),
self.addr.take(),
assoc_id,
)))
} else {
Ok(None)
@ -368,23 +371,28 @@ where
pub struct Assemblable<B> {
buf: Vec<Option<B>>,
addr: Address,
assoc_id: u16,
}
impl<B> Assemblable<B>
where
B: AsRef<[u8]>,
{
fn new(buf: Vec<Option<B>>, addr: Address) -> Self {
Self { buf, addr }
fn new(buf: Vec<Option<B>>, addr: Address, assoc_id: u16) -> Self {
Self {
buf,
addr,
assoc_id,
}
}
pub fn assemble<A>(self, buf: &mut A) -> Address
pub fn assemble<A>(self, buf: &mut A) -> (Address, u16)
where
A: Assembler<B>,
{
let data = self.buf.into_iter().map(|b| b.unwrap());
buf.assemble(data);
self.addr
(self.addr, self.assoc_id)
}
}
@ -396,6 +404,17 @@ where
fn assemble(&mut self, data: impl IntoIterator<Item = B>);
}
impl<B> Assembler<B> for Vec<u8>
where
B: AsRef<[u8]>,
{
fn assemble(&mut self, data: impl IntoIterator<Item = B>) {
for d in data {
self.extend_from_slice(d.as_ref());
}
}
}
#[derive(Debug, Error)]
pub enum AssembleError {
#[error("invalid fragment size")]

View File

@ -38,6 +38,16 @@ impl<B> Packet<side::Tx, B> {
let Side::Tx(tx) = self.inner else { unreachable!() };
Fragments::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload)
}
pub fn assoc_id(&self) -> &u16 {
let Side::Tx(tx) = &self.inner else { unreachable!() };
&tx.assoc_id
}
pub fn addr(&self) -> &Address {
let Side::Tx(tx) = &self.inner else { unreachable!() };
&tx.addr
}
}
pub struct Rx<B> {
@ -91,6 +101,21 @@ where
data,
)
}
pub fn assoc_id(&self) -> &u16 {
let Side::Rx(rx) = &self.inner else { unreachable!() };
&rx.assoc_id
}
pub fn addr(&self) -> &Address {
let Side::Rx(rx) = &self.inner else { unreachable!() };
&rx.addr
}
pub fn size(&self) -> &u16 {
let Side::Rx(rx) = &self.inner else { unreachable!() };
&rx.size
}
}
pub struct Fragments<'a, P>

View File

@ -33,7 +33,7 @@ impl Command for Connect {
}
impl From<Connect> for (Address,) {
fn from(connect: Connect) -> Self {
(connect.addr,)
fn from(conn: Connect) -> Self {
(conn.addr,)
}
}

View File

@ -51,7 +51,7 @@ impl Header {
pub fn len(&self) -> usize {
2 + match self {
Self::Authenticate(auth) => auth.len(),
Self::Connect(connect) => connect.len(),
Self::Connect(conn) => conn.len(),
Self::Packet(packet) => packet.len(),
Self::Dissociate(dissociate) => dissociate.len(),
Self::Heartbeat(heartbeat) => heartbeat.len(),