diff --git a/tuic-client/src/main.rs b/tuic-client/src/main.rs index b429fd8..cb79481 100644 --- a/tuic-client/src/main.rs +++ b/tuic-client/src/main.rs @@ -43,7 +43,7 @@ async fn main() { } } - match Socks5Server::set_config(cfg.local) { + match Socks5Server::set_config(cfg.local).await { Ok(()) => {} Err(err) => { eprintln!("{err}"); diff --git a/tuic-client/src/socks5.rs b/tuic-client/src/socks5.rs index 2eef0ea..1d8af58 100644 --- a/tuic-client/src/socks5.rs +++ b/tuic-client/src/socks5.rs @@ -3,7 +3,7 @@ use bytes::Bytes; use once_cell::sync::OnceCell; use parking_lot::Mutex; use quinn::VarInt; -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; +use socket2::Socket; use socks5_proto::{Address, Reply}; use socks5_server::{ auth::{NoAuth, Password}, @@ -13,7 +13,7 @@ use socks5_server::{ use std::{ collections::HashMap, io::{Error as IoError, ErrorKind}, - net::{IpAddr, SocketAddr, TcpListener as StdTcpListener, UdpSocket as StdUdpSocket}, + net::{SocketAddr, TcpListener as StdTcpListener, UdpSocket as StdUdpSocket}, sync::{ atomic::{AtomicU16, Ordering}, Arc, @@ -38,22 +38,16 @@ pub struct Server { } impl Server { - pub fn set_config(cfg: Local) -> Result<(), Error> { + pub async fn set_config(cfg: Local) -> Result<(), Error> { let socket = { - let domain = match cfg.server.ip() { - IpAddr::V4(_) => Domain::IPV4, - IpAddr::V6(_) => Domain::IPV6, - }; - - let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?; + let socket = Socket::from(TcpListener::bind(&cfg.server).await?.into_std()?); if let Some(dual_stack) = cfg.dual_stack { socket.set_only_v6(!dual_stack)?; } socket.set_reuse_address(true)?; - socket.bind(&SockAddr::from(cfg.server))?; - socket.listen(128)?; + TcpListener::from_std(StdTcpListener::from(socket))? }; @@ -117,30 +111,29 @@ impl Server { assoc: Associate, _addr: Address, ) -> Result<(), Error> { - async fn get_assoc_socket() -> Result<(Arc, SocketAddr), IoError> { - let domain = match SERVER.get().unwrap().addr.ip() { - IpAddr::V4(_) => Domain::IPV4, - IpAddr::V6(_) => Domain::IPV6, - }; - - let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + async fn get_assoc_socket() -> Result, IoError> { + let socket = Socket::from( + UdpSocket::bind(SERVER.get().unwrap().addr) + .await? + .into_std()?, + ); if let Some(dual_stack) = SERVER.get().unwrap().dual_stack { socket.set_only_v6(!dual_stack)?; } - socket.bind(&SockAddr::from(SERVER.get().unwrap().addr))?; - let socket = AssociatedUdpSocket::from(( UdpSocket::from_std(StdUdpSocket::from(socket))?, SERVER.get().unwrap().max_pkt_size, )); - let addr = socket.local_addr()?; - Ok((Arc::new(socket), addr)) + Ok(Arc::new(socket)) } - match get_assoc_socket().await { + match get_assoc_socket() + .await + .and_then(|socket| socket.local_addr().map(|addr| (socket, addr))) + { Ok((assoc_socket, assoc_addr)) => { let assoc = assoc .reply(Reply::Succeeded, Address::SocketAddress(assoc_addr)) diff --git a/tuic-server/src/server.rs b/tuic-server/src/server.rs index 64216ab..8dfa5bf 100644 --- a/tuic-server/src/server.rs +++ b/tuic-server/src/server.rs @@ -13,12 +13,12 @@ use quinn::{ }; use register_count::{Counter, Register}; use rustls::{version, ServerConfig as RustlsServerConfig}; -use socket2::{Domain, Protocol, SockAddr, Socket, Type}; +use socket2::Socket; use std::{ collections::{hash_map::Entry, HashMap}, future::Future, io::{Error as IoError, ErrorKind}, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket}, pin::Pin, sync::{ atomic::{AtomicBool, AtomicUsize, Ordering}, @@ -95,24 +95,16 @@ impl Server { config.transport_config(Arc::new(tp_cfg)); - let domain = match cfg.server.ip() { - IpAddr::V4(_) => Domain::IPV4, - IpAddr::V6(_) => Domain::IPV6, - }; - - let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?; + let socket = Socket::from(StdUdpSocket::bind(cfg.server)?); if let Some(dual_stack) = cfg.dual_stack { socket.set_only_v6(!dual_stack)?; } - socket.bind(&SockAddr::from(cfg.server))?; - let socket = StdUdpSocket::from(socket); - let ep = Endpoint::new( EndpointConfig::default(), Some(config), - socket, + StdUdpSocket::from(socket), Arc::new(TokioRuntime), )?; @@ -593,12 +585,14 @@ impl UdpSession { let socket_v4 = Arc::new(UdpSocket::bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))).await?); let socket_v6 = if udp_relay_ipv6 { - let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + let socket = Socket::from( + UdpSocket::bind(SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))) + .await? + .into_std()?, + ); + socket.set_only_v6(true)?; - socket.bind(&SockAddr::from(SocketAddr::from(( - Ipv6Addr::UNSPECIFIED, - 0, - ))))?; + Some(Arc::new(UdpSocket::from_std(StdUdpSocket::from(socket))?)) } else { None