diff --git a/tuic-server/src/main.rs b/tuic-server/src/main.rs index d52e01a..26c7dc5 100644 --- a/tuic-server/src/main.rs +++ b/tuic-server/src/main.rs @@ -2,9 +2,10 @@ use self::{ config::{Config, ConfigError}, server::Server, }; -use quinn::{crypto::ExportKeyingMaterialError, ConnectionError}; -use std::{env, io::Error as IoError, process}; +use quinn::ConnectionError; +use std::{env, io::Error as IoError, net::SocketAddr, process}; use thiserror::Error; +use tuic::Address; use tuic_quinn::Error as ModelError; mod config; @@ -50,4 +51,6 @@ pub enum Error { AuthFailed, #[error("received packet from unexpected source")] UnexpectedPacketSource, + #[error("{0} resolved to {1} but IPv6 UDP relay disabled")] + UdpRelayIpv6Disabled(Address, SocketAddr), } diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 9101202..6969bc4 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -4,7 +4,10 @@ use crossbeam_utils::atomic::AtomicCell; use parking_lot::Mutex; use quinn::{Connecting, Connection as QuinnConnection, Endpoint, RecvStream, SendStream, VarInt}; use std::{ + collections::{hash_map::Entry, HashMap}, future::Future, + io::{Error as IoError, ErrorKind}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, pin::Pin, sync::{ atomic::{AtomicBool, Ordering}, @@ -12,11 +15,22 @@ use std::{ }, task::{Context, Poll, Waker}, }; -use tuic_quinn::{side, Connection as Model, Task}; +use tokio::{ + io::{self, AsyncWriteExt}, + net::{self, TcpStream, UdpSocket}, + sync::{ + oneshot::{self, Receiver, Sender}, + Mutex as AsyncMutex, + }, +}; +use tokio_util::compat::FuturesAsyncReadCompatExt; +use tuic::Address; +use tuic_quinn::{side, Connect, Connection as Model, Packet, Task}; pub struct Server { ep: Endpoint, token: Arc<[u8]>, + udp_relay_ipv6: bool, zero_rtt_handshake: bool, } @@ -31,6 +45,7 @@ impl Server { tokio::spawn(Connection::init( conn, self.token.clone(), + self.udp_relay_ipv6, self.zero_rtt_handshake, )); } @@ -42,13 +57,20 @@ struct Connection { inner: QuinnConnection, model: Model, token: Arc<[u8]>, + udp_relay_ipv6: bool, is_authed: IsAuthed, + udp_sessions: Arc>>, udp_relay_mode: Arc>>, } impl Connection { - pub async fn init(conn: Connecting, token: Arc<[u8]>, zero_rtt_handshake: bool) { - match Self::handshake(conn, token, zero_rtt_handshake).await { + pub async fn init( + conn: Connecting, + token: Arc<[u8]>, + udp_relay_ipv6: bool, + zero_rtt_handshake: bool, + ) { + match Self::handshake(conn, token, udp_relay_ipv6, zero_rtt_handshake).await { Ok(conn) => loop { if conn.is_closed() { break; @@ -66,6 +88,7 @@ impl Connection { async fn handshake( conn: Connecting, token: Arc<[u8]>, + udp_relay_ipv6: bool, zero_rtt_handshake: bool, ) -> Result { let conn = if zero_rtt_handshake { @@ -84,7 +107,9 @@ impl Connection { inner: conn.clone(), model: Model::::new(conn), token, + udp_relay_ipv6, is_authed: IsAuthed::new(), + udp_sessions: Arc::new(AsyncMutex::new(HashMap::new())), udp_relay_mode: Arc::new(AtomicCell::new(None)), }) } @@ -135,8 +160,17 @@ impl Connection { } match pre_process(&self, recv).await { - Ok(Task::Packet(pkt)) => todo!(), - Ok(Task::Dissociate(assoc_id)) => todo!(), + Ok(Task::Packet(pkt)) => { + self.set_udp_relay_mode(UdpRelayMode::Quic); + match self.handle_packet(pkt).await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + } + Ok(Task::Dissociate(assoc_id)) => match self.handle_dissociate(assoc_id).await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + }, Ok(_) => unreachable!(), Err(err) => { eprintln!("{err}"); @@ -163,7 +197,10 @@ impl Connection { } match pre_process(&self, send, recv).await { - Ok(Task::Connect(conn)) => todo!(), + Ok(Task::Connect(conn)) => match self.handle_connect(conn).await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + }, Ok(_) => unreachable!(), Err(err) => { eprintln!("{err}"); @@ -192,8 +229,14 @@ impl Connection { } match pre_process(&self, dg).await { - Ok(Task::Packet(pkt)) => todo!(), - Ok(Task::Heartbeat) => todo!(), + Ok(Task::Packet(pkt)) => { + self.set_udp_relay_mode(UdpRelayMode::Native); + match self.handle_packet(pkt).await { + Ok(()) => {} + Err(err) => eprintln!("{err}"), + } + } + Ok(Task::Heartbeat) => {} Ok(_) => unreachable!(), Err(err) => { eprintln!("{err}"); @@ -203,6 +246,77 @@ impl Connection { } } + async fn handle_connect(&self, conn: Connect) -> Result<(), Error> { + let mut stream = None; + let mut last_err = None; + + match resolve_dns(conn.addr()).await { + Ok(addrs) => { + for addr in addrs { + match TcpStream::connect(addr).await { + Ok(s) => { + stream = Some(s); + break; + } + Err(err) => last_err = Some(err), + } + } + } + Err(err) => last_err = Some(err), + } + + if let Some(mut stream) = stream { + let mut conn = conn.compat(); + let res = io::copy_bidirectional(&mut conn, &mut stream).await; + let _ = conn.shutdown().await; + let _ = stream.shutdown().await; + res?; + Ok(()) + } else { + let _ = conn.compat().shutdown().await; + Err(last_err + .unwrap_or_else(|| IoError::new(ErrorKind::NotFound, "no address resolved")))? + } + } + + async fn handle_packet(&self, pkt: Packet) -> Result<(), Error> { + let Some((pkt, addr, assoc_id)) = pkt.accept().await? else { + return Ok(()); + }; + + let (socket_v4, socket_v6) = match self.udp_sessions.lock().await.entry(assoc_id) { + Entry::Occupied(mut entry) => { + let session = entry.get_mut(); + (session.socket_v4.clone(), session.socket_v6.clone()) + } + Entry::Vacant(entry) => { + let session = entry + .insert(UdpSession::new(assoc_id, self.clone(), self.udp_relay_ipv6).await?); + (session.socket_v4.clone(), session.socket_v6.clone()) + } + }; + + let Some(socket_addr) = resolve_dns(&addr).await?.next() else { + Err(IoError::new(ErrorKind::NotFound, "no address resolved"))? + }; + + let socket = match socket_addr { + SocketAddr::V4(_) => socket_v4, + SocketAddr::V6(_) => { + socket_v6.ok_or_else(|| Error::UdpRelayIpv6Disabled(addr, socket_addr))? + } + }; + + socket.send_to(&pkt, socket_addr).await?; + + Ok(()) + } + + async fn handle_dissociate(&self, assoc_id: u16) -> Result<(), Error> { + self.udp_sessions.lock().await.remove(&assoc_id); + Ok(()) + } + fn set_authed(&self) { self.is_authed.set_authed(); } @@ -228,6 +342,69 @@ impl Connection { } } +async fn resolve_dns(addr: &Address) -> Result, IoError> { + match addr { + Address::None => Err(IoError::new(ErrorKind::InvalidInput, "empty address")), + Address::DomainAddress(domain, port) => Ok(net::lookup_host((domain.as_str(), *port)) + .await? + .collect::>() + .into_iter()), + Address::SocketAddress(addr) => Ok(vec![*addr].into_iter()), + } +} + +struct UdpSession { + socket_v4: Arc, + socket_v6: Option>, + cancel: Option>, +} + +impl UdpSession { + async fn new(assoc_id: u16, conn: Connection, udp_relay_ipv6: bool) -> Result { + let socket_v4 = + Arc::new(UdpSocket::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))).await?); + let socket_v6 = if udp_relay_ipv6 { + Some(Arc::new( + UdpSocket::bind(SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))).await?, + )) + } else { + None + }; + + let (tx, rx) = oneshot::channel(); + + tokio::spawn(Self::listen_incoming( + assoc_id, + conn, + socket_v4.clone(), + socket_v6.clone(), + rx, + )); + + Ok(Self { + socket_v4, + socket_v6, + cancel: Some(tx), + }) + } + + async fn listen_incoming( + assoc_id: u16, + conn: Connection, + socket_v4: Arc, + socket_v6: Option>, + cancel: Receiver<()>, + ) { + todo!() + } +} + +impl Drop for UdpSession { + fn drop(&mut self) { + let _ = self.cancel.take().unwrap().send(()); + } +} + #[derive(Clone)] struct IsAuthed { is_authed: Arc,