diff --git a/tuic/src/protocol/address.rs b/tuic/src/protocol/address.rs index 717629b..a45ac40 100644 --- a/tuic/src/protocol/address.rs +++ b/tuic/src/protocol/address.rs @@ -1,5 +1,6 @@ use std::{ fmt::{Display, Formatter, Result as FmtResult}, + mem, net::SocketAddr, }; @@ -56,6 +57,10 @@ impl Address { } } + pub fn take(&mut self) -> Self { + mem::take(self) + } + pub fn is_none(&self) -> bool { matches!(self, Self::None) } @@ -82,3 +87,9 @@ impl Display for Address { } } } + +impl Default for Address { + fn default() -> Self { + Self::None + } +} diff --git a/tuic/src/protocol/packet.rs b/tuic/src/protocol/packet.rs index 83e4504..62a8219 100644 --- a/tuic/src/protocol/packet.rs +++ b/tuic/src/protocol/packet.rs @@ -35,6 +35,10 @@ impl Packet { addr, } } + + pub const fn len_without_addr() -> usize { + 2 + 2 + 1 + 1 + 2 + } } impl Command for Packet { diff --git a/tuic/src/prototype/authenticate.rs b/tuic/src/prototype/authenticate.rs index 06b4a94..fbccf57 100644 --- a/tuic/src/prototype/authenticate.rs +++ b/tuic/src/prototype/authenticate.rs @@ -13,4 +13,8 @@ impl Authenticate { _task_reg: task_reg, } } + + pub fn header(&self) -> &Header { + &self.header + } } diff --git a/tuic/src/prototype/connect.rs b/tuic/src/prototype/connect.rs index 7db35aa..fddac8e 100644 --- a/tuic/src/prototype/connect.rs +++ b/tuic/src/prototype/connect.rs @@ -13,4 +13,8 @@ impl Connect { _task_reg: task_reg, } } + + pub fn header(&self) -> &Header { + &self.header + } } diff --git a/tuic/src/prototype/dissociate.rs b/tuic/src/prototype/dissociate.rs index e34c576..abf5cef 100644 --- a/tuic/src/prototype/dissociate.rs +++ b/tuic/src/prototype/dissociate.rs @@ -13,4 +13,8 @@ impl Dissociate { _task_reg: task_reg, } } + + pub fn header(&self) -> &Header { + &self.header + } } diff --git a/tuic/src/prototype/heartbeat.rs b/tuic/src/prototype/heartbeat.rs index c20b0e2..369a7ec 100644 --- a/tuic/src/prototype/heartbeat.rs +++ b/tuic/src/prototype/heartbeat.rs @@ -10,4 +10,8 @@ impl Heartbeat { header: Header::Heartbeat(HeartbeatHeader::new()), } } + + pub fn header(&self) -> &Header { + &self.header + } } diff --git a/tuic/src/prototype/mod.rs b/tuic/src/prototype/mod.rs index 49cda85..b96f414 100644 --- a/tuic/src/prototype/mod.rs +++ b/tuic/src/prototype/mod.rs @@ -1,8 +1,11 @@ use crate::protocol::Address; use parking_lot::Mutex; use std::{ - collections::{hash_map::Entry, HashMap}, - sync::{Arc, Weak}, + collections::HashMap, + sync::{ + atomic::{AtomicU16, Ordering}, + Arc, Weak, + }, }; mod authenticate; @@ -102,14 +105,10 @@ impl UdpSessions { payload: &'a [u8], frag_len: usize, ) -> Packet<'a> { - match self.sessions.entry(assoc_id) { - Entry::Occupied(_) => {} - Entry::Vacant(entry) => { - entry.insert(UdpSession::new(self.local_active_task_count.reg())); - } - } - - Packet::new(assoc_id, addr, payload, frag_len) + self.sessions + .entry(assoc_id) + .or_insert_with(|| UdpSession::new(self.local_active_task_count.reg())) + .send(assoc_id, addr, payload, frag_len) } fn dissociate(&mut self, assoc_id: u16) -> Dissociate { @@ -119,13 +118,31 @@ impl UdpSessions { } struct UdpSession { + next_pkt_id: AtomicU16, _task_reg: TaskRegister, } impl UdpSession { fn new(task_reg: TaskRegister) -> Self { Self { + next_pkt_id: AtomicU16::new(0), _task_reg: task_reg, } } + + fn send<'a>( + &self, + assoc_id: u16, + addr: Address, + payload: &'a [u8], + frag_len: usize, + ) -> Packet<'a> { + Packet::new( + assoc_id, + self.next_pkt_id.fetch_add(1, Ordering::AcqRel), + addr, + payload, + frag_len, + ) + } } diff --git a/tuic/src/prototype/packet.rs b/tuic/src/prototype/packet.rs index 3e262eb..ccd17fc 100644 --- a/tuic/src/prototype/packet.rs +++ b/tuic/src/prototype/packet.rs @@ -1,19 +1,79 @@ -use crate::protocol::Address; +use crate::protocol::{Address, Header, Packet as PacketHeader}; pub struct Packet<'a> { assoc_id: u16, + pkt_id: u16, addr: Address, payload: &'a [u8], - frag_len: usize, + max_pkt_size: usize, + frag_total: u8, + next_frag_id: u8, + next_frag_start: usize, } impl<'a> Packet<'a> { - pub(super) fn new(assoc_id: u16, addr: Address, payload: &'a [u8], frag_len: usize) -> Self { + pub(super) fn new( + assoc_id: u16, + pkt_id: u16, + addr: Address, + payload: &'a [u8], + max_pkt_size: usize, + ) -> Self { + let first_frag_size = max_pkt_size - PacketHeader::len_without_addr() - addr.len(); + let frag_size_addr_none = + max_pkt_size - PacketHeader::len_without_addr() - Address::None.len(); + + let frag_total = if first_frag_size < payload.len() { + (1 + (payload.len() - first_frag_size) / frag_size_addr_none + 1) as u8 + } else { + 1u8 + }; + Self { assoc_id, + pkt_id, addr, payload, - frag_len, + max_pkt_size, + frag_total, + next_frag_id: 0, + next_frag_start: 0, } } } + +impl<'a> Iterator for Packet<'a> { + type Item = (Header, &'a [u8]); + + fn next(&mut self) -> Option { + if self.next_frag_id < self.frag_total { + let payload_size = + self.max_pkt_size - PacketHeader::len_without_addr() - self.addr.len(); + let next_frag_end = (self.next_frag_start + payload_size).min(self.payload.len()); + + let header = Header::Packet(PacketHeader::new( + self.assoc_id, + self.pkt_id, + self.frag_total, + self.next_frag_id, + (next_frag_end - self.next_frag_start) as u16, + self.addr.take(), + )); + + let payload = &self.payload[self.next_frag_start..next_frag_end]; + + self.next_frag_id += 1; + self.next_frag_start = next_frag_end; + + Some((header, payload)) + } else { + None + } + } +} + +impl ExactSizeIterator for Packet<'_> { + fn len(&self) -> usize { + self.frag_total as usize + } +}