diff --git a/tuic-client/src/connection.rs b/tuic-client/src/connection.rs index de19db5..7eb0c4c 100644 --- a/tuic-client/src/connection.rs +++ b/tuic-client/src/connection.rs @@ -1,6 +1,6 @@ use crate::{ config::Relay, - socks5::Server as Socks5Server, + socks5::UDP_SESSIONS as SOCKS5_UDP_SESSIONS, utils::{self, CongestionControl, ServerAddr, UdpRelayMode}, Error, }; @@ -29,7 +29,7 @@ use tokio::{ time, }; use tuic::Address; -use tuic_quinn::{side, Connect, Connection as Model, Task}; +use tuic_quinn::{side, Connect, Connection as Model, Packet, Task}; use uuid::Uuid; static ENDPOINT: OnceCell<Mutex<Endpoint>> = OnceCell::new(); @@ -158,12 +158,7 @@ impl Endpoint { let conn = if zero_rtt_handshake { match conn.into_0rtt() { Ok((conn, _)) => conn, - Err(conn) => { - log::info!( - "[connection] 0-RTT handshake failed, fallback to 1-RTT handshake" - ); - conn.await? - } + Err(conn) => conn.await?, } } else { conn.await? @@ -185,8 +180,6 @@ impl Endpoint { match res { Ok(conn) => { - log::info!("[connection] connection established"); - return Ok(Connection::new( conn, self.udp_relay_mode, @@ -279,6 +272,8 @@ impl Connection { } async fn init(self, heartbeat: Duration, gc_interval: Duration, gc_lifetime: Duration) { + log::info!("[relay] connection established"); + tokio::spawn(self.clone().authenticate()); tokio::spawn(self.clone().heartbeat(heartbeat)); tokio::spawn(self.clone().collect_garbage(gc_interval, gc_lifetime)); @@ -300,7 +295,7 @@ impl Connection { }; }; - log::error!("[connection] {err}"); + log::warn!("[relay] connection error: {err}"); } pub async fn connect(&self, addr: Address) -> Result<Connect, Error> { @@ -362,88 +357,85 @@ impl Connection { } async fn handle_uni_stream(self, recv: RecvStream, _reg: Register) { - log::debug!("[connection] incoming unidirectional stream"); + log::debug!("[relay] incoming unidirectional stream"); + let res = match self.model.accept_uni_stream(recv).await { - Err(err) => Err(Error::from(err)), + Err(err) => Err(Error::Model(err)), Ok(Task::Packet(pkt)) => match self.udp_relay_mode { - UdpRelayMode::Quic => match pkt.accept().await { - Ok(Some((pkt, addr, assoc_id))) => { - let addr = match addr { - Address::None => unreachable!(), - Address::DomainAddress(domain, port) => { - Socks5Address::DomainAddress(domain, port) - } - Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), - }; - Socks5Server::recv_pkt(pkt, addr, assoc_id).await; - Ok(()) - } - Ok(None) => Ok(()), - Err(err) => Err(Error::from(err)), - }, + UdpRelayMode::Quic => { + log::debug!( + "[relay] [packet] [{assoc_id:#06x}] [from-quic] [{pkt_id:#06x}] {frag_id}/{frag_total}", + assoc_id = pkt.assoc_id(), + pkt_id = pkt.pkt_id(), + frag_id = pkt.frag_id(), + frag_total = pkt.frag_total(), + ); + Self::handle_packet(pkt).await; + Ok(()) + } UdpRelayMode::Native => Err(Error::WrongPacketSource), }, - _ => unreachable!(), + _ => unreachable!(), // already filtered in `tuic_quinn` }; match res { Ok(()) => {} - Err(err) => log::error!("[connection] {err}"), + Err(err) => log::warn!("[relay] incoming unidirectional stream error: {err}"), } } async fn handle_bi_stream(self, send: SendStream, recv: RecvStream, _reg: Register) { - log::debug!("[connection] incoming bidirectional stream"); + log::debug!("[relay] incoming bidirectional stream"); + let res = match self.model.accept_bi_stream(send, recv).await { - Err(err) => Err(Error::from(err)), - _ => unreachable!(), + Err(err) => Err(Error::Model(err)), + _ => unreachable!(), // already filtered in `tuic_quinn` }; match res { Ok(()) => {} - Err(err) => log::error!("[connection] {err}"), + Err(err) => log::warn!("[relay] incoming bidirectional stream error: {err}"), } } async fn handle_datagram(self, dg: Bytes) { - log::debug!("[connection] incoming datagram"); + log::debug!("[relay] incoming datagram"); + let res = match self.model.accept_datagram(dg) { - Err(err) => Err(Error::from(err)), + Err(err) => Err(Error::Model(err)), Ok(Task::Packet(pkt)) => match self.udp_relay_mode { - UdpRelayMode::Native => match pkt.accept().await { - Ok(Some((pkt, addr, assoc_id))) => { - let addr = match addr { - Address::None => unreachable!(), - Address::DomainAddress(domain, port) => { - Socks5Address::DomainAddress(domain, port) - } - Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), - }; - Socks5Server::recv_pkt(pkt, addr, assoc_id).await; - Ok(()) - } - Ok(None) => Ok(()), - Err(err) => Err(Error::from(err)), - }, + UdpRelayMode::Native => { + log::debug!( + "[relay] [packet] [{assoc_id:#06x}] [from-native] [{pkt_id:#06x}] {frag_id}/{frag_total}", + assoc_id = pkt.assoc_id(), + pkt_id = pkt.pkt_id(), + frag_id = pkt.frag_id(), + frag_total = pkt.frag_total(), + ); + Self::handle_packet(pkt).await; + Ok(()) + } UdpRelayMode::Quic => Err(Error::WrongPacketSource), }, - _ => unreachable!(), + _ => unreachable!(), // already filtered in `tuic_quinn` }; match res { Ok(()) => {} - Err(err) => log::error!("[connection] {err}"), + Err(err) => log::warn!("[relay] incoming datagram error: {err}"), } } async fn authenticate(self) { + log::debug!("[relay] [authenticate] sending authentication"); + match self .model .authenticate(self.uuid, self.password.clone()) .await { - Ok(()) => log::info!("[connection] authentication sent"), - Err(err) => log::warn!("[connection] authentication failed: {err}"), + Ok(()) => log::info!("[relay] [authenticate] {uuid}", uuid = self.uuid), + Err(err) => log::warn!("[relay] [authenticate] authentication sending error: {err}"), } } @@ -460,12 +452,51 @@ impl Connection { } match self.model.heartbeat().await { - Ok(()) => log::info!("[connection] heartbeat"), - Err(err) => log::warn!("[connection] heartbeat error: {err}"), + Ok(()) => log::debug!("[relay] [heartbeat]"), + Err(err) => log::warn!("[relay] [heartbeat] heartbeat sending error: {err}"), } } } + async fn handle_packet(pkt: Packet) { + let assoc_id = pkt.assoc_id(); + let pkt_id = pkt.pkt_id(); + + match pkt.accept().await { + Ok(Some((pkt, addr, _))) => { + log::info!( + "[relay] [packet] [{assoc_id:#06x}] [from-native] [{pkt_id:#06x}] {addr}", + ); + + let addr = match addr { + Address::None => unreachable!(), + Address::DomainAddress(domain, port) => { + Socks5Address::DomainAddress(domain, port) + } + Address::SocketAddress(addr) => Socks5Address::SocketAddress(addr), + }; + + if let Some(session) = SOCKS5_UDP_SESSIONS + .get() + .unwrap() + .lock() + .get(&assoc_id) + .cloned() + { + if let Err(err) = session.send(pkt, addr).await { + log::warn!( + "[relay] [packet] [{assoc_id:#06x}] [from-native] [{pkt_id:#06x}] failed sending packet to socks5 client: {err}", + ); + } + } else { + log::warn!("[relay] [packet] [{assoc_id:#06x}] [from-native] [{pkt_id:#06x}] unable to find socks5 associate session"); + } + } + Ok(None) => {} + Err(err) => log::warn!("[relay] [packet] [{assoc_id:#06x}] [from-native] [{pkt_id:#06x}] packet receiving error: {err}"), + } + } + async fn collect_garbage(self, gc_interval: Duration, gc_lifetime: Duration) { loop { time::sleep(gc_interval).await; @@ -474,7 +505,7 @@ impl Connection { break; } - log::debug!("[connection] packet garbage collection"); + log::debug!("[relay] packet fragment garbage collecting event"); self.model.collect_garbage(gc_lifetime); } } diff --git a/tuic-client/src/main.rs b/tuic-client/src/main.rs index af84d87..730c924 100644 --- a/tuic-client/src/main.rs +++ b/tuic-client/src/main.rs @@ -43,7 +43,7 @@ async fn main() { } } - match Socks5Server::set_config(cfg.local).await { + match Socks5Server::set_config(cfg.local) { Ok(()) => {} Err(err) => { eprintln!("{err}"); diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index 414ffd2..d1db0c3 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -13,7 +13,7 @@ use socks5_server::{ use std::{ collections::HashMap, io::{Error as IoError, ErrorKind}, - net::{SocketAddr, TcpListener as StdTcpListener, UdpSocket as StdUdpSocket}, + net::{IpAddr, SocketAddr, TcpListener as StdTcpListener, UdpSocket as StdUdpSocket}, sync::{ atomic::{AtomicU16, Ordering}, Arc, @@ -27,19 +27,45 @@ use tokio_util::compat::FuturesAsyncReadCompatExt; use tuic::Address as TuicAddress; static SERVER: OnceCell<Server> = OnceCell::new(); +pub static UDP_SESSIONS: OnceCell<Mutex<HashMap<u16, UdpSession>>> = OnceCell::new(); pub struct Server { inner: Socks5Server, dual_stack: Option<bool>, max_pkt_size: usize, next_assoc_id: AtomicU16, - udp_sessions: Mutex<HashMap<u16, Arc<AssociatedUdpSocket>>>, } impl Server { - pub async fn set_config(cfg: Local) -> Result<(), Error> { + pub fn set_config(cfg: Local) -> Result<(), Error> { + SERVER + .set(Self::new( + cfg.server, + cfg.dual_stack, + cfg.max_packet_size, + cfg.username.map(|s| s.into_bytes()), + cfg.password.map(|s| s.into_bytes()), + )?) + .map_err(|_| "failed initializing socks5 server") + .unwrap(); + + UDP_SESSIONS + .set(Mutex::new(HashMap::new())) + .map_err(|_| "failed initializing socks5 UDP session pool") + .unwrap(); + + Ok(()) + } + + fn new( + addr: SocketAddr, + dual_stack: Option<bool>, + max_pkt_size: usize, + username: Option<Vec<u8>>, + password: Option<Vec<u8>>, + ) -> Result<Self, Error> { let socket = { - let domain = match cfg.server { + let domain = match addr { SocketAddr::V4(_) => Domain::IPV4, SocketAddr::V6(_) => Domain::IPV6, }; @@ -47,7 +73,7 @@ impl Server { let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)) .map_err(|err| Error::Socket("failed to create socks5 server socket", err))?; - if let Some(dual_stack) = cfg.dual_stack { + if let Some(dual_stack) = dual_stack { socket.set_only_v6(!dual_stack).map_err(|err| { Error::Socket("socks5 server dual-stack socket setting error", err) })?; @@ -62,7 +88,7 @@ impl Server { })?; socket - .bind(&SockAddr::from(cfg.server)) + .bind(&SockAddr::from(addr)) .map_err(|err| Error::Socket("failed to bind socks5 server socket", err))?; socket @@ -73,53 +99,58 @@ impl Server { .map_err(|err| Error::Socket("failed to create socks5 server socket", err))? }; - let auth: Arc<dyn Auth + Send + Sync> = match (cfg.username, cfg.password) { - (Some(username), Some(password)) => { - Arc::new(Password::new(username.into_bytes(), password.into_bytes())) - } + let auth: Arc<dyn Auth + Send + Sync> = match (username, password) { + (Some(username), Some(password)) => Arc::new(Password::new(username, password)), (None, None) => Arc::new(NoAuth), _ => return Err(Error::InvalidSocks5Auth), }; - let server = Self { + Ok(Self { inner: Socks5Server::new(socket, auth), - dual_stack: cfg.dual_stack, - max_pkt_size: cfg.max_packet_size, + dual_stack, + max_pkt_size, next_assoc_id: AtomicU16::new(0), - udp_sessions: Mutex::new(HashMap::new()), - }; - - SERVER - .set(server) - .map_err(|_| "socks5 server already initialized") - .unwrap(); - - Ok(()) + }) } pub async fn start() { - log::warn!("[socks5] server started, listening on {}", Self::addr()); + let server = SERVER.get().unwrap(); + + log::warn!( + "[socks5] server started, listening on {}", + server.inner.local_addr().unwrap() + ); loop { - match SERVER.get().unwrap().inner.accept().await { + match server.inner.accept().await { Ok((conn, addr)) => { log::debug!("[socks5] [{addr}] connection established"); + tokio::spawn(async move { - let res = match conn.handshake().await { - Ok(Connection::Associate(associate, addr)) => { - Self::handle_associate(associate, addr).await + match conn.handshake().await { + Ok(Connection::Associate(associate, _)) => { + let assoc_id = server.next_assoc_id.fetch_add(1, Ordering::Relaxed); + log::info!("[socks5] [{addr}] [associate] [{assoc_id:#06x}]"); + Self::handle_associate( + associate, + assoc_id, + server.dual_stack, + server.max_pkt_size, + ) + .await; } - Ok(Connection::Bind(bind, addr)) => Self::handle_bind(bind, addr).await, - Ok(Connection::Connect(connect, addr)) => { - Self::handle_connect(connect, addr).await + Ok(Connection::Bind(bind, _)) => { + log::info!("[socks5] [{addr}] [bind]"); + Self::handle_bind(bind).await; } - Err(err) => Err(Error::from(err)), + Ok(Connection::Connect(connect, target_addr)) => { + log::info!("[socks5] [{addr}] [connect] [{target_addr}]"); + Self::handle_connect(connect, target_addr).await; + } + Err(err) => log::warn!("[socks5] [{addr}] handshake error: {err}"), }; - match res { - Ok(()) => log::debug!("[socks5] [{addr}] connection closed"), - Err(err) => log::warn!("[socks5] [{addr}] {err}"), - } + log::debug!("[socks5] [{addr}] connection closed"); }); } Err(err) => log::warn!("[socks5] failed to establish connection: {err}"), @@ -129,87 +160,145 @@ impl Server { async fn handle_associate( assoc: Associate<associate::NeedReply>, - _addr: Address, - ) -> Result<(), Error> { - async fn get_assoc_socket() -> Result<Arc<AssociatedUdpSocket>, Error> { - let domain = match Server::addr() { - SocketAddr::V4(_) => Domain::IPV4, - SocketAddr::V6(_) => Domain::IPV6, - }; + assoc_id: u16, + dual_stack: Option<bool>, + max_pkt_size: usize, + ) { + let peer_addr = assoc.peer_addr().unwrap(); + let local_ip = assoc.local_addr().unwrap().ip(); - let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP)).map_err(|err| { - Error::Socket("failed to create socks5 server UDP associate socket", err) - })?; + match UdpSession::new(assoc_id, peer_addr, local_ip, dual_stack, max_pkt_size) { + Ok(session) => { + let local_addr = session.local_addr().unwrap(); + log::debug!( + "[socks5] [{peer_addr}] [associate] [{assoc_id:#06x}] bound to {local_addr}" + ); - if let Some(dual_stack) = Server::dual_stack() { - socket.set_only_v6(!dual_stack).map_err(|err| { - Error::Socket( - "socks5 server UDP associate dual-stack socket setting error", - err, - ) - })?; - } + let mut assoc = match assoc + .reply(Reply::Succeeded, Address::SocketAddress(local_addr)) + .await + { + Ok(assoc) => assoc, + Err(err) => { + log::warn!("[socks5] [{peer_addr}] [associate] [{assoc_id:#06x}] command reply error: {err}"); + return; + } + }; - socket.set_nonblocking(true).map_err(|err| { - Error::Socket( - "failed setting socks5 server UDP associate socket as non-blocking", - err, - ) - })?; + UDP_SESSIONS + .get() + .unwrap() + .lock() + .insert(assoc_id, session.clone()); - socket - .bind(&SockAddr::from(Server::addr())) - .map_err(|err| { - Error::Socket("failed to bind socks5 server UDP associate socket", err) - })?; + let handle_local_incoming_pkt = async move { + loop { + let (pkt, target_addr) = match session.recv().await { + Ok(res) => res, + Err(err) => { + log::warn!("[socks5] [{peer_addr}] [associate] [{assoc_id:#06x}] failed to receive UDP packet: {err}"); + continue; + } + }; - let socket = UdpSocket::from_std(StdUdpSocket::from(socket)).map_err(|err| { - Error::Socket("failed to create socks5 server UDP associate socket", err) - })?; + let forward = async move { + let target_addr = match target_addr { + Address::DomainAddress(domain, port) => { + TuicAddress::DomainAddress(domain, port) + } + Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), + }; - Ok(Arc::new(AssociatedUdpSocket::from(( - socket, - Server::max_pkt_size(), - )))) - } + match TuicConnection::get().await { + Ok(conn) => conn.packet(pkt, target_addr, assoc_id).await, + Err(err) => Err(err), + } + }; - match get_assoc_socket().await { - Ok(assoc_socket) => { - let assoc = assoc - .reply( - Reply::Succeeded, - Address::SocketAddress(assoc_socket.local_addr().unwrap()), - ) - .await?; - Self::send_pkt(assoc, assoc_socket).await + tokio::spawn(async move { + match forward.await { + Ok(()) => {} + Err(err) => { + log::warn!("[socks5] [{peer_addr}] [associate] [{assoc_id:#06x}] failed relaying UDP packet: {err}"); + } + } + }); + } + }; + + match tokio::select! { + res = assoc.wait_until_closed() => res, + _ = handle_local_incoming_pkt => unreachable!(), + } { + Ok(()) => {} + Err(err) => { + log::warn!("[socks5] [{peer_addr}] [associate] [{assoc_id:#06x}] associate connection error: {err}") + } + } + + log::debug!( + "[socks5] [{peer_addr}] [associate] [{assoc_id:#06x}] stopped associating" + ); + + UDP_SESSIONS + .get() + .unwrap() + .lock() + .remove(&assoc_id) + .unwrap(); + + let res = match TuicConnection::get().await { + Ok(conn) => conn.dissociate(assoc_id).await, + Err(err) => Err(err), + }; + + match res { + Ok(()) => {} + Err(err) => log::warn!("[socks5] [{peer_addr}] [associate] [{assoc_id:#06x}] failed stoping UDP relaying session: {err}"), + } } Err(err) => { - log::warn!("[socks5] failed to create associated socket: {err}"); - let mut assoc = assoc + log::warn!("[socks5] [{peer_addr}] [associate] [{assoc_id:#06x}] failed setting up UDP associate session: {err}"); + + match assoc .reply(Reply::GeneralFailure, Address::unspecified()) - .await?; - let _ = assoc.shutdown().await; - Ok(()) + .await + { + Ok(mut assoc) => { + let _ = assoc.shutdown().await; + } + Err(err) => { + log::warn!("[socks5] [{peer_addr}] [associate] [{assoc_id:#06x}] command reply error: {err}") + } + } } } } - async fn handle_bind(bind: Bind<bind::NeedFirstReply>, _addr: Address) -> Result<(), Error> { - let mut conn = bind + async fn handle_bind(bind: Bind<bind::NeedFirstReply>) { + let peer_addr = bind.peer_addr().unwrap(); + log::warn!("[socks5] [{peer_addr}] [bind] command not supported"); + + match bind .reply(Reply::CommandNotSupported, Address::unspecified()) - .await?; - let _ = conn.shutdown().await; - Ok(()) + .await + { + Ok(mut bind) => { + let _ = bind.shutdown().await; + } + Err(err) => log::warn!("[socks5] [{peer_addr}] [bind] command reply error: {err}"), + } } - async fn handle_connect(conn: Connect<connect::NeedReply>, addr: Address) -> Result<(), Error> { + async fn handle_connect(conn: Connect<connect::NeedReply>, addr: Address) { + let peer_addr = conn.peer_addr().unwrap(); let target_addr = match addr { Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), }; let relay = match TuicConnection::get().await { - Ok(conn) => conn.connect(target_addr).await, + Ok(conn) => conn.connect(target_addr.clone()).await, Err(err) => Err(err), }; @@ -219,140 +308,150 @@ impl Server { match conn.reply(Reply::Succeeded, Address::unspecified()).await { Ok(mut conn) => match io::copy_bidirectional(&mut conn, &mut relay).await { - Ok(_) => Ok(()), + Ok(_) => {} Err(err) => { let _ = conn.shutdown().await; let _ = relay.get_mut().reset(VarInt::from_u32(0)); - Err(Error::from(err)) + log::warn!("[socks5] [{peer_addr}] [connect] [{target_addr}] TCP stream relaying error: {err}"); } }, Err(err) => { let _ = relay.shutdown().await; - Err(Error::from(err)) + log::warn!("[socks5] [{peer_addr}] [connect] [{target_addr}] command reply error: {err}"); } } } - Err(relay_err) => { - log::error!("[connection] {relay_err}"); - let mut conn = conn + Err(err) => { + log::warn!("[socks5] [{peer_addr}] [connect] [{target_addr}] unable to relay TCP stream: {err}"); + + match conn .reply(Reply::GeneralFailure, Address::unspecified()) - .await?; - let _ = conn.shutdown().await; - Ok(()) - } - } - } - - async fn send_pkt( - mut assoc: Associate<associate::Ready>, - assoc_socket: Arc<AssociatedUdpSocket>, - ) -> Result<(), Error> { - let assoc_id = SERVER - .get() - .unwrap() - .next_assoc_id - .fetch_add(1, Ordering::AcqRel); - - SERVER - .get() - .unwrap() - .udp_sessions - .lock() - .insert(assoc_id, assoc_socket.clone()); - - let mut connected = None; - - async fn accept_pkt( - assoc_socket: &AssociatedUdpSocket, - connected: &mut Option<SocketAddr>, - assoc_id: u16, - ) -> Result<(), Error> { - let (pkt, frag, dst_addr, src_addr) = assoc_socket.recv_from().await?; - - if let Some(connected) = connected { - if connected != &src_addr { - Err(IoError::new( - ErrorKind::Other, - format!("invalid source address: {src_addr}"), - ))?; + .await + { + Ok(mut conn) => { + let _ = conn.shutdown().await; + } + Err(err) => { + log::warn!("[socks5] [{peer_addr}] [connect] [{target_addr}] command reply error: {err}") + } } - } else { - assoc_socket.connect(src_addr).await?; - *connected = Some(src_addr); } - - if frag != 0 { - Err(IoError::new( - ErrorKind::Other, - "fragmented packet is not supported", - ))?; - } - - let target_addr = match dst_addr { - Address::DomainAddress(domain, port) => TuicAddress::DomainAddress(domain, port), - Address::SocketAddress(addr) => TuicAddress::SocketAddress(addr), - }; - - let res = match TuicConnection::get().await { - Ok(conn) => conn.packet(pkt, target_addr, assoc_id).await, - Err(err) => Err(err), - }; - - match res { - Ok(()) => {} - Err(err) => log::error!("[connection] {err}"), - } - - Ok(()) } - - let res = tokio::select! { - res = assoc.wait_until_closed() => res, - _ = async { loop { - if let Err(err) = accept_pkt(&assoc_socket, &mut connected, assoc_id).await { - log::warn!("[socks5] {err}"); - } - }} => unreachable!(), - }; - - let _ = assoc.shutdown().await; - SERVER.get().unwrap().udp_sessions.lock().remove(&assoc_id); - - let dissoc_res = match TuicConnection::get().await { - Ok(conn) => conn.dissociate(assoc_id).await, - Err(err) => Err(err), - }; - - match dissoc_res { - Ok(()) => {} - Err(err) => log::error!("[connection] [dissociate] {err}"), - } - - Ok(res?) - } - - pub async fn recv_pkt(pkt: Bytes, addr: Address, assoc_id: u16) { - let assoc_socket = { - let sessions = SERVER.get().unwrap().udp_sessions.lock(); - let Some(assoc_socket) = sessions.get(&assoc_id) else { unreachable!() }; - assoc_socket.clone() - }; - - match assoc_socket.send(pkt, 0, addr).await { - Ok(_) => {} - Err(err) => log::error!("[socks5] [send] {err}"), - } - } - - fn addr() -> SocketAddr { - SERVER.get().unwrap().inner.local_addr().unwrap() - } - - fn dual_stack() -> Option<bool> { - SERVER.get().unwrap().dual_stack - } - - fn max_pkt_size() -> usize { - SERVER.get().unwrap().max_pkt_size + } +} + +#[derive(Clone)] +pub struct UdpSession { + socket: Arc<AssociatedUdpSocket>, + assoc_id: u16, + ctrl_addr: SocketAddr, +} + +impl UdpSession { + fn new( + assoc_id: u16, + ctrl_addr: SocketAddr, + local_ip: IpAddr, + dual_stack: Option<bool>, + max_pkt_size: usize, + ) -> Result<Self, Error> { + let domain = match local_ip { + IpAddr::V4(_) => Domain::IPV4, + IpAddr::V6(_) => Domain::IPV6, + }; + + let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP)).map_err(|err| { + Error::Socket("failed to create socks5 server UDP associate socket", err) + })?; + + if let Some(dual_stack) = dual_stack { + socket.set_only_v6(!dual_stack).map_err(|err| { + Error::Socket( + "socks5 server UDP associate dual-stack socket setting error", + err, + ) + })?; + } + + socket.set_nonblocking(true).map_err(|err| { + Error::Socket( + "failed setting socks5 server UDP associate socket as non-blocking", + err, + ) + })?; + + socket + .bind(&SockAddr::from(SocketAddr::from((local_ip, 0)))) + .map_err(|err| { + Error::Socket("failed to bind socks5 server UDP associate socket", err) + })?; + + let socket = UdpSocket::from_std(StdUdpSocket::from(socket)).map_err(|err| { + Error::Socket("failed to create socks5 server UDP associate socket", err) + })?; + + Ok(Self { + socket: Arc::new(AssociatedUdpSocket::from((socket, max_pkt_size))), + assoc_id, + ctrl_addr, + }) + } + + pub async fn send(&self, pkt: Bytes, src_addr: Address) -> Result<(), Error> { + let src_addr_display = src_addr.to_string(); + + log::debug!( + "[socks5] [{ctrl_addr}] [associate] [{assoc_id:#06x}] send packet from {src_addr_display} to {dst_addr}", + ctrl_addr = self.ctrl_addr, + assoc_id = self.assoc_id, + dst_addr = self.socket.peer_addr().unwrap(), + ); + + if let Err(err) = self.socket.send(pkt, 0, src_addr).await { + log::warn!( + "[socks5] [{ctrl_addr}] [associate] [{assoc_id:#06x}] send packet from {src_addr_display} to {dst_addr} error: {err}", + ctrl_addr = self.ctrl_addr, + assoc_id = self.assoc_id, + dst_addr = self.socket.peer_addr().unwrap(), + ); + + return Err(Error::Io(err)); + } + + Ok(()) + } + + pub async fn recv(&self) -> Result<(Bytes, Address), Error> { + let (pkt, frag, dst_addr, src_addr) = self.socket.recv_from().await?; + + if let Ok(connected_addr) = self.socket.peer_addr() { + if src_addr != connected_addr { + Err(IoError::new( + ErrorKind::Other, + format!("invalid source address: {src_addr}"), + ))?; + } + } else { + self.socket.connect(src_addr).await?; + } + + if frag != 0 { + Err(IoError::new( + ErrorKind::Other, + "fragmented packet is not supported", + ))?; + } + + log::debug!( + "[socks5] [{ctrl_addr}] [associate] [{assoc_id:#06x}] receive packet from {src_addr} to {dst_addr}", + ctrl_addr = self.ctrl_addr, + assoc_id = self.assoc_id + ); + + Ok((pkt, dst_addr)) + } + + fn local_addr(&self) -> Result<SocketAddr, IoError> { + self.socket.local_addr() } }