1
0

packet assembling mechanism

This commit is contained in:
EAimTY 2023-01-25 17:34:59 +09:00
parent 6cfb00dd81
commit 358fcd95f7
9 changed files with 329 additions and 67 deletions

View File

@ -4,10 +4,11 @@ version = "0.1.0"
edition = "2021"
[features]
prototype = ["parking_lot"]
prototype = ["parking_lot", "thiserror"]
[dependencies]
parking_lot = { version = "0.12.1", default-features = false, optional = true }
thiserror = { version = "1.0.38", default-features = false, optional = true }
[dev-dependencies]
tuic = { path = ".", features = ["prototype"] }

View File

@ -16,10 +16,10 @@ use std::{
///
/// The address type can be one of the following:
///
/// 0xff: None
/// 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name)
/// 0x01: IPv4 address
/// 0x02: IPv6 address
/// - 0xff: None
/// - 0x00: Fully-qualified domain name (the first byte indicates the length of the domain name)
/// - 0x01: IPv4 address
/// - 0x02: IPv6 address
///
/// The port number is encoded in 2 bytes after the Domain name / IP address.
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]

View File

@ -27,7 +27,7 @@ impl Command for Heartbeat {
}
impl From<Heartbeat> for () {
fn from(hb: Heartbeat) -> Self {
fn from(_: Heartbeat) -> Self {
()
}
}

View File

@ -1,10 +1,7 @@
use super::side::{self, Side, SideMarker};
use super::side::{self, Side};
use crate::protocol::{Authenticate as AuthenticateHeader, Header};
pub struct Authenticate<M>
where
M: SideMarker,
{
pub struct Authenticate<M> {
inner: Side<Tx, Rx>,
_marker: M,
}

View File

@ -1,13 +1,10 @@
use super::{
side::{self, Side, SideMarker},
side::{self, Side},
TaskRegister,
};
use crate::protocol::{Address, Connect as ConnectHeader, Header};
pub struct Connect<M>
where
M: SideMarker,
{
pub struct Connect<M> {
inner: Side<Tx, Rx>,
_marker: M,
}
@ -40,8 +37,7 @@ struct Rx {
}
impl Connect<side::Rx> {
pub(super) fn new(task_reg: TaskRegister, header: ConnectHeader) -> Self {
let (addr,) = header.into();
pub(super) fn new(task_reg: TaskRegister, addr: Address) -> Self {
Self {
inner: Side::Rx(Rx {
addr,
@ -50,4 +46,9 @@ impl Connect<side::Rx> {
_marker: side::Rx,
}
}
pub fn addr(&self) -> &Address {
let Side::Rx(rx) = &self.inner else { unreachable!() };
&rx.addr
}
}

View File

@ -1,10 +1,7 @@
use super::side::{self, Side, SideMarker};
use super::side::{self, Side};
use crate::protocol::{Dissociate as DissociateHeader, Header};
pub struct Dissociate<M>
where
M: SideMarker,
{
pub struct Dissociate<M> {
inner: Side<Tx, Rx>,
_marker: M,
}

View File

@ -1,10 +1,7 @@
use super::side::{self, Side, SideMarker};
use super::side::{self, Side};
use crate::protocol::{Header, Heartbeat as HeartbeatHeader};
pub struct Heartbeat<M>
where
M: SideMarker,
{
pub struct Heartbeat<M> {
inner: Side<Tx, Rx>,
_marker: M,
}

View File

@ -1,4 +1,4 @@
use crate::protocol::{Address, Connect as ConnectHeader};
use crate::protocol::{Address, Connect as ConnectHeader, Packet as PacketHeader};
use parking_lot::Mutex;
use std::{
collections::HashMap,
@ -6,7 +6,9 @@ use std::{
atomic::{AtomicU16, Ordering},
Arc, Weak,
},
time::{Duration, Instant},
};
use thiserror::Error;
mod authenticate;
mod connect;
@ -22,18 +24,21 @@ pub use self::{
packet::{Fragment, Packet},
};
pub struct Connection {
udp_sessions: Mutex<UdpSessions>,
pub struct Connection<B> {
udp_sessions: Arc<Mutex<UdpSessions<B>>>,
task_connect_count: TaskCount,
task_associate_count: TaskCount,
}
impl Connection {
impl<B> Connection<B>
where
B: AsRef<[u8]>,
{
pub fn new() -> Self {
let task_associate_count = TaskCount::new();
Self {
udp_sessions: Mutex::new(UdpSessions::new(task_associate_count.clone())),
udp_sessions: Arc::new(Mutex::new(UdpSessions::new(task_associate_count.clone()))),
task_connect_count: TaskCount::new(),
task_associate_count,
}
@ -44,11 +49,12 @@ impl Connection {
}
pub fn send_connect(&self, addr: Address) -> Connect<side::Tx> {
Connect::<side::Tx>::new(self.task_connect_count.reg(), addr)
Connect::<side::Tx>::new(self.task_connect_count.register(), addr)
}
pub fn recv_connect(&self, header: ConnectHeader) -> Connect<side::Rx> {
Connect::<side::Rx>::new(self.task_connect_count.reg(), header)
let (addr,) = header.into();
Connect::<side::Rx>::new(self.task_connect_count.register(), addr)
}
pub fn send_packet(
@ -56,8 +62,23 @@ impl Connection {
assoc_id: u16,
addr: Address,
max_pkt_size: usize,
) -> Packet<side::Tx> {
self.udp_sessions.lock().send(assoc_id, addr, max_pkt_size)
) -> Packet<side::Tx, B> {
self.udp_sessions
.lock()
.send_packet(assoc_id, addr, max_pkt_size)
}
pub fn recv_packet(&self, header: PacketHeader) -> Packet<side::Rx, B> {
let (assoc_id, pkt_id, frag_total, frag_id, size, addr) = header.into();
self.udp_sessions.lock().recv_packet(
self.udp_sessions.clone(),
assoc_id,
pkt_id,
frag_total,
frag_id,
size,
addr,
)
}
pub fn send_dissociate(&self, assoc_id: u16) -> Dissociate<side::Tx> {
@ -75,6 +96,10 @@ impl Connection {
pub fn task_associate_count(&self) -> usize {
self.task_associate_count.get()
}
pub fn collect_garbage(&self, timeout: Duration) {
self.udp_sessions.lock().collect_garbage(timeout);
}
}
#[derive(Clone)]
@ -86,7 +111,7 @@ impl TaskCount {
Self(Arc::new(()))
}
fn reg(&self) -> TaskRegister {
fn register(&self) -> TaskRegister {
TaskRegister(Arc::downgrade(&self.0))
}
@ -95,12 +120,25 @@ impl TaskCount {
}
}
struct UdpSessions {
sessions: HashMap<u16, UdpSession>,
pub mod side {
pub struct Tx;
pub struct Rx;
pub(super) enum Side<T, R> {
Tx(T),
Rx(R),
}
}
struct UdpSessions<B> {
sessions: HashMap<u16, UdpSession<B>>,
task_associate_count: TaskCount,
}
impl UdpSessions {
impl<B> UdpSessions<B>
where
B: AsRef<[u8]>,
{
fn new(task_associate_count: TaskCount) -> Self {
Self {
sessions: HashMap::new(),
@ -108,52 +146,224 @@ impl UdpSessions {
}
}
fn send<'a>(&mut self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet<side::Tx> {
fn send_packet<'a>(
&mut self,
assoc_id: u16,
addr: Address,
max_pkt_size: usize,
) -> Packet<side::Tx, B> {
self.sessions
.entry(assoc_id)
.or_insert_with(|| UdpSession::new(self.task_associate_count.reg()))
.send(assoc_id, addr, max_pkt_size)
.or_insert_with(|| UdpSession::new(self.task_associate_count.register()))
.send_packet(assoc_id, addr, max_pkt_size)
}
fn recv_packet<'a>(
&mut self,
sessions: Arc<Mutex<Self>>,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
) -> Packet<side::Rx, B> {
self.sessions
.entry(assoc_id)
.or_insert_with(|| UdpSession::new(self.task_associate_count.register()))
.recv_packet(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr)
}
fn dissociate(&mut self, assoc_id: u16) -> Dissociate<side::Tx> {
self.sessions.remove(&assoc_id);
Dissociate::new(assoc_id)
}
fn insert<A>(
&mut self,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
data: B,
) -> Result<Option<(A, Address)>, AssembleError>
where
A: Assembled<B>,
{
self.sessions
.entry(assoc_id)
.or_insert_with(|| UdpSession::new(self.task_associate_count.register()))
.insert(pkt_id, frag_total, frag_id, size, addr, data)
}
struct UdpSession {
fn collect_garbage(&mut self, timeout: Duration) {
for (_, session) in self.sessions.iter_mut() {
session.collect_garbage(timeout);
}
}
}
struct UdpSession<B> {
pkt_buf: HashMap<u16, PacketBuffer<B>>,
next_pkt_id: AtomicU16,
_task_reg: TaskRegister,
}
impl UdpSession {
impl<B> UdpSession<B>
where
B: AsRef<[u8]>,
{
fn new(task_reg: TaskRegister) -> Self {
Self {
pkt_buf: HashMap::new(),
next_pkt_id: AtomicU16::new(0),
_task_reg: task_reg,
}
}
fn send<'a>(&self, assoc_id: u16, addr: Address, max_pkt_size: usize) -> Packet<side::Tx> {
Packet::new(
fn send_packet(
&self,
assoc_id: u16,
addr: Address,
max_pkt_size: usize,
) -> Packet<side::Tx, B> {
Packet::<side::Tx, B>::new(
assoc_id,
self.next_pkt_id.fetch_add(1, Ordering::AcqRel),
addr,
max_pkt_size,
)
}
fn recv_packet(
&self,
sessions: Arc<Mutex<UdpSessions<B>>>,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
) -> Packet<side::Rx, B> {
Packet::<side::Rx, B>::new(sessions, assoc_id, pkt_id, frag_total, frag_id, size, addr)
}
pub mod side {
pub struct Tx;
pub struct Rx;
fn insert<A>(
&mut self,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
data: B,
) -> Result<Option<(A, Address)>, AssembleError>
where
A: Assembled<B>,
{
let res = self
.pkt_buf
.entry(pkt_id)
.or_insert_with(|| PacketBuffer::new(frag_total))
.insert(frag_total, frag_id, size, addr, data)?;
pub trait SideMarker {}
impl SideMarker for Tx {}
impl SideMarker for Rx {}
if res.is_some() {
self.pkt_buf.remove(&pkt_id);
}
pub(super) enum Side<T, R> {
Tx(T),
Rx(R),
Ok(res)
}
fn collect_garbage(&mut self, timeout: Duration) {
self.pkt_buf.retain(|_, buf| buf.c_time.elapsed() < timeout);
}
}
struct PacketBuffer<B> {
buf: Vec<Option<B>>,
frag_total: u8,
frag_received: u8,
addr: Address,
c_time: Instant,
}
impl<B> PacketBuffer<B>
where
B: AsRef<[u8]>,
{
fn new(frag_total: u8) -> Self {
let mut buf = Vec::with_capacity(frag_total as usize);
buf.resize_with(frag_total as usize, || None);
Self {
buf,
frag_total,
frag_received: 0,
addr: Address::None,
c_time: Instant::now(),
}
}
fn insert<A>(
&mut self,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
data: B,
) -> Result<Option<(A, Address)>, AssembleError>
where
A: Assembled<B>,
{
if data.as_ref().len() != size as usize {
return Err(AssembleError::InvalidFragmentSize);
}
if frag_id >= frag_total {
return Err(AssembleError::InvalidFragmentId);
}
if (frag_id == 0 && addr.is_none()) || (frag_id != 0 && !addr.is_none()) {
return Err(AssembleError::InvalidAddress);
}
if self.buf[frag_id as usize].is_some() {
return Err(AssembleError::DuplicateFragment);
}
self.buf[frag_id as usize] = Some(data);
self.frag_received += 1;
if frag_id == 0 {
self.addr = addr;
}
if self.frag_received == self.frag_total {
let iter = self.buf.iter_mut().map(|x| x.take().unwrap());
Ok(Some((A::assemble(iter)?, self.addr.take())))
} else {
Ok(None)
}
}
}
pub trait Assembled<B>
where
Self: Sized,
B: AsRef<[u8]>,
{
fn assemble(buf: impl IntoIterator<Item = B>) -> Result<Self, AssembleError>;
}
#[derive(Debug, Error)]
pub enum AssembleError {
#[error("invalid fragment size")]
InvalidFragmentSize,
#[error("invalid fragment id")]
InvalidFragmentId,
#[error("invalid address")]
InvalidAddress,
#[error("duplicate fragment")]
DuplicateFragment,
}

View File

@ -1,11 +1,13 @@
use super::side::{self, Side, SideMarker};
use super::{
side::{self, Side},
AssembleError, Assembled, UdpSessions,
};
use crate::protocol::{Address, Header, Packet as PacketHeader};
use parking_lot::Mutex;
use std::sync::Arc;
pub struct Packet<M>
where
M: SideMarker,
{
inner: Side<Tx, Rx>,
pub struct Packet<M, B> {
inner: Side<Tx, Rx<B>>,
_marker: M,
}
@ -16,9 +18,10 @@ pub struct Tx {
max_pkt_size: usize,
}
pub struct Rx;
impl Packet<side::Tx> {
impl<B> Packet<side::Tx, B>
where
B: AsRef<[u8]>,
{
pub(super) fn new(assoc_id: u16, pkt_id: u16, addr: Address, max_pkt_size: usize) -> Self {
Self {
inner: Side::Tx(Tx {
@ -37,6 +40,62 @@ impl Packet<side::Tx> {
}
}
pub struct Rx<B> {
sessions: Arc<Mutex<UdpSessions<B>>>,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
}
impl<B> Packet<side::Rx, B>
where
B: AsRef<[u8]>,
{
pub(super) fn new(
sessions: Arc<Mutex<UdpSessions<B>>>,
assoc_id: u16,
pkt_id: u16,
frag_total: u8,
frag_id: u8,
size: u16,
addr: Address,
) -> Self {
Self {
inner: Side::Rx(Rx {
sessions,
assoc_id,
pkt_id,
frag_total,
frag_id,
size,
addr,
}),
_marker: side::Rx,
}
}
pub fn assemble<A>(self, data: B) -> Result<Option<(A, Address)>, AssembleError>
where
A: Assembled<B>,
{
let Side::Rx(rx) = self.inner else { unreachable!() };
let mut sessions = rx.sessions.lock();
sessions.insert(
rx.assoc_id,
rx.pkt_id,
rx.frag_total,
rx.frag_id,
rx.size,
rx.addr,
data,
)
}
}
pub struct Fragment<'a> {
assoc_id: u16,
pkt_id: u16,