From 6229a08e61f227409901abcc2a51e27b8c0952cd Mon Sep 17 00:00:00 2001 From: EAimTY Date: Fri, 27 Jan 2023 13:10:30 +0900 Subject: [PATCH] implement async (un)marshal for protocol --- tuic-quinn/src/lib.rs | 8 +- tuic/Cargo.toml | 5 +- tuic/src/lib.rs | 2 +- tuic/src/marshal.rs | 86 ++++++++++++++++-- tuic/src/model/authenticate.rs | 4 +- tuic/src/model/dissociate.rs | 4 +- tuic/src/model/packet.rs | 12 +-- tuic/src/protocol/address.rs | 95 -------------------- tuic/src/protocol/authenticate.rs | 14 ++- tuic/src/protocol/connect.rs | 10 +-- tuic/src/protocol/dissociate.rs | 14 ++- tuic/src/protocol/heartbeat.rs | 10 +-- tuic/src/protocol/mod.rs | 114 +++++++++++++++++++++--- tuic/src/protocol/packet.rs | 40 ++++----- tuic/src/unmarshal.rs | 140 ++++++++++++++++++++++++++++-- 15 files changed, 372 insertions(+), 186 deletions(-) delete mode 100644 tuic/src/protocol/address.rs diff --git a/tuic-quinn/src/lib.rs b/tuic-quinn/src/lib.rs index 3c8b2ea..d84a33c 100644 --- a/tuic-quinn/src/lib.rs +++ b/tuic-quinn/src/lib.rs @@ -87,7 +87,7 @@ impl<'conn, Side> Connection<'conn, Side> { model: PacketModel, mut recv: &mut RecvStream, ) -> Result, Error> { - let mut buf = vec![0; *model.size() as usize]; + let mut buf = vec![0; model.size() as usize]; AsyncReadExt::read_exact(&mut recv, &mut buf).await?; let mut asm = Vec::new(); @@ -183,7 +183,7 @@ impl<'conn> Connection<'conn, side::Client> { Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); let pos = dg.position() as usize; - let buf = dg.into_inner().slice(pos..pos + *model.size() as usize); + let buf = dg.into_inner().slice(pos..pos + model.size() as usize); Ok(Task::Packet(self.accept_packet_native(model, buf).await?)) } Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), @@ -206,7 +206,7 @@ impl<'conn> Connection<'conn, side::Server> { match Header::async_unmarshal(&mut recv).await? { Header::Authenticate(auth) => { let model = self.model.recv_authenticate(auth); - Ok(Task::Authenticate(*model.token())) + Ok(Task::Authenticate(model.token())) } Header::Connect(_) => Err(Error::BadCommand("connect")), Header::Packet(pkt) => { @@ -251,7 +251,7 @@ impl<'conn> Connection<'conn, side::Server> { Header::Packet(pkt) => { let model = self.model.recv_packet(pkt); let pos = dg.position() as usize; - let buf = dg.into_inner().slice(pos..pos + *model.size() as usize); + let buf = dg.into_inner().slice(pos..pos + model.size() as usize); Ok(Task::Packet(self.accept_packet_native(model, buf).await?)) } Header::Dissociate(_) => Err(Error::BadCommand("dissociate")), diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index f19b12a..2774890 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -4,11 +4,12 @@ version = "0.1.0" edition = "2021" [features] -async_marshal = ["futures-io"] +async_marshal = ["bytes", "futures-util"] model = ["parking_lot", "thiserror"] [dependencies] -futures-io = { version = "0.3.25", default-features = false, features = ["std"], optional = true } +bytes = { version = "1.3.0", default-features = false, features = ["std"], optional = true } +futures-util = { version = "0.3.25", default-features = false, features = ["io", "std"], optional = true } parking_lot = { version = "0.12.1", default-features = false, optional = true } thiserror = { version = "1.0.38", default-features = false, optional = true } diff --git a/tuic/src/lib.rs b/tuic/src/lib.rs index 05c95ad..cfa0ac6 100644 --- a/tuic/src/lib.rs +++ b/tuic/src/lib.rs @@ -3,7 +3,7 @@ mod protocol; pub use self::protocol::{ - Address, Authenticate, Command, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, + Address, Authenticate, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, }; #[cfg(feature = "async_marshal")] diff --git a/tuic/src/marshal.rs b/tuic/src/marshal.rs index 2cbeba4..2c756bf 100644 --- a/tuic/src/marshal.rs +++ b/tuic/src/marshal.rs @@ -1,9 +1,85 @@ -use crate::protocol::Header; -use futures_io::AsyncWrite; -use std::io::Error as IoError; +use crate::protocol::{ + Address, Authenticate, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, +}; +use bytes::BufMut; +use futures_util::{AsyncWrite, AsyncWriteExt}; +use std::{io::Error as IoError, net::SocketAddr}; impl Header { - pub async fn async_marshal(&self, s: &mut impl AsyncWrite) -> Result<(), IoError> { - todo!() + pub async fn async_marshal(&self, s: &mut (impl AsyncWrite + Unpin)) -> Result<(), IoError> { + let mut buf = vec![0; self.len()]; + self.write(&mut buf); + s.write_all(&buf).await + } + + pub fn write(&self, buf: &mut impl BufMut) { + buf.put_u8(VERSION); + buf.put_u8(self.type_code()); + + match self { + Self::Authenticate(auth) => auth.write(buf), + Self::Connect(conn) => conn.write(buf), + Self::Packet(packet) => packet.write(buf), + Self::Dissociate(dissociate) => dissociate.write(buf), + Self::Heartbeat(heartbeat) => heartbeat.write(buf), + } } } + +impl Address { + fn write(&self, buf: &mut impl BufMut) { + buf.put_u8(self.type_code()); + + match self { + Self::None => {} + Self::DomainAddress(domain, port) => { + buf.put_u8(domain.len() as u8); + buf.put_slice(domain.as_bytes()); + buf.put_u16(*port); + } + Self::SocketAddress(SocketAddr::V4(addr)) => { + buf.put_slice(&addr.ip().octets()); + buf.put_u16(addr.port()); + } + Self::SocketAddress(SocketAddr::V6(addr)) => { + for seg in addr.ip().segments() { + buf.put_u16(seg); + } + buf.put_u16(addr.port()); + } + } + } +} + +impl Authenticate { + fn write(&self, buf: &mut impl BufMut) { + buf.put_slice(&self.token()); + } +} + +impl Connect { + fn write(&self, buf: &mut impl BufMut) { + self.addr().write(buf); + } +} + +impl Packet { + fn write(&self, buf: &mut impl BufMut) { + buf.put_u16(self.assoc_id()); + buf.put_u16(self.pkt_id()); + buf.put_u8(self.frag_total()); + buf.put_u8(self.frag_id()); + buf.put_u16(self.size()); + self.addr().write(buf); + } +} + +impl Dissociate { + fn write(&self, buf: &mut impl BufMut) { + buf.put_u16(self.assoc_id()); + } +} + +impl Heartbeat { + fn write(&self, _buf: &mut impl BufMut) {} +} diff --git a/tuic/src/model/authenticate.rs b/tuic/src/model/authenticate.rs index ca54f39..e02940d 100644 --- a/tuic/src/model/authenticate.rs +++ b/tuic/src/model/authenticate.rs @@ -38,8 +38,8 @@ impl Authenticate { } } - pub fn token(&self) -> &[u8; 8] { + pub fn token(&self) -> [u8; 8] { let Side::Rx(rx) = &self.inner else { unreachable!() }; - &rx.token + rx.token } } diff --git a/tuic/src/model/dissociate.rs b/tuic/src/model/dissociate.rs index e6e4072..79321b0 100644 --- a/tuic/src/model/dissociate.rs +++ b/tuic/src/model/dissociate.rs @@ -38,8 +38,8 @@ impl Dissociate { } } - pub fn assoc_id(&self) -> &u16 { + pub fn assoc_id(&self) -> u16 { let Side::Rx(rx) = &self.inner else { unreachable!() }; - &rx.assoc_id + rx.assoc_id } } diff --git a/tuic/src/model/packet.rs b/tuic/src/model/packet.rs index ea336e3..6bc5fe1 100644 --- a/tuic/src/model/packet.rs +++ b/tuic/src/model/packet.rs @@ -39,9 +39,9 @@ impl Packet { Fragments::new(tx.assoc_id, tx.pkt_id, tx.addr, tx.max_pkt_size, payload) } - pub fn assoc_id(&self) -> &u16 { + pub fn assoc_id(&self) -> u16 { let Side::Tx(tx) = &self.inner else { unreachable!() }; - &tx.assoc_id + tx.assoc_id } pub fn addr(&self) -> &Address { @@ -102,9 +102,9 @@ where ) } - pub fn assoc_id(&self) -> &u16 { + pub fn assoc_id(&self) -> u16 { let Side::Rx(rx) = &self.inner else { unreachable!() }; - &rx.assoc_id + rx.assoc_id } pub fn addr(&self) -> &Address { @@ -112,9 +112,9 @@ where &rx.addr } - pub fn size(&self) -> &u16 { + pub fn size(&self) -> u16 { let Side::Rx(rx) = &self.inner else { unreachable!() }; - &rx.size + rx.size } } diff --git a/tuic/src/protocol/address.rs b/tuic/src/protocol/address.rs deleted file mode 100644 index 78eaefc..0000000 --- a/tuic/src/protocol/address.rs +++ /dev/null @@ -1,95 +0,0 @@ -use std::{ - fmt::{Display, Formatter, Result as FmtResult}, - mem, - net::SocketAddr, -}; - -/// Address -/// -/// ```plain -/// +------+----------+ -/// | TYPE | ADDR | -/// +------+----------+ -/// | 1 | Variable | -/// +------+----------+ -/// ``` -/// -/// The address type can be one of the following: -/// -/// - 0xff: None -/// - 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name) -/// - 0x01: IPv4 address -/// - 0x02: IPv6 address -/// -/// The port number is encoded in 2 bytes after the Domain name / IP address. -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] -pub enum Address { - None, - DomainAddress(String, u16), - SocketAddress(SocketAddr), -} - -impl Address { - pub const TYPE_CODE_NONE: u8 = 0xff; - pub const TYPE_CODE_DOMAIN: u8 = 0x00; - pub const TYPE_CODE_IPV4: u8 = 0x01; - pub const TYPE_CODE_IPV6: u8 = 0x02; - - pub fn type_code(&self) -> u8 { - match self { - Self::None => Self::TYPE_CODE_NONE, - Self::DomainAddress(_, _) => Self::TYPE_CODE_DOMAIN, - Self::SocketAddress(addr) => match addr { - SocketAddr::V4(_) => Self::TYPE_CODE_IPV4, - SocketAddr::V6(_) => Self::TYPE_CODE_IPV6, - }, - } - } - - pub fn len(&self) -> usize { - 1 + match self { - Address::None => 0, - Address::DomainAddress(addr, _) => 1 + addr.len() + 2, - Address::SocketAddress(addr) => match addr { - SocketAddr::V4(_) => 1 * 4 + 2, - SocketAddr::V6(_) => 2 * 8 + 2, - }, - } - } - - pub fn take(&mut self) -> Self { - mem::take(self) - } - - pub fn is_none(&self) -> bool { - matches!(self, Self::None) - } - - pub fn is_domain(&self) -> bool { - matches!(self, Self::DomainAddress(_, _)) - } - - pub fn is_ipv4(&self) -> bool { - matches!(self, Self::SocketAddress(SocketAddr::V4(_))) - } - - pub fn is_ipv6(&self) -> bool { - matches!(self, Self::SocketAddress(SocketAddr::V6(_))) - } -} - -impl Display for Address { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - match self { - Self::None => write!(f, "none"), - Self::DomainAddress(addr, port) => write!(f, "{addr}:{port}"), - Self::SocketAddress(addr) => write!(f, "{addr}"), - } - } -} - -impl Default for Address { - fn default() -> Self { - Self::None - } -} diff --git a/tuic/src/protocol/authenticate.rs b/tuic/src/protocol/authenticate.rs index eebbafd..b17f791 100644 --- a/tuic/src/protocol/authenticate.rs +++ b/tuic/src/protocol/authenticate.rs @@ -1,5 +1,3 @@ -use super::Command; - // +-------+ // | TOKEN | // +-------+ @@ -11,23 +9,21 @@ pub struct Authenticate { } impl Authenticate { - pub(super) const TYPE_CODE: u8 = 0x00; + const TYPE_CODE: u8 = 0x00; pub const fn new(token: [u8; 8]) -> Self { Self { token } } - pub fn token(&self) -> &[u8; 8] { - &self.token + pub fn token(&self) -> [u8; 8] { + self.token } -} -impl Command for Authenticate { - fn type_code() -> u8 { + pub const fn type_code() -> u8 { Self::TYPE_CODE } - fn len(&self) -> usize { + pub fn len(&self) -> usize { 8 } } diff --git a/tuic/src/protocol/connect.rs b/tuic/src/protocol/connect.rs index 20a4d96..6139c83 100644 --- a/tuic/src/protocol/connect.rs +++ b/tuic/src/protocol/connect.rs @@ -1,4 +1,4 @@ -use super::{Address, Command}; +use super::Address; // +----------+ // | ADDR | @@ -11,7 +11,7 @@ pub struct Connect { } impl Connect { - pub(super) const TYPE_CODE: u8 = 0x01; + const TYPE_CODE: u8 = 0x01; pub const fn new(addr: Address) -> Self { Self { addr } @@ -20,14 +20,12 @@ impl Connect { pub fn addr(&self) -> &Address { &self.addr } -} -impl Command for Connect { - fn type_code() -> u8 { + pub const fn type_code() -> u8 { Self::TYPE_CODE } - fn len(&self) -> usize { + pub fn len(&self) -> usize { self.addr.len() } } diff --git a/tuic/src/protocol/dissociate.rs b/tuic/src/protocol/dissociate.rs index 94734f5..86caa19 100644 --- a/tuic/src/protocol/dissociate.rs +++ b/tuic/src/protocol/dissociate.rs @@ -1,5 +1,3 @@ -use super::Command; - // +----------+ // | ASSOC_ID | // +----------+ @@ -11,23 +9,21 @@ pub struct Dissociate { } impl Dissociate { - pub(super) const TYPE_CODE: u8 = 0x03; + const TYPE_CODE: u8 = 0x03; pub const fn new(assoc_id: u16) -> Self { Self { assoc_id } } - pub fn assoc_id(&self) -> &u16 { - &self.assoc_id + pub fn assoc_id(&self) -> u16 { + self.assoc_id } -} -impl Command for Dissociate { - fn type_code() -> u8 { + pub const fn type_code() -> u8 { Self::TYPE_CODE } - fn len(&self) -> usize { + pub fn len(&self) -> usize { 2 } } diff --git a/tuic/src/protocol/heartbeat.rs b/tuic/src/protocol/heartbeat.rs index 91087e8..dd8143a 100644 --- a/tuic/src/protocol/heartbeat.rs +++ b/tuic/src/protocol/heartbeat.rs @@ -1,5 +1,3 @@ -use super::Command; - // +-+ // | | // +-+ @@ -9,19 +7,17 @@ use super::Command; pub struct Heartbeat; impl Heartbeat { - pub(super) const TYPE_CODE: u8 = 0x04; + const TYPE_CODE: u8 = 0x04; pub const fn new() -> Self { Self } -} -impl Command for Heartbeat { - fn type_code() -> u8 { + pub const fn type_code() -> u8 { Self::TYPE_CODE } - fn len(&self) -> usize { + pub fn len(&self) -> usize { 0 } } diff --git a/tuic/src/protocol/mod.rs b/tuic/src/protocol/mod.rs index c890893..3ade4cb 100644 --- a/tuic/src/protocol/mod.rs +++ b/tuic/src/protocol/mod.rs @@ -1,4 +1,9 @@ -mod address; +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + mem, + net::SocketAddr, +}; + mod authenticate; mod connect; mod dissociate; @@ -6,8 +11,8 @@ mod heartbeat; mod packet; pub use self::{ - address::Address, authenticate::Authenticate, connect::Connect, dissociate::Dissociate, - heartbeat::Heartbeat, packet::Packet, + authenticate::Authenticate, connect::Connect, dissociate::Dissociate, heartbeat::Heartbeat, + packet::Packet, }; pub const VERSION: u8 = 0x05; @@ -32,13 +37,13 @@ pub enum Header { } impl Header { - pub const TYPE_CODE_AUTHENTICATE: u8 = Authenticate::TYPE_CODE; - pub const TYPE_CODE_CONNECT: u8 = Connect::TYPE_CODE; - pub const TYPE_CODE_PACKET: u8 = Packet::TYPE_CODE; - pub const TYPE_CODE_DISSOCIATE: u8 = Dissociate::TYPE_CODE; - pub const TYPE_CODE_HEARTBEAT: u8 = Heartbeat::TYPE_CODE; + pub const TYPE_CODE_AUTHENTICATE: u8 = Authenticate::type_code(); + pub const TYPE_CODE_CONNECT: u8 = Connect::type_code(); + pub const TYPE_CODE_PACKET: u8 = Packet::type_code(); + pub const TYPE_CODE_DISSOCIATE: u8 = Dissociate::type_code(); + pub const TYPE_CODE_HEARTBEAT: u8 = Heartbeat::type_code(); - pub fn type_code(&self) -> u8 { + pub const fn type_code(&self) -> u8 { match self { Self::Authenticate(_) => Authenticate::type_code(), Self::Connect(_) => Connect::type_code(), @@ -59,7 +64,92 @@ impl Header { } } -pub trait Command { - fn type_code() -> u8; - fn len(&self) -> usize; +/// Address +/// +/// ```plain +/// +------+----------+ +/// | TYPE | ADDR | +/// +------+----------+ +/// | 1 | Variable | +/// +------+----------+ +/// ``` +/// +/// The address type can be one of the following: +/// +/// - 0xff: None +/// - 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name) +/// - 0x01: IPv4 address +/// - 0x02: IPv6 address +/// +/// The port number is encoded in 2 bytes after the Domain name / IP address. +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +pub enum Address { + None, + DomainAddress(String, u16), + SocketAddress(SocketAddr), +} + +impl Address { + pub const TYPE_CODE_NONE: u8 = 0xff; + pub const TYPE_CODE_DOMAIN: u8 = 0x00; + pub const TYPE_CODE_IPV4: u8 = 0x01; + pub const TYPE_CODE_IPV6: u8 = 0x02; + + pub const fn type_code(&self) -> u8 { + match self { + Self::None => Self::TYPE_CODE_NONE, + Self::DomainAddress(_, _) => Self::TYPE_CODE_DOMAIN, + Self::SocketAddress(addr) => match addr { + SocketAddr::V4(_) => Self::TYPE_CODE_IPV4, + SocketAddr::V6(_) => Self::TYPE_CODE_IPV6, + }, + } + } + + pub fn len(&self) -> usize { + 1 + match self { + Address::None => 0, + Address::DomainAddress(addr, _) => 1 + addr.len() + 2, + Address::SocketAddress(addr) => match addr { + SocketAddr::V4(_) => 1 * 4 + 2, + SocketAddr::V6(_) => 2 * 8 + 2, + }, + } + } + + pub fn take(&mut self) -> Self { + mem::take(self) + } + + pub fn is_none(&self) -> bool { + matches!(self, Self::None) + } + + pub fn is_domain(&self) -> bool { + matches!(self, Self::DomainAddress(_, _)) + } + + pub fn is_ipv4(&self) -> bool { + matches!(self, Self::SocketAddress(SocketAddr::V4(_))) + } + + pub fn is_ipv6(&self) -> bool { + matches!(self, Self::SocketAddress(SocketAddr::V6(_))) + } +} + +impl Display for Address { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + match self { + Self::None => write!(f, "none"), + Self::DomainAddress(addr, port) => write!(f, "{addr}:{port}"), + Self::SocketAddress(addr) => write!(f, "{addr}"), + } + } +} + +impl Default for Address { + fn default() -> Self { + Self::None + } } diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs index 71c79e5..d9c6a5d 100644 --- a/tuic/src/protocol/packet.rs +++ b/tuic/src/protocol/packet.rs @@ -1,4 +1,4 @@ -use super::{Address, Command}; +use super::Address; // +----------+--------+------------+---------+------+----------+ // | ASSOC_ID | PKT_ID | FRAG_TOTAL | FRAG_ID | SIZE | ADDR | @@ -16,7 +16,7 @@ pub struct Packet { } impl Packet { - pub(super) const TYPE_CODE: u8 = 0x02; + const TYPE_CODE: u8 = 0x02; pub const fn new( assoc_id: u16, @@ -36,42 +36,40 @@ impl Packet { } } - pub fn assoc_id(&self) -> &u16 { - &self.assoc_id + pub fn assoc_id(&self) -> u16 { + self.assoc_id } - pub fn pkt_id(&self) -> &u16 { - &self.pkt_id + pub fn pkt_id(&self) -> u16 { + self.pkt_id } - pub fn frag_total(&self) -> &u8 { - &self.frag_total + pub fn frag_total(&self) -> u8 { + self.frag_total } - pub fn frag_id(&self) -> &u8 { - &self.frag_id + pub fn frag_id(&self) -> u8 { + self.frag_id } - pub fn size(&self) -> &u16 { - &self.size + pub fn size(&self) -> u16 { + self.size } pub fn addr(&self) -> &Address { &self.addr } - pub const fn len_without_addr() -> usize { - 2 + 2 + 1 + 1 + 2 - } -} - -impl Command for Packet { - fn type_code() -> u8 { + pub const fn type_code() -> u8 { Self::TYPE_CODE } - fn len(&self) -> usize { - 2 + 2 + 1 + 1 + 2 + self.addr.len() + pub fn len(&self) -> usize { + Self::len_without_addr() + self.addr.len() + } + + pub const fn len_without_addr() -> usize { + 2 + 2 + 1 + 1 + 2 } } diff --git a/tuic/src/unmarshal.rs b/tuic/src/unmarshal.rs index aa1fd56..de2bb1a 100644 --- a/tuic/src/unmarshal.rs +++ b/tuic/src/unmarshal.rs @@ -1,12 +1,142 @@ -use crate::protocol::Header; -use futures_io::AsyncRead; +use crate::protocol::{ + Address, Authenticate, Connect, Dissociate, Header, Heartbeat, Packet, VERSION, +}; +use futures_util::{AsyncRead, AsyncReadExt}; +use std::{io::Error as IoError, net::SocketAddr, string::FromUtf8Error}; use thiserror::Error; impl Header { - pub async fn async_unmarshal(s: &mut impl AsyncRead) -> Result { - todo!() + pub async fn async_unmarshal(s: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buf = [0; 1]; + s.read_exact(&mut buf).await?; + let ver = buf[0]; + + if ver != VERSION { + return Err(UnmarshalError::InvalidVersion(ver)); + } + + let mut buf = [0; 1]; + s.read_exact(&mut buf).await?; + let cmd = buf[0]; + + match cmd { + Header::TYPE_CODE_AUTHENTICATE => { + Authenticate::async_read(s).await.map(Self::Authenticate) + } + Header::TYPE_CODE_CONNECT => Connect::async_read(s).await.map(Self::Connect), + Header::TYPE_CODE_PACKET => Packet::async_read(s).await.map(Self::Packet), + Header::TYPE_CODE_DISSOCIATE => Dissociate::async_read(s).await.map(Self::Dissociate), + Header::TYPE_CODE_HEARTBEAT => Heartbeat::async_read(s).await.map(Self::Heartbeat), + _ => Err(UnmarshalError::InvalidCommand(cmd)), + } + } +} + +impl Address { + async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buf = [0; 1]; + s.read_exact(&mut buf).await?; + let type_code = buf[0]; + + match type_code { + Address::TYPE_CODE_NONE => Ok(Self::None), + Address::TYPE_CODE_DOMAIN => { + let mut buf = [0; 1]; + s.read_exact(&mut buf).await?; + let len = buf[0] as usize; + + let mut buf = vec![0; len + 2]; + s.read_exact(&mut buf).await?; + let port = u16::from_be_bytes([buf[len], buf[len + 1]]); + buf.truncate(len); + let domain = String::from_utf8(buf)?; + + Ok(Self::DomainAddress(domain, port)) + } + Address::TYPE_CODE_IPV4 => { + let mut buf = [0; 6]; + s.read_exact(&mut buf).await?; + let ip = [buf[0], buf[1], buf[2], buf[3]]; + let port = u16::from_be_bytes([buf[4], buf[5]]); + Ok(Self::SocketAddress(SocketAddr::from((ip, port)))) + } + Address::TYPE_CODE_IPV6 => { + let mut buf = [0; 18]; + s.read_exact(&mut buf).await?; + let ip = [ + u16::from_be_bytes([buf[0], buf[1]]), + u16::from_be_bytes([buf[2], buf[3]]), + u16::from_be_bytes([buf[4], buf[5]]), + u16::from_be_bytes([buf[6], buf[7]]), + u16::from_be_bytes([buf[8], buf[9]]), + u16::from_be_bytes([buf[10], buf[11]]), + u16::from_be_bytes([buf[12], buf[13]]), + u16::from_be_bytes([buf[14], buf[15]]), + ]; + let port = u16::from_be_bytes([buf[16], buf[17]]); + + Ok(Self::SocketAddress(SocketAddr::from((ip, port)))) + } + _ => Err(UnmarshalError::InvalidAddressType(type_code)), + } + } +} + +impl Authenticate { + async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buf = [0; 8]; + s.read_exact(&mut buf).await?; + Ok(Self::new(buf)) + } +} + +impl Connect { + async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { + Ok(Self::new(Address::async_read(s).await?)) + } +} + +impl Packet { + async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buf = [0; 8]; + s.read_exact(&mut buf).await?; + + let assoc_id = u16::from_be_bytes([buf[0], buf[1]]); + let pkt_id = u16::from_be_bytes([buf[2], buf[3]]); + let frag_total = buf[4]; + let frag_id = buf[5]; + let size = u16::from_be_bytes([buf[6], buf[7]]); + let addr = Address::async_read(s).await?; + + Ok(Self::new(assoc_id, pkt_id, frag_total, frag_id, size, addr)) + } +} + +impl Dissociate { + async fn async_read(s: &mut (impl AsyncRead + Unpin)) -> Result { + let mut buf = [0; 2]; + s.read_exact(&mut buf).await?; + let assoc_id = u16::from_be_bytes(buf); + Ok(Self::new(assoc_id)) + } +} + +impl Heartbeat { + async fn async_read(_s: &mut (impl AsyncRead + Unpin)) -> Result { + Ok(Self::new()) } } #[derive(Debug, Error)] -pub enum UnmarshalError {} +pub enum UnmarshalError { + #[error(transparent)] + Io(#[from] IoError), + #[error("invalid version: {0}")] + InvalidVersion(u8), + #[error("invalid command: {0}")] + InvalidCommand(u8), + #[error("invalid address type: {0}")] + InvalidAddressType(u8), + #[error("address parsing error: {0}")] + AddressParse(#[from] FromUtf8Error), +}