diff --git a/tuic-quinn/Cargo.toml b/tuic-quinn/Cargo.toml index d98ceed..f990737 100644 --- a/tuic-quinn/Cargo.toml +++ b/tuic-quinn/Cargo.toml @@ -3,6 +3,10 @@ name = "tuic-quinn" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] +async-trait = { version = "0.1.62", default-features = false } +bytes = { version = "1.3.0", default-features = false, features = ["std"] } +futures-util = { version = "0.3.25", default-features = false, features = ["io", "std"] } +quinn = { version = "0.9.3", default-features = false, features = ["futures-io"] } +thiserror = { version = "1.0.38", default-features = false } +tuic = { path = "../tuic", default-features = false, features = ["model"] } diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 7d12d9a..e9a5249 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -1,14 +1,159 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right -} +use self::{marshal::Marshal, side::Side}; +use bytes::Bytes; +use futures_util::{io::Cursor, AsyncRead, AsyncWrite, AsyncWriteExt}; +use quinn::{ + Connection as QuinnConnection, ConnectionError, RecvStream, SendDatagramError, SendStream, +}; +use std::{ + io::Error as IoError, + pin::Pin, + task::{Context, Poll}, +}; +use thiserror::Error; +use tuic::{ + model::{ + side::{Rx, Tx}, + Connect as ConnectModel, Connection as ConnectionModel, + }, + protocol::{Address, Header}, +}; -#[cfg(test)] -mod tests { - use super::*; +mod marshal; +mod unmarshal; - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); +pub mod side { + pub struct Client; + pub struct Server; + + pub(super) enum Side { + Client(C), + Server(S), } } + +pub struct Connection<'conn, Side> { + conn: &'conn QuinnConnection, + model: ConnectionModel, + _marker: Side, +} + +impl<'conn> Connection<'conn, side::Client> { + pub fn new(conn: &'conn QuinnConnection) -> Self { + Self { + conn, + model: ConnectionModel::new(), + _marker: side::Client, + } + } + + pub async fn connect(&self, addr: Address) -> Result { + let (mut send, recv) = self.conn.open_bi().await?; + let model = self.model.send_connect(addr); + model.header().marshal(&mut send).await?; + Ok(Connect::new(Side::Client(model), send, recv)) + } + + pub async fn packet_native( + &self, + pkt: impl AsRef<[u8]>, + assoc_id: u16, + addr: Address, + ) -> Result<(), Error> { + let Some(max_pkt_size) = self.conn.max_datagram_size() else { + return Err(Error::SendDatagram(SendDatagramError::Disabled)); + }; + + let model = self.model.send_packet(assoc_id, addr, max_pkt_size); + + for (header, frag) in model.into_fragments(pkt) { + let mut buf = Cursor::new(vec![0; header.len() + frag.len()]); + header.marshal(&mut buf).await?; + buf.write_all(frag).await.unwrap(); + self.conn.send_datagram(Bytes::from(buf.into_inner()))?; + } + + Ok(()) + } + + pub async fn packet_quic( + &self, + pkt: impl AsRef<[u8]>, + assoc_id: u16, + addr: Address, + ) -> Result<(), Error> { + let model = self.model.send_packet(assoc_id, addr, u16::MAX as usize); + let mut frags = model.into_fragments(pkt); + let (header, frag) = frags.next().unwrap(); + assert!(frags.next().is_none()); + + let mut send = self.conn.open_uni().await?; + header.marshal(&mut send).await?; + AsyncWriteExt::write_all(&mut send, frag).await?; + + Ok(()) + } +} + +pub struct Connect { + model: Side, ConnectModel>, + send: SendStream, + recv: RecvStream, +} + +impl Connect { + fn new( + model: Side, ConnectModel>, + send: SendStream, + recv: RecvStream, + ) -> Self { + Self { model, send, recv } + } + + pub fn addr(&self) -> &Address { + match &self.model { + Side::Client(model) => { + let Header::Connect(connect) = model.header() else { unreachable!() }; + &connect.addr() + } + Side::Server(model) => model.addr(), + } + } +} + +impl AsyncRead for Connect { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + AsyncRead::poll_read(Pin::new(&mut self.get_mut().recv), cx, buf) + } +} + +impl AsyncWrite for Connect { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(Pin::new(&mut self.get_mut().send), cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_flush(Pin::new(&mut self.get_mut().send), cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + AsyncWrite::poll_close(Pin::new(&mut self.get_mut().send), cx) + } +} + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] IoError), + #[error(transparent)] + Connection(#[from] ConnectionError), + #[error(transparent)] + SendDatagram(#[from] SendDatagramError), +} diff --git a/tuic-quinn/src/marshal.rs b/tuic-quinn/src/marshal.rs new file mode 100644 index 0000000..66c2d85 --- /dev/null +++ b/tuic-quinn/src/marshal.rs @@ -0,0 +1,16 @@ +use async_trait::async_trait; +use futures_util::AsyncWrite; +use std::io::Error as IoError; +use tuic::protocol::Header; + +#[async_trait] +pub(super) trait Marshal { + async fn marshal(&self, s: &mut impl AsyncWrite) -> Result<(), IoError>; +} + +#[async_trait] +impl Marshal for Header { + async fn marshal(&self, s: &mut impl AsyncWrite) -> Result<(), IoError> { + todo!() + } +} diff --git a/tuic-quinn/src/unmarshal.rs b/tuic-quinn/src/unmarshal.rs new file mode 100644 index 0000000..c7e6e11 --- /dev/null +++ b/tuic-quinn/src/unmarshal.rs @@ -0,0 +1,7 @@ +use async_trait::async_trait; +use futures_util::AsyncRead; + +#[async_trait] +trait Unmarshal { + fn unmarshal(&self, s: &mut impl AsyncRead) -> Result<(), ()>; +}