diff --git a/adapter/inbound/http.go b/adapter/inbound/http.go index 89960cf..b27ea88 100644 --- a/adapter/inbound/http.go +++ b/adapter/inbound/http.go @@ -2,6 +2,7 @@ package inbound import ( "net" + "net/netip" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/context" @@ -9,7 +10,7 @@ import ( ) // NewHTTP receive normal http request and return HTTPContext -func NewHTTP(target socks5.Addr, source net.Addr, conn net.Conn) *context.ConnContext { +func NewHTTP(target socks5.Addr, source net.Addr, originTarget net.Addr, conn net.Conn) *context.ConnContext { metadata := parseSocksAddr(target) metadata.NetWork = C.TCP metadata.Type = C.HTTP @@ -17,5 +18,8 @@ func NewHTTP(target socks5.Addr, source net.Addr, conn net.Conn) *context.ConnCo metadata.SrcIP = ip metadata.SrcPort = port } + if addrPort, err := netip.ParseAddrPort(originTarget.String()); err == nil { + metadata.OriginDst = addrPort + } return context.NewConnContext(conn, metadata) } diff --git a/adapter/inbound/https.go b/adapter/inbound/https.go index e7e9221..5c2b1a7 100644 --- a/adapter/inbound/https.go +++ b/adapter/inbound/https.go @@ -3,6 +3,7 @@ package inbound import ( "net" "net/http" + "net/netip" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/context" @@ -16,5 +17,8 @@ func NewHTTPS(request *http.Request, conn net.Conn) *context.ConnContext { metadata.SrcIP = ip metadata.SrcPort = port } + if addrPort, err := netip.ParseAddrPort(conn.LocalAddr().String()); err == nil { + metadata.OriginDst = addrPort + } return context.NewConnContext(conn, metadata) } diff --git a/adapter/inbound/packet.go b/adapter/inbound/packet.go index 80b136c..bfa1355 100644 --- a/adapter/inbound/packet.go +++ b/adapter/inbound/packet.go @@ -1,6 +1,9 @@ package inbound import ( + "net" + "net/netip" + C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/socks5" ) @@ -17,7 +20,7 @@ func (s *PacketAdapter) Metadata() *C.Metadata { } // NewPacket is PacketAdapter generator -func NewPacket(target socks5.Addr, packet C.UDPPacket, source C.Type) *PacketAdapter { +func NewPacket(target socks5.Addr, originTarget net.Addr, packet C.UDPPacket, source C.Type) *PacketAdapter { metadata := parseSocksAddr(target) metadata.NetWork = C.UDP metadata.Type = source @@ -25,7 +28,9 @@ func NewPacket(target socks5.Addr, packet C.UDPPacket, source C.Type) *PacketAda metadata.SrcIP = ip metadata.SrcPort = port } - + if addrPort, err := netip.ParseAddrPort(originTarget.String()); err == nil { + metadata.OriginDst = addrPort + } return &PacketAdapter{ UDPPacket: packet, metadata: metadata, diff --git a/adapter/inbound/socket.go b/adapter/inbound/socket.go index be71701..31bde9a 100644 --- a/adapter/inbound/socket.go +++ b/adapter/inbound/socket.go @@ -2,6 +2,7 @@ package inbound import ( "net" + "net/netip" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/context" @@ -17,6 +18,8 @@ func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *context.ConnCo metadata.SrcIP = ip metadata.SrcPort = port } - + if addrPort, err := netip.ParseAddrPort(conn.LocalAddr().String()); err == nil { + metadata.OriginDst = addrPort + } return context.NewConnContext(conn, metadata) } diff --git a/component/process/process.go b/component/process/process.go index 67a5df6..7ca2e2b 100644 --- a/component/process/process.go +++ b/component/process/process.go @@ -2,7 +2,7 @@ package process import ( "errors" - "net" + "net/netip" ) var ( @@ -16,6 +16,6 @@ const ( UDP = "udp" ) -func FindProcessName(network string, srcIP net.IP, srcPort int) (string, error) { - return findProcessName(network, srcIP, srcPort) +func FindProcessPath(network string, from netip.AddrPort, to netip.AddrPort) (string, error) { + return findProcessPath(network, from, to) } diff --git a/component/process/process_darwin.go b/component/process/process_darwin.go index 814161c..625012c 100644 --- a/component/process/process_darwin.go +++ b/component/process/process_darwin.go @@ -2,7 +2,7 @@ package process import ( "encoding/binary" - "net" + "net/netip" "strconv" "strings" "syscall" @@ -33,7 +33,7 @@ var structSize = func() int { } }() -func findProcessName(network string, ip net.IP, port int) (string, error) { +func findProcessPath(network string, from netip.AddrPort, _ netip.AddrPort) (string, error) { var spath string switch network { case TCP: @@ -44,7 +44,7 @@ func findProcessName(network string, ip net.IP, port int) (string, error) { return "", ErrInvalidNetwork } - isIPv4 := ip.To4() != nil + isIPv4 := from.Addr().Is4() value, err := syscall.Sysctl(spath) if err != nil { @@ -65,30 +65,36 @@ func findProcessName(network string, ip net.IP, port int) (string, error) { inp, so := i, i+104 srcPort := binary.BigEndian.Uint16(buf[inp+18 : inp+20]) - if uint16(port) != srcPort { + if from.Port() != srcPort { continue } + // FIXME: add dstPort check + // xinpcb_n.inp_vflag flag := buf[inp+44] var ( - srcIP net.IP + srcIP netip.Addr + srcIPOk bool srcIsIPv4 bool ) switch { case flag&0x1 > 0 && isIPv4: // ipv4 - srcIP = net.IP(buf[inp+76 : inp+80]) + srcIP, srcIPOk = netip.AddrFromSlice(buf[inp+76 : inp+80]) srcIsIPv4 = true case flag&0x2 > 0 && !isIPv4: // ipv6 - srcIP = net.IP(buf[inp+64 : inp+80]) + srcIP, srcIPOk = netip.AddrFromSlice(buf[inp+64 : inp+80]) default: continue } + if !srcIPOk { + continue + } - if ip.Equal(srcIP) { + if from.Addr() == srcIP { // FIXME: add dstIP check // xsocket_n.so_last_pid pid := readNativeUint32(buf[so+68 : so+72]) return getExecPathFromPID(pid) diff --git a/component/process/process_freebsd.go b/component/process/process_freebsd.go new file mode 100644 index 0000000..9491397 --- /dev/null +++ b/component/process/process_freebsd.go @@ -0,0 +1,217 @@ +package process + +import ( + "encoding/binary" + "fmt" + "net/netip" + "strconv" + "strings" + "unsafe" + + "golang.org/x/sys/unix" +) + +type Xinpgen12 [64]byte // size 64 + +type InEndpoints12 struct { + FPort [2]byte + LPort [2]byte + FAddr [16]byte + LAddr [16]byte + ZoneID uint32 +} // size 40 + +type XTcpcb12 struct { + Len uint32 // offset 0 + Padding1 [20]byte // offset 4 + SocketAddr uint64 // offset 24 + Padding2 [84]byte // offset 32 + Family uint32 // offset 116 + Padding3 [140]byte // offset 120 + InEndpoints InEndpoints12 // offset 260 + Padding4 [444]byte // offset 300 +} // size 744 + +type XInpcb12 struct { + Len uint32 // offset 0 + Padding1 [12]byte // offset 4 + SocketAddr uint64 // offset 16 + Padding2 [84]byte // offset 24 + Family uint32 // offset 108 + Padding3 [140]byte // offset 112 + InEndpoints InEndpoints12 // offset 252 + Padding4 [108]byte // offset 292 +} // size 400 + +type XFile12 struct { + Size uint64 // offset 0 + Pid uint32 // offset 8 + Padding1 [44]byte // offset 12 + DataAddr uint64 // offset 56 + Padding2 [64]byte // offset 64 +} // size 128 + +var majorVersion = func() int { + releaseVersion, err := unix.Sysctl("kern.osrelease") + if err != nil { + return 0 + } + + majorVersionText, _, _ := strings.Cut(releaseVersion, ".") + + majorVersion, err := strconv.Atoi(majorVersionText) + if err != nil { + return 0 + } + + return majorVersion +}() + +func findProcessPath(network string, from netip.AddrPort, to netip.AddrPort) (string, error) { + switch majorVersion { + case 12, 13: + return findProcessPath12(network, from, to) + } + + return "", ErrPlatformNotSupport +} + +func findProcessPath12(network string, from netip.AddrPort, to netip.AddrPort) (string, error) { + switch network { + case TCP: + data, err := unix.SysctlRaw("net.inet.tcp.pcblist") + if err != nil { + return "", err + } + + if len(data) < int(unsafe.Sizeof(Xinpgen12{})) { + return "", fmt.Errorf("invalid sysctl data len: %d", len(data)) + } + + data = data[unsafe.Sizeof(Xinpgen12{}):] + + for len(data) > int(unsafe.Sizeof(XTcpcb12{}.Len)) { + tcb := (*XTcpcb12)(unsafe.Pointer(&data[0])) + if tcb.Len < uint32(unsafe.Sizeof(XTcpcb12{})) || uint32(len(data)) < tcb.Len { + break + } + + data = data[tcb.Len:] + + var ( + connFromAddr netip.Addr + connToAddr netip.Addr + ) + if tcb.Family == unix.AF_INET { + connFromAddr = netip.AddrFrom4([4]byte(tcb.InEndpoints.LAddr[12:16])) + connToAddr = netip.AddrFrom4([4]byte(tcb.InEndpoints.FAddr[12:16])) + } else if tcb.Family == unix.AF_INET6 { + connFromAddr = netip.AddrFrom16(tcb.InEndpoints.LAddr) + connToAddr = netip.AddrFrom16(tcb.InEndpoints.FAddr) + } else { + continue + } + + connFrom := netip.AddrPortFrom(connFromAddr, binary.BigEndian.Uint16(tcb.InEndpoints.LPort[:])) + connTo := netip.AddrPortFrom(connToAddr, binary.BigEndian.Uint16(tcb.InEndpoints.FPort[:])) + + if connFrom == from && connTo == to { + pid, err := findPidBySocketAddr12(tcb.SocketAddr) + if err != nil { + return "", err + } + + return findExecutableByPid(pid) + } + } + case UDP: + data, err := unix.SysctlRaw("net.inet.udp.pcblist") + if err != nil { + return "", err + } + + if len(data) < int(unsafe.Sizeof(Xinpgen12{})) { + return "", fmt.Errorf("invalid sysctl data len: %d", len(data)) + } + + data = data[unsafe.Sizeof(Xinpgen12{}):] + + for len(data) > int(unsafe.Sizeof(XInpcb12{}.Len)) { + icb := (*XInpcb12)(unsafe.Pointer(&data[0])) + if icb.Len < uint32(unsafe.Sizeof(XInpcb12{})) || uint32(len(data)) < icb.Len { + break + } + data = data[icb.Len:] + + var connFromAddr netip.Addr + + if icb.Family == unix.AF_INET { + connFromAddr = netip.AddrFrom4([4]byte(icb.InEndpoints.LAddr[12:16])) + } else if icb.Family == unix.AF_INET6 { + connFromAddr = netip.AddrFrom16(icb.InEndpoints.LAddr) + } else { + continue + } + + connFrom := netip.AddrPortFrom(connFromAddr, binary.BigEndian.Uint16(icb.InEndpoints.LPort[:])) + + if connFrom == from { + pid, err := findPidBySocketAddr12(icb.SocketAddr) + if err != nil { + return "", err + } + + return findExecutableByPid(pid) + } + } + } + + return "", ErrNotFound +} + +func findPidBySocketAddr12(socketAddr uint64) (uint32, error) { + buf, err := unix.SysctlRaw("kern.file") + if err != nil { + return 0, err + } + + filesLen := len(buf) / int(unsafe.Sizeof(XFile12{})) + files := unsafe.Slice((*XFile12)(unsafe.Pointer(&buf[0])), filesLen) + + for _, file := range files { + if file.Size != uint64(unsafe.Sizeof(XFile12{})) { + return 0, fmt.Errorf("invalid xfile size: %d", file.Size) + } + + if file.DataAddr == socketAddr { + return file.Pid, nil + } + } + + return 0, ErrNotFound +} + +func findExecutableByPid(pid uint32) (string, error) { + buf := make([]byte, unix.PathMax) + size := uint64(len(buf)) + mib := [4]uint32{ + unix.CTL_KERN, + 14, // KERN_PROC + 12, // KERN_PROC_PATHNAME + pid, + } + + _, _, errno := unix.Syscall6( + unix.SYS___SYSCTL, + uintptr(unsafe.Pointer(&mib[0])), + uintptr(len(mib)), + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&size)), + 0, + 0) + if errno != 0 || size == 0 { + return "", fmt.Errorf("sysctl: get proc name: %w", errno) + } + + return string(buf[:size-1]), nil +} diff --git a/component/process/process_freebsd_amd64.go b/component/process/process_freebsd_amd64.go deleted file mode 100644 index f3e6464..0000000 --- a/component/process/process_freebsd_amd64.go +++ /dev/null @@ -1,233 +0,0 @@ -package process - -import ( - "encoding/binary" - "fmt" - "net" - "strconv" - "strings" - "sync" - "syscall" - "unsafe" - - "github.com/Dreamacro/clash/log" -) - -// store process name for when dealing with multiple PROCESS-NAME rules -var ( - defaultSearcher *searcher - - once sync.Once -) - -func findProcessName(network string, ip net.IP, srcPort int) (string, error) { - once.Do(func() { - if err := initSearcher(); err != nil { - log.Errorln("Initialize PROCESS-NAME failed: %s", err.Error()) - log.Warnln("All PROCESS-NAME rules will be skipped") - return - } - }) - - if defaultSearcher == nil { - return "", ErrPlatformNotSupport - } - - var spath string - isTCP := network == TCP - switch network { - case TCP: - spath = "net.inet.tcp.pcblist" - case UDP: - spath = "net.inet.udp.pcblist" - default: - return "", ErrInvalidNetwork - } - - value, err := syscall.Sysctl(spath) - if err != nil { - return "", err - } - - buf := []byte(value) - pid, err := defaultSearcher.Search(buf, ip, uint16(srcPort), isTCP) - if err != nil { - return "", err - } - - return getExecPathFromPID(pid) -} - -func getExecPathFromPID(pid uint32) (string, error) { - buf := make([]byte, 2048) - size := uint64(len(buf)) - // CTL_KERN, KERN_PROC, KERN_PROC_PATHNAME, pid - mib := [4]uint32{1, 14, 12, pid} - - _, _, errno := syscall.Syscall6( - syscall.SYS___SYSCTL, - uintptr(unsafe.Pointer(&mib[0])), - uintptr(len(mib)), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&size)), - 0, - 0) - if errno != 0 || size == 0 { - return "", errno - } - - return string(buf[:size-1]), nil -} - -func readNativeUint32(b []byte) uint32 { - return *(*uint32)(unsafe.Pointer(&b[0])) -} - -type searcher struct { - // sizeof(struct xinpgen) - headSize int - // sizeof(struct xtcpcb) - tcpItemSize int - // sizeof(struct xinpcb) - udpItemSize int - udpInpOffset int - port int - ip int - vflag int - socket int - - // sizeof(struct xfile) - fileItemSize int - data int - pid int -} - -func (s *searcher) Search(buf []byte, ip net.IP, port uint16, isTCP bool) (uint32, error) { - var itemSize int - var inpOffset int - - if isTCP { - // struct xtcpcb - itemSize = s.tcpItemSize - inpOffset = 8 - } else { - // struct xinpcb - itemSize = s.udpItemSize - inpOffset = s.udpInpOffset - } - - isIPv4 := ip.To4() != nil - // skip the first xinpgen block - for i := s.headSize; i+itemSize <= len(buf); i += itemSize { - inp := i + inpOffset - - srcPort := binary.BigEndian.Uint16(buf[inp+s.port : inp+s.port+2]) - - if port != srcPort { - continue - } - - // xinpcb.inp_vflag - flag := buf[inp+s.vflag] - - var srcIP net.IP - switch { - case flag&0x1 > 0 && isIPv4: - // ipv4 - srcIP = net.IP(buf[inp+s.ip : inp+s.ip+4]) - case flag&0x2 > 0 && !isIPv4: - // ipv6 - srcIP = net.IP(buf[inp+s.ip-12 : inp+s.ip+4]) - default: - continue - } - - if !ip.Equal(srcIP) { - continue - } - - // xsocket.xso_so, interpreted as big endian anyway since it's only used for comparison - socket := binary.BigEndian.Uint64(buf[inp+s.socket : inp+s.socket+8]) - return s.searchSocketPid(socket) - } - return 0, ErrNotFound -} - -func (s *searcher) searchSocketPid(socket uint64) (uint32, error) { - value, err := syscall.Sysctl("kern.file") - if err != nil { - return 0, err - } - - buf := []byte(value) - - // struct xfile - itemSize := s.fileItemSize - for i := 0; i+itemSize <= len(buf); i += itemSize { - // xfile.xf_data - data := binary.BigEndian.Uint64(buf[i+s.data : i+s.data+8]) - if data == socket { - // xfile.xf_pid - pid := readNativeUint32(buf[i+s.pid : i+s.pid+4]) - return pid, nil - } - } - return 0, ErrNotFound -} - -func newSearcher(major int) *searcher { - var s *searcher - switch major { - case 11: - s = &searcher{ - headSize: 32, - tcpItemSize: 1304, - udpItemSize: 632, - port: 198, - ip: 228, - vflag: 116, - socket: 88, - fileItemSize: 80, - data: 56, - pid: 8, - udpInpOffset: 8, - } - case 12: - fallthrough - case 13: - s = &searcher{ - headSize: 64, - tcpItemSize: 744, - udpItemSize: 400, - port: 254, - ip: 284, - vflag: 392, - socket: 16, - fileItemSize: 128, - data: 56, - pid: 8, - } - } - return s -} - -func initSearcher() error { - osRelease, err := syscall.Sysctl("kern.osrelease") - if err != nil { - return err - } - - dot := strings.Index(osRelease, ".") - if dot != -1 { - osRelease = osRelease[:dot] - } - major, err := strconv.Atoi(osRelease) - if err != nil { - return err - } - defaultSearcher = newSearcher(major) - if defaultSearcher == nil { - return fmt.Errorf("unsupported freebsd version %d", major) - } - return nil -} diff --git a/component/process/process_freebsd_test.go b/component/process/process_freebsd_test.go new file mode 100644 index 0000000..e84fa72 --- /dev/null +++ b/component/process/process_freebsd_test.go @@ -0,0 +1,35 @@ +//go:build freebsd + +package process + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestEnforceStructValid12(t *testing.T) { + if majorVersion != 12 && majorVersion != 13 { + t.Skipf("Unsupported freebsd version: %d", majorVersion) + + return + } + + assert.Equal(t, 0, int(unsafe.Offsetof(XTcpcb12{}.Len))) + assert.Equal(t, 24, int(unsafe.Offsetof(XTcpcb12{}.SocketAddr))) + assert.Equal(t, 116, int(unsafe.Offsetof(XTcpcb12{}.Family))) + assert.Equal(t, 260, int(unsafe.Offsetof(XTcpcb12{}.InEndpoints))) + assert.Equal(t, 0, int(unsafe.Offsetof(XInpcb12{}.Len))) + assert.Equal(t, 16, int(unsafe.Offsetof(XInpcb12{}.SocketAddr))) + assert.Equal(t, 108, int(unsafe.Offsetof(XInpcb12{}.Family))) + assert.Equal(t, 252, int(unsafe.Offsetof(XInpcb12{}.InEndpoints))) + assert.Equal(t, 0, int(unsafe.Offsetof(XFile12{}.Size))) + assert.Equal(t, 8, int(unsafe.Offsetof(XFile12{}.Pid))) + assert.Equal(t, 56, int(unsafe.Offsetof(XFile12{}.DataAddr))) + assert.Equal(t, 64, int(unsafe.Sizeof(Xinpgen12{}))) + assert.Equal(t, 744, int(unsafe.Sizeof(XTcpcb12{}))) + assert.Equal(t, 400, int(unsafe.Sizeof(XInpcb12{}))) + assert.Equal(t, 40, int(unsafe.Sizeof(InEndpoints12{}))) + assert.Equal(t, 128, int(unsafe.Sizeof(XFile12{}))) +} diff --git a/component/process/process_linux.go b/component/process/process_linux.go index 1edd486..78821d7 100644 --- a/component/process/process_linux.go +++ b/component/process/process_linux.go @@ -4,24 +4,16 @@ import ( "bytes" "encoding/binary" "fmt" - "net" + "net/netip" "os" - "path/filepath" - "strings" - "syscall" - "unicode" "unsafe" + "github.com/Dreamacro/clash/common/pool" + "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) -const ( - SOCK_DIAG_BY_FAMILY = 20 - inetDiagRequestSize = int(unsafe.Sizeof(inetDiagRequest{})) - inetDiagResponseSize = int(unsafe.Sizeof(inetDiagResponse{})) -) - type inetDiagRequest struct { Family byte Protocol byte @@ -57,43 +49,50 @@ type inetDiagResponse struct { INode uint32 } -func findProcessName(network string, ip net.IP, srcPort int) (string, error) { - inode, uid, err := resolveSocketByNetlink(network, ip, srcPort) +func findProcessPath(network string, from netip.AddrPort, to netip.AddrPort) (string, error) { + inode, uid, err := resolveSocketByNetlink(network, from, to) if err != nil { return "", err } - return resolveProcessNameByProcSearch(inode, uid) + return resolveProcessPathByProcSearch(inode, uid) } -func resolveSocketByNetlink(network string, ip net.IP, srcPort int) (uint32, uint32, error) { +func resolveSocketByNetlink(network string, from netip.AddrPort, to netip.AddrPort) (inode uint32, uid uint32, err error) { request := &inetDiagRequest{ States: 0xffffffff, Cookie: [2]uint32{0xffffffff, 0xffffffff}, } - if ip.To4() != nil { + if from.Addr().Is4() { request.Family = unix.AF_INET } else { request.Family = unix.AF_INET6 } - if strings.HasPrefix(network, "tcp") { + // Swap src & dst for udp + // See also https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html + switch network { + case TCP: request.Protocol = unix.IPPROTO_TCP - } else if strings.HasPrefix(network, "udp") { + + copy(request.Src[:], from.Addr().AsSlice()) + copy(request.Dst[:], to.Addr().AsSlice()) + + binary.BigEndian.PutUint16(request.SrcPort[:], from.Port()) + binary.BigEndian.PutUint16(request.DstPort[:], to.Port()) + case UDP: request.Protocol = unix.IPPROTO_UDP - } else { + + copy(request.Dst[:], from.Addr().AsSlice()) + copy(request.Src[:], to.Addr().AsSlice()) + + binary.BigEndian.PutUint16(request.DstPort[:], from.Port()) + binary.BigEndian.PutUint16(request.SrcPort[:], to.Port()) + default: return 0, 0, ErrInvalidNetwork } - if v4 := ip.To4(); v4 != nil { - copy(request.Src[:], v4) - } else { - copy(request.Src[:], ip) - } - - binary.BigEndian.PutUint16(request.SrcPort[:], uint16(srcPort)) - conn, err := netlink.Dial(unix.NETLINK_INET_DIAG, nil) if err != nil { return 0, 0, err @@ -102,10 +101,10 @@ func resolveSocketByNetlink(network string, ip net.IP, srcPort int) (uint32, uin message := netlink.Message{ Header: netlink.Header{ - Type: SOCK_DIAG_BY_FAMILY, - Flags: netlink.Request | netlink.Dump, + Type: 20, // SOCK_DIAG_BY_FAMILY + Flags: netlink.Request, }, - Data: (*(*[inetDiagRequestSize]byte)(unsafe.Pointer(request)))[:], + Data: (*(*[unsafe.Sizeof(*request)]byte)(unsafe.Pointer(request)))[:], } messages, err := conn.Execute(message) @@ -114,7 +113,7 @@ func resolveSocketByNetlink(network string, ip net.IP, srcPort int) (uint32, uin } for _, msg := range messages { - if len(msg.Data) < inetDiagResponseSize { + if len(msg.Data) < int(unsafe.Sizeof(inetDiagResponse{})) { continue } @@ -126,53 +125,82 @@ func resolveSocketByNetlink(network string, ip net.IP, srcPort int) (uint32, uin return 0, 0, ErrNotFound } -func resolveProcessNameByProcSearch(inode, uid uint32) (string, error) { - files, err := os.ReadDir("/proc") +func resolveProcessPathByProcSearch(inode, uid uint32) (string, error) { + procDir, err := os.Open("/proc") + if err != nil { + return "", err + } + defer procDir.Close() + + pids, err := procDir.Readdirnames(-1) if err != nil { return "", err } - buffer := make([]byte, unix.PathMax) - socket := fmt.Appendf(nil, "socket:[%d]", inode) + expectedSocketName := fmt.Appendf(nil, "socket:[%d]", inode) - for _, f := range files { - if !f.IsDir() || !isPid(f.Name()) { + pathBuffer := pool.Get(64) + defer pool.Put(pathBuffer) + + readlinkBuffer := pool.Get(32) + defer pool.Put(readlinkBuffer) + + copy(pathBuffer, "/proc/") + + for _, pid := range pids { + if !isPid(pid) { continue } - info, err := f.Info() + pathBuffer = append(pathBuffer[:len("/proc/")], pid...) + + stat := &unix.Stat_t{} + err = unix.Stat(string(pathBuffer), stat) if err != nil { - return "", err - } - if info.Sys().(*syscall.Stat_t).Uid != uid { + continue + } else if stat.Uid != uid { continue } - processPath := filepath.Join("/proc", f.Name()) - fdPath := filepath.Join(processPath, "fd") + pathBuffer = append(pathBuffer, "/fd/"...) + fdsPrefixLength := len(pathBuffer) - fds, err := os.ReadDir(fdPath) + fdDir, err := os.Open(string(pathBuffer)) + if err != nil { + continue + } + + fds, err := fdDir.Readdirnames(-1) + fdDir.Close() if err != nil { continue } for _, fd := range fds { - n, err := unix.Readlink(filepath.Join(fdPath, fd.Name()), buffer) + pathBuffer = pathBuffer[:fdsPrefixLength] + + pathBuffer = append(pathBuffer, fd...) + + n, err := unix.Readlink(string(pathBuffer), readlinkBuffer) if err != nil { continue } - if bytes.Equal(buffer[:n], socket) { - return os.Readlink(filepath.Join(processPath, "exe")) + if bytes.Equal(readlinkBuffer[:n], expectedSocketName) { + return os.Readlink("/proc/" + pid + "/exe") } } } - return "", fmt.Errorf("process of uid(%d),inode(%d) not found", uid, inode) + return "", fmt.Errorf("inode %d of uid %d not found", inode, uid) } -func isPid(s string) bool { - return strings.IndexFunc(s, func(r rune) bool { - return !unicode.IsDigit(r) - }) == -1 +func isPid(name string) bool { + for _, c := range name { + if c < '0' || c > '9' { + return false + } + } + + return true } diff --git a/component/process/process_other.go b/component/process/process_other.go index c9e486f..0a1a60c 100644 --- a/component/process/process_other.go +++ b/component/process/process_other.go @@ -1,9 +1,11 @@ -//go:build !darwin && !linux && !windows && (!freebsd || !amd64) +//go:build !darwin && !linux && !windows && !freebsd package process -import "net" +import ( + "net/netip" +) -func findProcessName(network string, ip net.IP, srcPort int) (string, error) { +func findProcessPath(_ string, _, _ netip.AddrPort) (string, error) { return "", ErrPlatformNotSupport } diff --git a/component/process/process_test.go b/component/process/process_test.go new file mode 100644 index 0000000..e46754d --- /dev/null +++ b/component/process/process_test.go @@ -0,0 +1,101 @@ +package process + +import ( + "net" + "net/netip" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func testConn(t *testing.T, network, address string) { + l, err := net.Listen(network, address) + if err != nil { + assert.FailNow(t, "Listen failed", err) + } + defer l.Close() + + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + assert.FailNow(t, "Dial failed", err) + } + defer conn.Close() + + rConn, err := l.Accept() + if err != nil { + assert.FailNow(t, "Accept conn failed", err) + } + defer rConn.Close() + + path, err := FindProcessPath(TCP, netip.MustParseAddrPort(conn.LocalAddr().String()), netip.MustParseAddrPort(conn.RemoteAddr().String())) + if err != nil { + assert.FailNow(t, "Find process path failed", err) + } + + exePath, err := os.Executable() + if err != nil { + assert.FailNow(t, "Get executable failed", err) + } + + assert.Equal(t, exePath, path) +} + +func TestFindProcessPathTCP(t *testing.T) { + t.Run("v4", func(t *testing.T) { + testConn(t, "tcp4", "127.0.0.1:0") + }) + t.Run("v6", func(t *testing.T) { + testConn(t, "tcp6", "[::1]:0") + }) +} + +func testPacketConn(t *testing.T, network, lAddress, rAddress string) { + lConn, err := net.ListenPacket(network, lAddress) + if err != nil { + assert.FailNow(t, "ListenPacket failed", err) + } + defer lConn.Close() + + rConn, err := net.ListenPacket(network, rAddress) + if err != nil { + assert.FailNow(t, "ListenPacket failed", err) + } + defer rConn.Close() + + _, err = lConn.WriteTo([]byte{0}, rConn.LocalAddr()) + if err != nil { + assert.FailNow(t, "Send message failed", err) + } + + path, err := FindProcessPath(UDP, netip.MustParseAddrPort(lConn.LocalAddr().String()), netip.MustParseAddrPort(rConn.LocalAddr().String())) + if err != nil { + assert.FailNow(t, "Find process path", err) + } + + exePath, err := os.Executable() + if err != nil { + assert.FailNow(t, "Find executable", err) + } + + assert.Equal(t, exePath, path) +} + +func TestFindProcessPathUDP(t *testing.T) { + t.Run("v4", func(t *testing.T) { + testPacketConn(t, "udp4", "127.0.0.1:0", "127.0.0.1:0") + }) + t.Run("v6", func(t *testing.T) { + testPacketConn(t, "udp6", "[::1]:0", "[::1]:0") + }) +} + +func BenchmarkFindProcessName(b *testing.B) { + from := netip.MustParseAddrPort("127.0.0.1:11447") + to := netip.MustParseAddrPort("127.0.0.1:33669") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + FindProcessPath(TCP, from, to) + } +} diff --git a/component/process/process_windows.go b/component/process/process_windows.go index 5cbcdf7..9eacfbc 100644 --- a/component/process/process_windows.go +++ b/component/process/process_windows.go @@ -1,196 +1,206 @@ package process import ( + "errors" "fmt" - "net" - "sync" - "syscall" + "net/netip" "unsafe" - "github.com/Dreamacro/clash/log" - "golang.org/x/sys/windows" -) -const ( - tcpTableFunc = "GetExtendedTcpTable" - tcpTablePidConn = 4 - udpTableFunc = "GetExtendedUdpTable" - udpTablePid = 1 - queryProcNameFunc = "QueryFullProcessImageNameW" + "github.com/Dreamacro/clash/common/pool" ) var ( - getExTCPTable uintptr - getExUDPTable uintptr - queryProcName uintptr + modIphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") - once sync.Once + procGetExtendedTcpTable = modIphlpapi.NewProc("GetExtendedTcpTable") + procGetExtendedUdpTable = modIphlpapi.NewProc("GetExtendedUdpTable") ) -func initWin32API() error { - h, err := windows.LoadLibrary("iphlpapi.dll") - if err != nil { - return fmt.Errorf("LoadLibrary iphlpapi.dll failed: %s", err.Error()) - } - - getExTCPTable, err = windows.GetProcAddress(h, tcpTableFunc) - if err != nil { - return fmt.Errorf("GetProcAddress of %s failed: %s", tcpTableFunc, err.Error()) - } - - getExUDPTable, err = windows.GetProcAddress(h, udpTableFunc) - if err != nil { - return fmt.Errorf("GetProcAddress of %s failed: %s", udpTableFunc, err.Error()) - } - - h, err = windows.LoadLibrary("kernel32.dll") - if err != nil { - return fmt.Errorf("LoadLibrary kernel32.dll failed: %s", err.Error()) - } - - queryProcName, err = windows.GetProcAddress(h, queryProcNameFunc) - if err != nil { - return fmt.Errorf("GetProcAddress of %s failed: %s", queryProcNameFunc, err.Error()) - } - - return nil -} - -func findProcessName(network string, ip net.IP, srcPort int) (string, error) { - once.Do(func() { - err := initWin32API() - if err != nil { - log.Errorln("Initialize PROCESS-NAME failed: %s", err.Error()) - log.Warnln("All PROCESS-NAMES rules will be skipped") - return - } - }) - family := windows.AF_INET - if ip.To4() == nil { +func findProcessPath(network string, from netip.AddrPort, to netip.AddrPort) (string, error) { + family := uint32(windows.AF_INET) + if from.Addr().Is6() { family = windows.AF_INET6 } - var class int - var fn uintptr + var protocol uint32 switch network { case TCP: - fn = getExTCPTable - class = tcpTablePidConn + protocol = windows.IPPROTO_TCP case UDP: - fn = getExUDPTable - class = udpTablePid + protocol = windows.IPPROTO_UDP default: return "", ErrInvalidNetwork } - buf, err := getTransportTable(fn, family, class) + pid, err := findPidByConnectionEndpoint(family, protocol, from, to) if err != nil { return "", err } - s := newSearcher(family == windows.AF_INET, network == TCP) - - pid, err := s.Search(buf, ip, uint16(srcPort)) - if err != nil { - return "", err - } return getExecPathFromPID(pid) } -type searcher struct { - itemSize int - port int - ip int - ipSize int - pid int - tcpState int -} +func findPidByConnectionEndpoint(family uint32, protocol uint32, from netip.AddrPort, to netip.AddrPort) (uint32, error) { + buf := pool.Get(8) + defer pool.Put(buf) -func (s *searcher) Search(b []byte, ip net.IP, port uint16) (uint32, error) { - n := int(readNativeUint32(b[:4])) - itemSize := s.itemSize - for i := 0; i < n; i++ { - row := b[4+itemSize*i : 4+itemSize*(i+1)] + bufSize := len(buf) - if s.tcpState >= 0 { - tcpState := readNativeUint32(row[s.tcpState : s.tcpState+4]) - // MIB_TCP_STATE_ESTAB, only check established connections for TCP - if tcpState != 5 { - continue +loop: + for { + var ret uintptr + + switch protocol { + case windows.IPPROTO_TCP: + ret, _, _ = procGetExtendedTcpTable.Call( + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&bufSize)), + 0, + uintptr(family), + 4, // TCP_TABLE_OWNER_PID_CONNECTIONS + 0, + ) + case windows.IPPROTO_UDP: + ret, _, _ = procGetExtendedUdpTable.Call( + uintptr(unsafe.Pointer(&buf[0])), + uintptr(unsafe.Pointer(&bufSize)), + 0, + uintptr(family), + 1, // UDP_TABLE_OWNER_PID + 0, + ) + default: + return 0, errors.New("unsupported network") + } + + switch ret { + case 0: + buf = buf[:bufSize] + + break loop + case uintptr(windows.ERROR_INSUFFICIENT_BUFFER): + pool.Put(buf) + buf = pool.Get(bufSize) + + continue loop + default: + return 0, fmt.Errorf("syscall error: %d", ret) + } + } + + if len(buf) < int(unsafe.Sizeof(uint32(0))) { + return 0, fmt.Errorf("invalid table size: %d", len(buf)) + } + + entriesSize := *(*uint32)(unsafe.Pointer(&buf[0])) + + switch protocol { + case windows.IPPROTO_TCP: + if family == windows.AF_INET { + type MibTcpRowOwnerPid struct { + State uint32 + LocalAddr [4]byte + LocalPort uint32 + RemoteAddr [4]byte + RemotePort uint32 + OwningPid uint32 + } + + if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibTcpRowOwnerPid{})) { + return 0, fmt.Errorf("invalid tables size: %d", len(buf)) + } + + entries := unsafe.Slice((*MibTcpRowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize) + for _, entry := range entries { + localAddr := netip.AddrFrom4(entry.LocalAddr) + localPort := windows.Ntohs(uint16(entry.LocalPort)) + remoteAddr := netip.AddrFrom4(entry.RemoteAddr) + remotePort := windows.Ntohs(uint16(entry.RemotePort)) + + if localAddr == from.Addr() && remoteAddr == to.Addr() && localPort == from.Port() && remotePort == to.Port() { + return entry.OwningPid, nil + } + } + } else { + type MibTcp6RowOwnerPid struct { + LocalAddr [16]byte + LocalScopeID uint32 + LocalPort uint32 + RemoteAddr [16]byte + RemoteScopeID uint32 + RemotePort uint32 + State uint32 + OwningPid uint32 + } + + if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibTcp6RowOwnerPid{})) { + return 0, fmt.Errorf("invalid tables size: %d", len(buf)) + } + + entries := unsafe.Slice((*MibTcp6RowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize) + for _, entry := range entries { + localAddr := netip.AddrFrom16(entry.LocalAddr) + localPort := windows.Ntohs(uint16(entry.LocalPort)) + remoteAddr := netip.AddrFrom16(entry.RemoteAddr) + remotePort := windows.Ntohs(uint16(entry.RemotePort)) + + if localAddr == from.Addr() && remoteAddr == to.Addr() && localPort == from.Port() && remotePort == to.Port() { + return entry.OwningPid, nil + } } } + case windows.IPPROTO_UDP: + if family == windows.AF_INET { + type MibUdpRowOwnerPid struct { + LocalAddr [4]byte + LocalPort uint32 + OwningPid uint32 + } - // according to MSDN, only the lower 16 bits of dwLocalPort are used and the port number is in network endian. - // this field can be illustrated as follows depends on different machine endianess: - // little endian: [ MSB LSB 0 0 ] interpret as native uint32 is ((LSB<<8)|MSB) - // big endian: [ 0 0 MSB LSB ] interpret as native uint32 is ((MSB<<8)|LSB) - // so we need an syscall.Ntohs on the lower 16 bits after read the port as native uint32 - srcPort := syscall.Ntohs(uint16(readNativeUint32(row[s.port : s.port+4]))) - if srcPort != port { - continue + if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibUdpRowOwnerPid{})) { + return 0, fmt.Errorf("invalid tables size: %d", len(buf)) + } + + entries := unsafe.Slice((*MibUdpRowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize) + for _, entry := range entries { + localAddr := netip.AddrFrom4(entry.LocalAddr) + localPort := windows.Ntohs(uint16(entry.LocalPort)) + + if localAddr == from.Addr() && localPort == from.Port() { + return entry.OwningPid, nil + } + } + } else { + type MibUdp6RowOwnerPid struct { + LocalAddr [16]byte + LocalScopeId uint32 + LocalPort uint32 + OwningPid uint32 + } + + if uint32(len(buf))-4 < entriesSize*uint32(unsafe.Sizeof(MibUdp6RowOwnerPid{})) { + return 0, fmt.Errorf("invalid tables size: %d", len(buf)) + } + + entries := unsafe.Slice((*MibUdp6RowOwnerPid)(unsafe.Pointer(&buf[4])), entriesSize) + for _, entry := range entries { + localAddr := netip.AddrFrom16(entry.LocalAddr) + localPort := windows.Ntohs(uint16(entry.LocalPort)) + + if localAddr == from.Addr() && localPort == from.Port() { + return entry.OwningPid, nil + } + } } - - srcIP := net.IP(row[s.ip : s.ip+s.ipSize]) - // windows binds an unbound udp socket to 0.0.0.0/[::] while first sendto - if !ip.Equal(srcIP) && (!srcIP.IsUnspecified() || s.tcpState != -1) { - continue - } - - pid := readNativeUint32(row[s.pid : s.pid+4]) - return pid, nil + default: + return 0, ErrInvalidNetwork } + return 0, ErrNotFound } -func newSearcher(isV4, isTCP bool) *searcher { - var itemSize, port, ip, ipSize, pid int - tcpState := -1 - switch { - case isV4 && isTCP: - // struct MIB_TCPROW_OWNER_PID - itemSize, port, ip, ipSize, pid, tcpState = 24, 8, 4, 4, 20, 0 - case isV4 && !isTCP: - // struct MIB_UDPROW_OWNER_PID - itemSize, port, ip, ipSize, pid = 12, 4, 0, 4, 8 - case !isV4 && isTCP: - // struct MIB_TCP6ROW_OWNER_PID - itemSize, port, ip, ipSize, pid, tcpState = 56, 20, 0, 16, 52, 48 - case !isV4 && !isTCP: - // struct MIB_UDP6ROW_OWNER_PID - itemSize, port, ip, ipSize, pid = 28, 20, 0, 16, 24 - } - - return &searcher{ - itemSize: itemSize, - port: port, - ip: ip, - ipSize: ipSize, - pid: pid, - tcpState: tcpState, - } -} - -func getTransportTable(fn uintptr, family int, class int) ([]byte, error) { - for size, buf := uint32(8), make([]byte, 8); ; { - ptr := unsafe.Pointer(&buf[0]) - err, _, _ := syscall.SyscallN(fn, uintptr(ptr), uintptr(unsafe.Pointer(&size)), 0, uintptr(family), uintptr(class), 0) - - switch err { - case 0: - return buf, nil - case uintptr(syscall.ERROR_INSUFFICIENT_BUFFER): - buf = make([]byte, size) - default: - return nil, fmt.Errorf("syscall error: %d", err) - } - } -} - -func readNativeUint32(b []byte) uint32 { - return *(*uint32)(unsafe.Pointer(&b[0])) -} - func getExecPathFromPID(pid uint32) (string, error) { // kernel process starts with a colon in order to distinguish with normal processes switch pid { @@ -207,17 +217,13 @@ func getExecPathFromPID(pid uint32) (string, error) { } defer windows.CloseHandle(h) - buf := make([]uint16, syscall.MAX_LONG_PATH) + buf := make([]uint16, windows.MAX_LONG_PATH) size := uint32(len(buf)) - r1, _, err := syscall.SyscallN( - queryProcName, - uintptr(h), - uintptr(1), - uintptr(unsafe.Pointer(&buf[0])), - uintptr(unsafe.Pointer(&size)), - ) - if r1 == 0 { + + err = windows.QueryFullProcessImageName(h, 0, &buf[0], &size) + if err != nil { return "", err } - return syscall.UTF16ToString(buf[:size]), nil + + return windows.UTF16ToString(buf[:size]), nil } diff --git a/constant/metadata.go b/constant/metadata.go index cab6e37..b92cac7 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -3,6 +3,7 @@ package constant import ( "encoding/json" "net" + "net/netip" "strconv" "github.com/Dreamacro/clash/transport/socks5" @@ -72,6 +73,8 @@ type Metadata struct { DNSMode DNSMode `json:"dnsMode"` ProcessPath string `json:"processPath"` SpecialProxy string `json:"specialProxy"` + + OriginDst netip.AddrPort `json:"-"` } func (m *Metadata) RemoteAddress() string { diff --git a/listener/http/client.go b/listener/http/client.go index 873a9a3..eb7b7fd 100644 --- a/listener/http/client.go +++ b/listener/http/client.go @@ -12,7 +12,7 @@ import ( "github.com/Dreamacro/clash/transport/socks5" ) -func newClient(source net.Addr, in chan<- C.ConnContext) *http.Client { +func newClient(source net.Addr, originTarget net.Addr, in chan<- C.ConnContext) *http.Client { return &http.Client{ Transport: &http.Transport{ // from http.DefaultTransport @@ -32,7 +32,7 @@ func newClient(source net.Addr, in chan<- C.ConnContext) *http.Client { left, right := net.Pipe() - in <- inbound.NewHTTP(dstAddr, source, right) + in <- inbound.NewHTTP(dstAddr, source, originTarget, right) return left, nil }, diff --git a/listener/http/proxy.go b/listener/http/proxy.go index 989d119..59d0831 100644 --- a/listener/http/proxy.go +++ b/listener/http/proxy.go @@ -15,7 +15,7 @@ import ( ) func HandleConn(c net.Conn, in chan<- C.ConnContext, cache *cache.LruCache) { - client := newClient(c.RemoteAddr(), in) + client := newClient(c.RemoteAddr(), c.LocalAddr(), in) defer client.CloseIdleConnections() conn := N.NewBufferedConn(c) diff --git a/listener/http/upgrade.go b/listener/http/upgrade.go index 50cdcb0..c6870d9 100644 --- a/listener/http/upgrade.go +++ b/listener/http/upgrade.go @@ -41,7 +41,7 @@ func handleUpgrade(conn net.Conn, request *http.Request, in chan<- C.ConnContext left, right := net.Pipe() - in <- inbound.NewHTTP(dstAddr, conn.RemoteAddr(), right) + in <- inbound.NewHTTP(dstAddr, conn.RemoteAddr(), conn.LocalAddr(), right) bufferedLeft := N.NewBufferedConn(left) defer bufferedLeft.Close() diff --git a/listener/socks/udp.go b/listener/socks/udp.go index 8bc439f..5ef4216 100644 --- a/listener/socks/udp.go +++ b/listener/socks/udp.go @@ -79,7 +79,7 @@ func handleSocksUDP(pc net.PacketConn, in chan<- *inbound.PacketAdapter, buf []b bufRef: buf, } select { - case in <- inbound.NewPacket(target, packet, C.SOCKS5): + case in <- inbound.NewPacket(target, pc.LocalAddr(), packet, C.SOCKS5): default: } } diff --git a/listener/tproxy/udp.go b/listener/tproxy/udp.go index 0e7c059..4d8f6fc 100644 --- a/listener/tproxy/udp.go +++ b/listener/tproxy/udp.go @@ -91,7 +91,7 @@ func handlePacketConn(in chan<- *inbound.PacketAdapter, buf []byte, lAddr, rAddr buf: buf, } select { - case in <- inbound.NewPacket(target, pkt, C.TPROXY): + case in <- inbound.NewPacket(target, target.UDPAddr(), pkt, C.TPROXY): default: } } diff --git a/listener/tunnel/udp.go b/listener/tunnel/udp.go index ee0ecba..1a658ba 100644 --- a/listener/tunnel/udp.go +++ b/listener/tunnel/udp.go @@ -76,7 +76,7 @@ func (l *PacketConn) handleUDP(pc net.PacketConn, in chan<- *inbound.PacketAdapt payload: buf, } - ctx := inbound.NewPacket(l.target, packet, C.TUNNEL) + ctx := inbound.NewPacket(l.target, pc.LocalAddr(), packet, C.TUNNEL) ctx.Metadata().SpecialProxy = l.proxy select { case in <- ctx: diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 5eaf421..72da03d 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -401,9 +401,10 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { if !processFound && rule.ShouldFindProcess() { processFound = true + srcIP, ok := netip.AddrFromSlice(metadata.SrcIP) srcPort, err := strconv.ParseUint(metadata.SrcPort, 10, 16) - if err == nil { - path, err := P.FindProcessName(metadata.NetWork.String(), metadata.SrcIP, int(srcPort)) + if ok && err == nil && metadata.OriginDst.IsValid() { + path, err := P.FindProcessPath(metadata.NetWork.String(), netip.AddrPortFrom(srcIP, uint16(srcPort)), metadata.OriginDst) if err != nil { log.Debugln("[Process] find process %s: %v", metadata.String(), err) } else {