From 358fcd95f7b1b8e03a73eba092176346ba489634 Mon Sep 17 00:00:00 2001 From: EAimTY Date: Wed, 25 Jan 2023 17:34:59 +0900 Subject: [PATCH] packet assembling mechanism --- tuic/Cargo.toml | 3 +- tuic/src/protocol/address.rs | 8 +- tuic/src/protocol/heartbeat.rs | 2 +- tuic/src/prototype/authenticate.rs | 7 +- tuic/src/prototype/connect.rs | 15 +- tuic/src/prototype/dissociate.rs | 7 +- tuic/src/prototype/heartbeat.rs | 7 +- tuic/src/prototype/mod.rs | 270 +++++++++++++++++++++++++---- tuic/src/prototype/packet.rs | 77 +++++++- 9 files changed, 329 insertions(+), 67 deletions(-) diff --git a/tuic/Cargo.toml b/tuic/Cargo.toml index cdefb9c..3aab101 100644 --- a/tuic/Cargo.toml +++ b/tuic/Cargo.toml @@ -4,10 +4,11 @@ version = "0.1.0" edition = "2021" [features] -prototype = ["parking_lot"] +prototype = ["parking_lot", "thiserror"] [dependencies] parking_lot = { version = "0.12.1", default-features = false, optional = true } +thiserror = { version = "1.0.38", default-features = false, optional = true } [dev-dependencies] tuic = { path = ".", features = ["prototype"] } diff --git a/tuic/src/protocol/address.rs b/tuic/src/protocol/address.rs index a45ac40..78eaefc 100644 --- a/tuic/src/protocol/address.rs +++ b/tuic/src/protocol/address.rs @@ -16,10 +16,10 @@ use std::{ /// /// The address type can be one of the following: /// -/// 0xff: None -/// 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name) -/// 0x01: IPv4 address -/// 0x02: IPv6 address +/// - 0xff: None +/// - 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name) +/// - 0x01: IPv4 address +/// - 0x02: IPv6 address /// /// The port number is encoded in 2 bytes after the Domain name / IP address. #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] diff --git a/tuic/src/protocol/heartbeat.rs b/tuic/src/protocol/heartbeat.rs index 4694444..91087e8 100644 --- a/tuic/src/protocol/heartbeat.rs +++ b/tuic/src/protocol/heartbeat.rs @@ -27,7 +27,7 @@ impl Command for Heartbeat { } impl From for () { - fn from(hb: Heartbeat) -> Self { + fn from(_: Heartbeat) -> Self { () } } diff --git a/tuic/src/prototype/authenticate.rs b/tuic/src/prototype/authenticate.rs index be03ee8..118ee20 100644 --- a/tuic/src/prototype/authenticate.rs +++ b/tuic/src/prototype/authenticate.rs @@ -1,10 +1,7 @@ -use super::side::{self, Side, SideMarker}; +use super::side::{self, Side}; use crate::protocol::{Authenticate as AuthenticateHeader, Header}; -pub struct Authenticate -where - M: SideMarker, -{ +pub struct Authenticate { inner: Side, _marker: M, } diff --git a/tuic/src/prototype/connect.rs b/tuic/src/prototype/connect.rs index a3643a1..0e40479 100644 --- a/tuic/src/prototype/connect.rs +++ b/tuic/src/prototype/connect.rs @@ -1,13 +1,10 @@ use super::{ - side::{self, Side, SideMarker}, + side::{self, Side}, TaskRegister, }; use crate::protocol::{Address, Connect as ConnectHeader, Header}; -pub struct Connect -where - M: SideMarker, -{ +pub struct Connect { inner: Side, _marker: M, } @@ -40,8 +37,7 @@ struct Rx { } impl Connect { - pub(super) fn new(task_reg: TaskRegister, header: ConnectHeader) -> Self { - let (addr,) = header.into(); + pub(super) fn new(task_reg: TaskRegister, addr: Address) -> Self { Self { inner: Side::Rx(Rx { addr, @@ -50,4 +46,9 @@ impl Connect { _marker: side::Rx, } } + + pub fn addr(&self) -> &Address { + let Side::Rx(rx) = &self.inner else { unreachable!() }; + &rx.addr + } } diff --git a/tuic/src/prototype/dissociate.rs b/tuic/src/prototype/dissociate.rs index c1088f5..bf7f728 100644 --- a/tuic/src/prototype/dissociate.rs +++ b/tuic/src/prototype/dissociate.rs @@ -1,10 +1,7 @@ -use super::side::{self, Side, SideMarker}; +use super::side::{self, Side}; use crate::protocol::{Dissociate as DissociateHeader, Header}; -pub struct Dissociate -where - M: SideMarker, -{ +pub struct Dissociate { inner: Side, _marker: M, } diff --git a/tuic/src/prototype/heartbeat.rs b/tuic/src/prototype/heartbeat.rs index 17ecdfc..b5b96eb 100644 --- a/tuic/src/prototype/heartbeat.rs +++ b/tuic/src/prototype/heartbeat.rs @@ -1,10 +1,7 @@ -use super::side::{self, Side, SideMarker}; +use super::side::{self, Side}; use crate::protocol::{Header, Heartbeat as HeartbeatHeader}; -pub struct Heartbeat -where - M: SideMarker, -{ +pub struct Heartbeat { inner: Side, _marker: M, } diff --git a/tuic/src/prototype/mod.rs b/tuic/src/prototype/mod.rs index 86ca1d7..b79cd7e 100644 --- a/tuic/src/prototype/mod.rs +++ b/tuic/src/prototype/mod.rs @@ -1,4 +1,4 @@ -use crate::protocol::{Address, Connect as ConnectHeader}; +use crate::protocol::{Address, Connect as ConnectHeader, Packet as PacketHeader}; use parking_lot::Mutex; use std::{ collections::HashMap, @@ -6,7 +6,9 @@ use std::{ atomic::{AtomicU16, Ordering}, Arc, Weak, }, + time::{Duration, Instant}, }; +use thiserror::Error; mod authenticate; mod connect; @@ -22,18 +24,21 @@ pub use self::{ packet::{Fragment, Packet}, }; -pub struct Connection { - udp_sessions: Mutex, +pub struct Connection { + udp_sessions: Arc>>, task_connect_count: TaskCount, task_associate_count: TaskCount, } -impl Connection { +impl Connection +where + B: AsRef<[u8]>, +{ pub fn new() -> Self { let task_associate_count = TaskCount::new(); Self { - udp_sessions: Mutex::new(UdpSessions::new(task_associate_count.clone())), + udp_sessions: Arc::new(Mutex::new(UdpSessions::new(task_associate_count.clone()))), task_connect_count: TaskCount::new(), task_associate_count, } @@ -44,11 +49,12 @@ impl Connection { } pub fn send_connect(&self, addr: Address) -> Connect { - Connect::::new(self.task_connect_count.reg(), addr) + Connect::::new(self.task_connect_count.register(), addr) } pub fn recv_connect(&self, header: ConnectHeader) -> Connect { - Connect::::new(self.task_connect_count.reg(), header) + let (addr,) = header.into(); + Connect::::new(self.task_connect_count.register(), addr) } pub fn send_packet( @@ -56,8 +62,23 @@ impl Connection { assoc_id: u16, addr: Address, max_pkt_size: usize, - ) -> Packet { - self.udp_sessions.lock().send(assoc_id, addr, max_pkt_size) + ) -> Packet { + self.udp_sessions + .lock() + .send_packet(assoc_id, addr, max_pkt_size) + } + + pub fn recv_packet(&self, header: PacketHeader) -> Packet { + let (assoc_id, pkt_id, frag_total, frag_id, size, addr) = header.into(); + self.udp_sessions.lock().recv_packet( + self.udp_sessions.clone(), + assoc_id, + pkt_id, + frag_total, + frag_id, + size, + addr, + ) } pub fn send_dissociate(&self, assoc_id: u16) -> Dissociate { @@ -75,6 +96,10 @@ impl Connection { pub fn task_associate_count(&self) -> usize { self.task_associate_count.get() } + + pub fn collect_garbage(&self, timeout: Duration) { + self.udp_sessions.lock().collect_garbage(timeout); + } } #[derive(Clone)] @@ -86,7 +111,7 @@ impl TaskCount { Self(Arc::new(())) } - fn reg(&self) -> TaskRegister { + fn register(&self) -> TaskRegister { TaskRegister(Arc::downgrade(&self.0)) } @@ -95,12 +120,25 @@ impl TaskCount { } } -struct UdpSessions { - sessions: HashMap, +pub mod side { + pub struct Tx; + pub struct Rx; + + pub(super) enum Side { + Tx(T), + Rx(R), + } +} + +struct UdpSessions { + sessions: HashMap>, task_associate_count: TaskCount, } -impl UdpSessions { +impl UdpSessions +where + B: AsRef<[u8]>, +{ fn new(task_associate_count: TaskCount) -> Self { Self { sessions: HashMap::new(), @@ -108,52 +146,224 @@ impl UdpSessions { } } - fn send<'a>(&mut self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet { + fn send_packet<'a>( + &mut self, + assoc_id: u16, + addr: Address, + max_pkt_size: usize, + ) -> Packet { self.sessions .entry(assoc_id) - .or_insert_with(|| UdpSession::new(self.task_associate_count.reg())) - .send(assoc_id, addr, max_pkt_size) + .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) + .send_packet(assoc_id, addr, max_pkt_size) + } + + fn recv_packet<'a>( + &mut self, + sessions: Arc>, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + ) -> Packet { + self.sessions + .entry(assoc_id) + .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) + .recv_packet(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr) } fn dissociate(&mut self, assoc_id: u16) -> Dissociate { self.sessions.remove(&assoc_id); Dissociate::new(assoc_id) } + + fn insert( + &mut self, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + data: B, + ) -> Result, AssembleError> + where + A: Assembled, + { + self.sessions + .entry(assoc_id) + .or_insert_with(|| UdpSession::new(self.task_associate_count.register())) + .insert(pkt_id, frag_total, frag_id, size, addr, data) + } + + fn collect_garbage(&mut self, timeout: Duration) { + for (_, session) in self.sessions.iter_mut() { + session.collect_garbage(timeout); + } + } } -struct UdpSession { +struct UdpSession { + pkt_buf: HashMap>, next_pkt_id: AtomicU16, _task_reg: TaskRegister, } -impl UdpSession { +impl UdpSession +where + B: AsRef<[u8]>, +{ fn new(task_reg: TaskRegister) -> Self { Self { + pkt_buf: HashMap::new(), next_pkt_id: AtomicU16::new(0), _task_reg: task_reg, } } - fn send<'a>(&self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet { - Packet::new( + fn send_packet( + &self, + assoc_id: u16, + addr: Address, + max_pkt_size: usize, + ) -> Packet { + Packet::::new( assoc_id, self.next_pkt_id.fetch_add(1, Ordering::AcqRel), addr, max_pkt_size, ) } -} -pub mod side { - pub struct Tx; - pub struct Rx; + fn recv_packet( + &self, + sessions: Arc>>, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + ) -> Packet { + Packet::::new(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr) + } - pub trait SideMarker {} - impl SideMarker for Tx {} - impl SideMarker for Rx {} + fn insert( + &mut self, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + data: B, + ) -> Result, AssembleError> + where + A: Assembled, + { + let res = self + .pkt_buf + .entry(pkt_id) + .or_insert_with(|| PacketBuffer::new(frag_total)) + .insert(frag_total, frag_id, size, addr, data)?; - pub(super) enum Side { - Tx(T), - Rx(R), + if res.is_some() { + self.pkt_buf.remove(&pkt_id); + } + + Ok(res) + } + + fn collect_garbage(&mut self, timeout: Duration) { + self.pkt_buf.retain(|_, buf| buf.c_time.elapsed() < timeout); } } + +struct PacketBuffer { + buf: Vec>, + frag_total: u8, + frag_received: u8, + addr: Address, + c_time: Instant, +} + +impl PacketBuffer +where + B: AsRef<[u8]>, +{ + fn new(frag_total: u8) -> Self { + let mut buf = Vec::with_capacity(frag_total as usize); + buf.resize_with(frag_total as usize, || None); + + Self { + buf, + frag_total, + frag_received: 0, + addr: Address::None, + c_time: Instant::now(), + } + } + + fn insert( + &mut self, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + data: B, + ) -> Result, AssembleError> + where + A: Assembled, + { + if data.as_ref().len() != size as usize { + return Err(AssembleError::InvalidFragmentSize); + } + + if frag_id >= frag_total { + return Err(AssembleError::InvalidFragmentId); + } + + if (frag_id == 0 && addr.is_none()) || (frag_id != 0 && !addr.is_none()) { + return Err(AssembleError::InvalidAddress); + } + + if self.buf[frag_id as usize].is_some() { + return Err(AssembleError::DuplicateFragment); + } + + self.buf[frag_id as usize] = Some(data); + self.frag_received += 1; + + if frag_id == 0 { + self.addr = addr; + } + + if self.frag_received == self.frag_total { + let iter = self.buf.iter_mut().map(|x| x.take().unwrap()); + Ok(Some((A::assemble(iter)?, self.addr.take()))) + } else { + Ok(None) + } + } +} + +pub trait Assembled +where + Self: Sized, + B: AsRef<[u8]>, +{ + fn assemble(buf: impl IntoIterator) -> Result; +} + +#[derive(Debug, Error)] +pub enum AssembleError { + #[error("invalid fragment size")] + InvalidFragmentSize, + #[error("invalid fragment id")] + InvalidFragmentId, + #[error("invalid address")] + InvalidAddress, + #[error("duplicate fragment")] + DuplicateFragment, +} diff --git a/tuic/src/prototype/packet.rs b/tuic/src/prototype/packet.rs index 4a139c6..56ae3bc 100644 --- a/tuic/src/prototype/packet.rs +++ b/tuic/src/prototype/packet.rs @@ -1,11 +1,13 @@ -use super::side::{self, Side, SideMarker}; +use super::{ + side::{self, Side}, + AssembleError, Assembled, UdpSessions, +}; use crate::protocol::{Address, Header, Packet as PacketHeader}; +use parking_lot::Mutex; +use std::sync::Arc; -pub struct Packet -where - M: SideMarker, -{ - inner: Side, +pub struct Packet { + inner: Side>, _marker: M, } @@ -16,9 +18,10 @@ pub struct Tx { max_pkt_size: usize, } -pub struct Rx; - -impl Packet { +impl Packet +where + B: AsRef<[u8]>, +{ pub(super) fn new(assoc_id: u16, pkt_id: u16, addr: Address, max_pkt_size: usize) -> Self { Self { inner: Side::Tx(Tx { @@ -37,6 +40,62 @@ impl Packet { } } +pub struct Rx { + sessions: Arc>>, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, +} + +impl Packet +where + B: AsRef<[u8]>, +{ + pub(super) fn new( + sessions: Arc>>, + assoc_id: u16, + pkt_id: u16, + frag_total: u8, + frag_id: u8, + size: u16, + addr: Address, + ) -> Self { + Self { + inner: Side::Rx(Rx { + sessions, + assoc_id, + pkt_id, + frag_total, + frag_id, + size, + addr, + }), + _marker: side::Rx, + } + } + + pub fn assemble(self, data: B) -> Result, AssembleError> + where + A: Assembled, + { + let Side::Rx(rx) = self.inner else { unreachable!() }; + let mut sessions = rx.sessions.lock(); + + sessions.insert( + rx.assoc_id, + rx.pkt_id, + rx.frag_total, + rx.frag_id, + rx.size, + rx.addr, + data, + ) + } +} + pub struct Fragment<'a> { assoc_id: u16, pkt_id: u16,