From 20a521f02d495890f14b843e56ef56ef890fd6be Mon Sep 17 00:00:00 2001 From: yaling888 <73897884+yaling888@users.noreply.github.com> Date: Sat, 8 Apr 2023 19:20:14 +0800 Subject: [PATCH] Feature: bind socket to interface by native API on Windows (#2662) --- component/dhcp/conn.go | 12 +++- component/dialer/bind_others.go | 47 +-------------- component/dialer/bind_windows.go | 98 ++++++++++++++++++++++++++++++++ component/dialer/dialer.go | 20 ++++++- component/dialer/fallbackbind.go | 90 +++++++++++++++++++++++++++++ component/dialer/options.go | 7 +++ 6 files changed, 224 insertions(+), 50 deletions(-) create mode 100644 component/dialer/bind_windows.go create mode 100644 component/dialer/fallbackbind.go diff --git a/component/dhcp/conn.go b/component/dhcp/conn.go index 90a9e25..5b71d3c 100644 --- a/component/dhcp/conn.go +++ b/component/dhcp/conn.go @@ -14,5 +14,15 @@ func ListenDHCPClient(ctx context.Context, ifaceName string) (net.PacketConn, er listenAddr = "255.255.255.255:68" } - return dialer.ListenPacket(ctx, "udp4", listenAddr, dialer.WithInterface(ifaceName), dialer.WithAddrReuse(true)) + options := []dialer.Option{ + dialer.WithInterface(ifaceName), + dialer.WithAddrReuse(true), + } + + // fallback bind on windows, because syscall bind can not receive broadcast + if runtime.GOOS == "windows" { + options = append(options, dialer.WithFallbackBind(true)) + } + + return dialer.ListenPacket(ctx, "udp4", listenAddr, options...) } diff --git a/component/dialer/bind_others.go b/component/dialer/bind_others.go index 51b2ef6..0b1d8b1 100644 --- a/component/dialer/bind_others.go +++ b/component/dialer/bind_others.go @@ -1,57 +1,12 @@ -//go:build !linux && !darwin +//go:build !linux && !darwin && !windows package dialer import ( "net" "strconv" - "strings" - - "github.com/Dreamacro/clash/component/iface" ) -func lookupLocalAddr(ifaceName string, network string, destination net.IP, port int) (net.Addr, error) { - ifaceObj, err := iface.ResolveInterface(ifaceName) - if err != nil { - return nil, err - } - - var addr *net.IPNet - switch network { - case "udp4", "tcp4": - addr, err = ifaceObj.PickIPv4Addr(destination) - case "tcp6", "udp6": - addr, err = ifaceObj.PickIPv6Addr(destination) - default: - if destination != nil { - if destination.To4() != nil { - addr, err = ifaceObj.PickIPv4Addr(destination) - } else { - addr, err = ifaceObj.PickIPv6Addr(destination) - } - } else { - addr, err = ifaceObj.PickIPv4Addr(destination) - } - } - if err != nil { - return nil, err - } - - if strings.HasPrefix(network, "tcp") { - return &net.TCPAddr{ - IP: addr.IP, - Port: port, - }, nil - } else if strings.HasPrefix(network, "udp") { - return &net.UDPAddr{ - IP: addr.IP, - Port: port, - }, nil - } - - return nil, iface.ErrAddrNotFound -} - func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, network string, destination net.IP) error { if !destination.IsGlobalUnicast() { return nil diff --git a/component/dialer/bind_windows.go b/component/dialer/bind_windows.go new file mode 100644 index 0000000..e9c198e --- /dev/null +++ b/component/dialer/bind_windows.go @@ -0,0 +1,98 @@ +package dialer + +import ( + "encoding/binary" + "net" + "strings" + "syscall" + "unsafe" + + "github.com/Dreamacro/clash/component/iface" + + "golang.org/x/sys/windows" +) + +const ( + IP_UNICAST_IF = 31 + IPV6_UNICAST_IF = 31 +) + +type controlFn = func(network, address string, c syscall.RawConn) error + +func bindControl(ifaceIdx int, chain controlFn) controlFn { + return func(network, address string, c syscall.RawConn) (err error) { + defer func() { + if err == nil && chain != nil { + err = chain(network, address, c) + } + }() + + ipStr, _, err := net.SplitHostPort(address) + if err == nil { + ip := net.ParseIP(ipStr) + if ip != nil && !ip.IsGlobalUnicast() { + return + } + } + + var innerErr error + err = c.Control(func(fd uintptr) { + if ipStr == "" && strings.HasPrefix(network, "udp") { + // When listening udp ":0", we should bind socket to interface4 and interface6 at the same time + // and ignore the error of bind6 + _ = bindSocketToInterface6(windows.Handle(fd), ifaceIdx) + innerErr = bindSocketToInterface4(windows.Handle(fd), ifaceIdx) + return + } + switch network { + case "tcp4", "udp4": + innerErr = bindSocketToInterface4(windows.Handle(fd), ifaceIdx) + case "tcp6", "udp6": + innerErr = bindSocketToInterface6(windows.Handle(fd), ifaceIdx) + } + }) + + if innerErr != nil { + err = innerErr + } + + return + } +} + +func bindSocketToInterface4(handle windows.Handle, ifaceIdx int) error { + // MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. + // Ref: https://learn.microsoft.com/en-us/windows/win32/winsock/ipproto-ip-socket-options + var bytes [4]byte + binary.BigEndian.PutUint32(bytes[:], uint32(ifaceIdx)) + index := *(*uint32)(unsafe.Pointer(&bytes[0])) + err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(index)) + if err != nil { + return err + } + return nil +} + +func bindSocketToInterface6(handle windows.Handle, ifaceIdx int) error { + return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, ifaceIdx) +} + +func bindIfaceToDialer(ifaceName string, dialer *net.Dialer, _ string, _ net.IP) error { + ifaceObj, err := iface.ResolveInterface(ifaceName) + if err != nil { + return err + } + + dialer.Control = bindControl(ifaceObj.Index, dialer.Control) + return nil +} + +func bindIfaceToListenConfig(ifaceName string, lc *net.ListenConfig, _, address string) (string, error) { + ifaceObj, err := iface.ResolveInterface(ifaceName) + if err != nil { + return "", err + } + + lc.Control = bindControl(ifaceObj.Index, lc.Control) + return address, nil +} diff --git a/component/dialer/dialer.go b/component/dialer/dialer.go index d2b5695..fe380d7 100644 --- a/component/dialer/dialer.go +++ b/component/dialer/dialer.go @@ -51,7 +51,15 @@ func ListenPacket(ctx context.Context, network, address string, options ...Optio lc := &net.ListenConfig{} if cfg.interfaceName != "" { - addr, err := bindIfaceToListenConfig(cfg.interfaceName, lc, network, address) + var ( + addr string + err error + ) + if cfg.fallbackBind { + addr, err = fallbackBindIfaceToListenConfig(cfg.interfaceName, lc, network, address) + } else { + addr, err = bindIfaceToListenConfig(cfg.interfaceName, lc, network, address) + } if err != nil { return nil, err } @@ -83,8 +91,14 @@ func dialContext(ctx context.Context, network string, destination net.IP, port s dialer := &net.Dialer{} if opt.interfaceName != "" { - if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil { - return nil, err + if opt.fallbackBind { + if err := fallbackBindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil { + return nil, err + } + } else { + if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil { + return nil, err + } } } if opt.routingMark != 0 { diff --git a/component/dialer/fallbackbind.go b/component/dialer/fallbackbind.go new file mode 100644 index 0000000..961cfd3 --- /dev/null +++ b/component/dialer/fallbackbind.go @@ -0,0 +1,90 @@ +package dialer + +import ( + "net" + "strconv" + "strings" + + "github.com/Dreamacro/clash/component/iface" +) + +func lookupLocalAddr(ifaceName string, network string, destination net.IP, port int) (net.Addr, error) { + ifaceObj, err := iface.ResolveInterface(ifaceName) + if err != nil { + return nil, err + } + + var addr *net.IPNet + switch network { + case "udp4", "tcp4": + addr, err = ifaceObj.PickIPv4Addr(destination) + case "tcp6", "udp6": + addr, err = ifaceObj.PickIPv6Addr(destination) + default: + if destination != nil { + if destination.To4() != nil { + addr, err = ifaceObj.PickIPv4Addr(destination) + } else { + addr, err = ifaceObj.PickIPv6Addr(destination) + } + } else { + addr, err = ifaceObj.PickIPv4Addr(destination) + } + } + if err != nil { + return nil, err + } + + if strings.HasPrefix(network, "tcp") { + return &net.TCPAddr{ + IP: addr.IP, + Port: port, + }, nil + } else if strings.HasPrefix(network, "udp") { + return &net.UDPAddr{ + IP: addr.IP, + Port: port, + }, nil + } + + return nil, iface.ErrAddrNotFound +} + +func fallbackBindIfaceToDialer(ifaceName string, dialer *net.Dialer, network string, destination net.IP) error { + if !destination.IsGlobalUnicast() { + return nil + } + + local := uint64(0) + if dialer.LocalAddr != nil { + _, port, err := net.SplitHostPort(dialer.LocalAddr.String()) + if err == nil { + local, _ = strconv.ParseUint(port, 10, 16) + } + } + + addr, err := lookupLocalAddr(ifaceName, network, destination, int(local)) + if err != nil { + return err + } + + dialer.LocalAddr = addr + + return nil +} + +func fallbackBindIfaceToListenConfig(ifaceName string, _ *net.ListenConfig, network, address string) (string, error) { + _, port, err := net.SplitHostPort(address) + if err != nil { + port = "0" + } + + local, _ := strconv.ParseUint(port, 10, 16) + + addr, err := lookupLocalAddr(ifaceName, network, nil, int(local)) + if err != nil { + return "", err + } + + return addr.String(), nil +} diff --git a/component/dialer/options.go b/component/dialer/options.go index 2d88409..9773ebb 100644 --- a/component/dialer/options.go +++ b/component/dialer/options.go @@ -10,6 +10,7 @@ var ( type option struct { interfaceName string + fallbackBind bool addrReuse bool routingMark int } @@ -22,6 +23,12 @@ func WithInterface(name string) Option { } } +func WithFallbackBind(fallback bool) Option { + return func(opt *option) { + opt.fallbackBind = fallback + } +} + func WithAddrReuse(reuse bool) Option { return func(opt *option) { opt.addrReuse = reuse