1
0

packet assembling mechanism

This commit is contained in:
EAimTY 2023-01-25 17:34:59 +09:00
parent 6cfb00dd81
commit 358fcd95f7
9 changed files with 329 additions and 67 deletions

View File

@ -4,10 +4,11 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[features] [features]
prototype = ["parking_lot"] prototype = ["parking_lot", "thiserror"]
[dependencies] [dependencies]
parking_lot = { version = "0.12.1", default-features = false, optional = true } parking_lot = { version = "0.12.1", default-features = false, optional = true }
thiserror = { version = "1.0.38", default-features = false, optional = true }
[dev-dependencies] [dev-dependencies]
tuic = { path = ".", features = ["prototype"] } tuic = { path = ".", features = ["prototype"] }

View File

@ -16,10 +16,10 @@ use std::{
/// ///
/// The address type can be one of the following: /// The address type can be one of the following:
/// ///
/// 0xff: None /// - 0xff: None
/// 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name) /// - 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name)
/// 0x01: IPv4 address /// - 0x01: IPv4 address
/// 0x02: IPv6 address /// - 0x02: IPv6 address
/// ///
/// The port number is encoded in 2 bytes after the Domain name / IP address. /// The port number is encoded in 2 bytes after the Domain name / IP address.
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]

View File

@ -27,7 +27,7 @@ impl Command for Heartbeat {
} }
impl From<Heartbeat> for () { impl From<Heartbeat> for () {
fn from(hb: Heartbeat) -> Self { fn from(_: Heartbeat) -> Self {
() ()
} }
} }

View File

@ -1,10 +1,7 @@
use super::side::{self, Side, SideMarker}; use super::side::{self, Side};
use crate::protocol::{Authenticate as AuthenticateHeader, Header}; use crate::protocol::{Authenticate as AuthenticateHeader, Header};
pub struct Authenticate<M> pub struct Authenticate<M> {
where
M: SideMarker,
{
inner: Side<Tx, Rx>, inner: Side<Tx, Rx>,
_marker: M, _marker: M,
} }

View File

@ -1,13 +1,10 @@
use super::{ use super::{
side::{self, Side, SideMarker}, side::{self, Side},
TaskRegister, TaskRegister,
}; };
use crate::protocol::{Address, Connect as ConnectHeader, Header}; use crate::protocol::{Address, Connect as ConnectHeader, Header};
pub struct Connect<M> pub struct Connect<M> {
where
M: SideMarker,
{
inner: Side<Tx, Rx>, inner: Side<Tx, Rx>,
_marker: M, _marker: M,
} }
@ -40,8 +37,7 @@ struct Rx {
} }
impl Connect<side::Rx> { impl Connect<side::Rx> {
pub(super) fn new(task_reg: TaskRegister, header: ConnectHeader) -> Self { pub(super) fn new(task_reg: TaskRegister, addr: Address) -> Self {
let (addr,) = header.into();
Self { Self {
inner: Side::Rx(Rx { inner: Side::Rx(Rx {
addr, addr,
@ -50,4 +46,9 @@ impl Connect<side::Rx> {
_marker: side::Rx, _marker: side::Rx,
} }
} }
pub fn addr(&self) -> &Address {
let Side::Rx(rx) = &self.inner else { unreachable!() };
&rx.addr
}
} }

View File

@ -1,10 +1,7 @@
use super::side::{self, Side, SideMarker}; use super::side::{self, Side};
use crate::protocol::{Dissociate as DissociateHeader, Header}; use crate::protocol::{Dissociate as DissociateHeader, Header};
pub struct Dissociate<M> pub struct Dissociate<M> {
where
M: SideMarker,
{
inner: Side<Tx, Rx>, inner: Side<Tx, Rx>,
_marker: M, _marker: M,
} }

View File

@ -1,10 +1,7 @@
use super::side::{self, Side, SideMarker}; use super::side::{self, Side};
use crate::protocol::{Header, Heartbeat as HeartbeatHeader}; use crate::protocol::{Header, Heartbeat as HeartbeatHeader};
pub struct Heartbeat<M> pub struct Heartbeat<M> {
where
M: SideMarker,
{
inner: Side<Tx, Rx>, inner: Side<Tx, Rx>,
_marker: M, _marker: M,
} }

View File

@ -1,4 +1,4 @@
use crate::protocol::{Address, Connect as ConnectHeader}; use crate::protocol::{Address, Connect as ConnectHeader, Packet as PacketHeader};
use parking_lot::Mutex; use parking_lot::Mutex;
use std::{ use std::{
collections::HashMap, collections::HashMap,
@ -6,7 +6,9 @@ use std::{
atomic::{AtomicU16, Ordering}, atomic::{AtomicU16, Ordering},
Arc, Weak, Arc, Weak,
}, },
time::{Duration, Instant},
}; };
use thiserror::Error;
mod authenticate; mod authenticate;
mod connect; mod connect;
@ -22,18 +24,21 @@ pub use self::{
packet::{Fragment, Packet}, packet::{Fragment, Packet},
}; };
pub struct Connection { pub struct Connection<B> {
udp_sessions: Mutex<UdpSessions>, udp_sessions: Arc<Mutex<UdpSessions<B>>>,
task_connect_count: TaskCount, task_connect_count: TaskCount,
task_associate_count: TaskCount, task_associate_count: TaskCount,
} }
impl Connection { impl<B> Connection<B>
where
B: AsRef<[u8]>,
{
pub fn new() -> Self { pub fn new() -> Self {
let task_associate_count = TaskCount::new(); let task_associate_count = TaskCount::new();
Self { Self {
udp_sessions: Mutex::new(UdpSessions::new(task_associate_count.clone())), udp_sessions: Arc::new(Mutex::new(UdpSessions::new(task_associate_count.clone()))),
task_connect_count: TaskCount::new(), task_connect_count: TaskCount::new(),
task_associate_count, task_associate_count,
} }
@ -44,11 +49,12 @@ impl Connection {
} }
pub fn send_connect(&self, addr: Address) -> Connect<side::Tx> { pub fn send_connect(&self, addr: Address) -> Connect<side::Tx> {
Connect::<side::Tx>::new(self.task_connect_count.reg(), addr) Connect::<side::Tx>::new(self.task_connect_count.register(), addr)
} }
pub fn recv_connect(&self, header: ConnectHeader) -> Connect<side::Rx> { pub fn recv_connect(&self, header: ConnectHeader) -> Connect<side::Rx> {
Connect::<side::Rx>::new(self.task_connect_count.reg(), header) let (addr,) = header.into();
Connect::<side::Rx>::new(self.task_connect_count.register(), addr)
} }
pub fn send_packet( pub fn send_packet(
@ -56,8 +62,23 @@ impl Connection {
assoc_id: u16, assoc_id: u16,
addr: Address, addr: Address,
max_pkt_size: usize, max_pkt_size: usize,
) -> Packet<side::Tx> { ) -> Packet<side::Tx, B> {
self.udp_sessions.lock().send(assoc_id, addr, max_pkt_size) self.udp_sessions
.lock()
.send_packet(assoc_id, addr, max_pkt_size)
}
pub fn recv_packet(&self, header: PacketHeader) -> Packet<side::Rx, B> {
let (assoc_id, pkt_id, frag_total, frag_id, size, addr) = header.into();
self.udp_sessions.lock().recv_packet(
self.udp_sessions.clone(),
assoc_id,
pkt_id,
frag_total,
frag_id,
size,
addr,
)
} }
pub fn send_dissociate(&self, assoc_id: u16) -> Dissociate<side::Tx> { pub fn send_dissociate(&self, assoc_id: u16) -> Dissociate<side::Tx> {
@ -75,6 +96,10 @@ impl Connection {
pub fn task_associate_count(&self) -> usize { pub fn task_associate_count(&self) -> usize {
self.task_associate_count.get() self.task_associate_count.get()
} }
pub fn collect_garbage(&self, timeout: Duration) {
self.udp_sessions.lock().collect_garbage(timeout);
}
} }
#[derive(Clone)] #[derive(Clone)]
@ -86,7 +111,7 @@ impl TaskCount {
Self(Arc::new(())) Self(Arc::new(()))
} }
fn reg(&self) -> TaskRegister { fn register(&self) -> TaskRegister {
TaskRegister(Arc::downgrade(&self.0)) TaskRegister(Arc::downgrade(&self.0))
} }
@ -95,12 +120,25 @@ impl TaskCount {
} }
} }
struct UdpSessions { pub mod side {
sessions: HashMap<u16, UdpSession>, pub struct Tx;
pub struct Rx;
pub(super) enum Side<T, R> {
Tx(T),
Rx(R),
}
}
struct UdpSessions<B> {
sessions: HashMap<u16, UdpSession<B>>,
task_associate_count: TaskCount, task_associate_count: TaskCount,
} }
impl UdpSessions { impl<B> UdpSessions<B>
where
B: AsRef<[u8]>,
{
fn new(task_associate_count: TaskCount) -> Self { fn new(task_associate_count: TaskCount) -> Self {
Self { Self {
sessions: HashMap::new(), sessions: HashMap::new(),
@ -108,52 +146,224 @@ impl UdpSessions {
} }
} }
fn send<'a>(&mut self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet<side::Tx> { fn send_packet<'a>(
&mut self,
assoc_id: u16,
addr: Address,
max_pkt_size: usize,
) -> Packet<side::Tx, B> {
self.sessions self.sessions
.entry(assoc_id) .entry(assoc_id)
.or_insert_with(|| UdpSession::new(self.task_associate_count.reg())) .or_insert_with(|| UdpSession::new(self.task_associate_count.register()))
.send(assoc_id, addr, max_pkt_size) .send_packet(assoc_id, addr, max_pkt_size)
}
fn recv_packet<'a>(
&mut self,
sessions: Arc<Mutex<Self>>,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
) -> Packet<side::Rx, B> {
self.sessions
.entry(assoc_id)
.or_insert_with(|| UdpSession::new(self.task_associate_count.register()))
.recv_packet(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr)
} }
fn dissociate(&mut self, assoc_id: u16) -> Dissociate<side::Tx> { fn dissociate(&mut self, assoc_id: u16) -> Dissociate<side::Tx> {
self.sessions.remove(&assoc_id); self.sessions.remove(&assoc_id);
Dissociate::new(assoc_id) Dissociate::new(assoc_id)
} }
fn insert<A>(
&mut self,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
data: B,
) -> Result<Option<(A, Address)>, AssembleError>
where
A: Assembled<B>,
{
self.sessions
.entry(assoc_id)
.or_insert_with(|| UdpSession::new(self.task_associate_count.register()))
.insert(pkt_id, frag_total, frag_id, size, addr, data)
} }
struct UdpSession { fn collect_garbage(&mut self, timeout: Duration) {
for (_, session) in self.sessions.iter_mut() {
session.collect_garbage(timeout);
}
}
}
struct UdpSession<B> {
pkt_buf: HashMap<u16, PacketBuffer<B>>,
next_pkt_id: AtomicU16, next_pkt_id: AtomicU16,
_task_reg: TaskRegister, _task_reg: TaskRegister,
} }
impl UdpSession { impl<B> UdpSession<B>
where
B: AsRef<[u8]>,
{
fn new(task_reg: TaskRegister) -> Self { fn new(task_reg: TaskRegister) -> Self {
Self { Self {
pkt_buf: HashMap::new(),
next_pkt_id: AtomicU16::new(0), next_pkt_id: AtomicU16::new(0),
_task_reg: task_reg, _task_reg: task_reg,
} }
} }
fn send<'a>(&self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet<side::Tx> { fn send_packet(
Packet::new( &self,
assoc_id: u16,
addr: Address,
max_pkt_size: usize,
) -> Packet<side::Tx, B> {
Packet::<side::Tx, B>::new(
assoc_id, assoc_id,
self.next_pkt_id.fetch_add(1, Ordering::AcqRel), self.next_pkt_id.fetch_add(1, Ordering::AcqRel),
addr, addr,
max_pkt_size, max_pkt_size,
) )
} }
fn recv_packet(
&self,
sessions: Arc<Mutex<UdpSessions<B>>>,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
) -> Packet<side::Rx, B> {
Packet::<side::Rx, B>::new(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr)
} }
pub mod side { fn insert<A>(
pub struct Tx; &mut self,
pub struct Rx; pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
data: B,
) -> Result<Option<(A, Address)>, AssembleError>
where
A: Assembled<B>,
{
let res = self
.pkt_buf
.entry(pkt_id)
.or_insert_with(|| PacketBuffer::new(frag_total))
.insert(frag_total, frag_id, size, addr, data)?;
pub trait SideMarker {} if res.is_some() {
impl SideMarker for Tx {} self.pkt_buf.remove(&pkt_id);
impl SideMarker for Rx {} }
pub(super) enum Side<T, R> { Ok(res)
Tx(T), }
Rx(R),
fn collect_garbage(&mut self, timeout: Duration) {
self.pkt_buf.retain(|_, buf| buf.c_time.elapsed() < timeout);
} }
} }
struct PacketBuffer<B> {
buf: Vec<Option<B>>,
frag_total: u8,
frag_received: u8,
addr: Address,
c_time: Instant,
}
impl<B> PacketBuffer<B>
where
B: AsRef<[u8]>,
{
fn new(frag_total: u8) -> Self {
let mut buf = Vec::with_capacity(frag_total as usize);
buf.resize_with(frag_total as usize, || None);
Self {
buf,
frag_total,
frag_received: 0,
addr: Address::None,
c_time: Instant::now(),
}
}
fn insert<A>(
&mut self,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
data: B,
) -> Result<Option<(A, Address)>, AssembleError>
where
A: Assembled<B>,
{
if data.as_ref().len() != size as usize {
return Err(AssembleError::InvalidFragmentSize);
}
if frag_id >= frag_total {
return Err(AssembleError::InvalidFragmentId);
}
if (frag_id == 0 && addr.is_none()) || (frag_id != 0 && !addr.is_none()) {
return Err(AssembleError::InvalidAddress);
}
if self.buf[frag_id as usize].is_some() {
return Err(AssembleError::DuplicateFragment);
}
self.buf[frag_id as usize] = Some(data);
self.frag_received += 1;
if frag_id == 0 {
self.addr = addr;
}
if self.frag_received == self.frag_total {
let iter = self.buf.iter_mut().map(|x| x.take().unwrap());
Ok(Some((A::assemble(iter)?, self.addr.take())))
} else {
Ok(None)
}
}
}
pub trait Assembled<B>
where
Self: Sized,
B: AsRef<[u8]>,
{
fn assemble(buf: impl IntoIterator<Item = B>) -> Result<Self, AssembleError>;
}
#[derive(Debug, Error)]
pub enum AssembleError {
#[error("invalid fragment size")]
InvalidFragmentSize,
#[error("invalid fragment id")]
InvalidFragmentId,
#[error("invalid address")]
InvalidAddress,
#[error("duplicate fragment")]
DuplicateFragment,
}

View File

@ -1,11 +1,13 @@
use super::side::{self, Side, SideMarker}; use super::{
side::{self, Side},
AssembleError, Assembled, UdpSessions,
};
use crate::protocol::{Address, Header, Packet as PacketHeader}; use crate::protocol::{Address, Header, Packet as PacketHeader};
use parking_lot::Mutex;
use std::sync::Arc;
pub struct Packet<M> pub struct Packet<M, B> {
where inner: Side<Tx, Rx<B>>,
M: SideMarker,
{
inner: Side<Tx, Rx>,
_marker: M, _marker: M,
} }
@ -16,9 +18,10 @@ pub struct Tx {
max_pkt_size: usize, max_pkt_size: usize,
} }
pub struct Rx; impl<B> Packet<side::Tx, B>
where
impl Packet<side::Tx> { B: AsRef<[u8]>,
{
pub(super) fn new(assoc_id: u16, pkt_id: u16, addr: Address, max_pkt_size: usize) -> Self { pub(super) fn new(assoc_id: u16, pkt_id: u16, addr: Address, max_pkt_size: usize) -> Self {
Self { Self {
inner: Side::Tx(Tx { inner: Side::Tx(Tx {
@ -37,6 +40,62 @@ impl Packet<side::Tx> {
} }
} }
pub struct Rx<B> {
sessions: Arc<Mutex<UdpSessions<B>>>,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
}
impl<B> Packet<side::Rx, B>
where
B: AsRef<[u8]>,
{
pub(super) fn new(
sessions: Arc<Mutex<UdpSessions<B>>>,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
) -> Self {
Self {
inner: Side::Rx(Rx {
sessions,
assoc_id,
pkt_id,
frag_total,
frag_id,
size,
addr,
}),
_marker: side::Rx,
}
}
pub fn assemble<A>(self, data: B) -> Result<Option<(A, Address)>, AssembleError>
where
A: Assembled<B>,
{
let Side::Rx(rx) = self.inner else { unreachable!() };
let mut sessions = rx.sessions.lock();
sessions.insert(
rx.assoc_id,
rx.pkt_id,
rx.frag_total,
rx.frag_id,
rx.size,
rx.addr,
data,
)
}
}
pub struct Fragment<'a> { pub struct Fragment<'a> {
assoc_id: u16, assoc_id: u16,
pkt_id: u16, pkt_id: u16,