From 5212aaf445ecb7aa0482bdded8cb29a8f01b4152 Mon Sep 17 00:00:00 2001 From: Kr328 Date: Sun, 25 Jun 2023 09:19:06 +0800 Subject: [PATCH] Fix: process resolving for udp (#2806) --- component/process/process_freebsd.go | 28 +++++----- component/process/process_linux.go | 76 +++++++++++++++++++--------- component/process/process_test.go | 15 +++++- component/process/process_windows.go | 4 +- 4 files changed, 80 insertions(+), 43 deletions(-) diff --git a/component/process/process_freebsd.go b/component/process/process_freebsd.go index 9491397..769c684 100644 --- a/component/process/process_freebsd.go +++ b/component/process/process_freebsd.go @@ -23,32 +23,32 @@ type InEndpoints12 struct { type XTcpcb12 struct { Len uint32 // offset 0 - Padding1 [20]byte // offset 4 + _ [20]byte // offset 4 SocketAddr uint64 // offset 24 - Padding2 [84]byte // offset 32 + _ [84]byte // offset 32 Family uint32 // offset 116 - Padding3 [140]byte // offset 120 + _ [140]byte // offset 120 InEndpoints InEndpoints12 // offset 260 - Padding4 [444]byte // offset 300 + _ [444]byte // offset 300 } // size 744 type XInpcb12 struct { Len uint32 // offset 0 - Padding1 [12]byte // offset 4 + _ [12]byte // offset 4 SocketAddr uint64 // offset 16 - Padding2 [84]byte // offset 24 + _ [84]byte // offset 24 Family uint32 // offset 108 - Padding3 [140]byte // offset 112 + _ [140]byte // offset 112 InEndpoints InEndpoints12 // offset 252 - Padding4 [108]byte // offset 292 + _ [108]byte // offset 292 } // size 400 type XFile12 struct { Size uint64 // offset 0 Pid uint32 // offset 8 - Padding1 [44]byte // offset 12 + _ [44]byte // offset 12 DataAddr uint64 // offset 56 - Padding2 [64]byte // offset 64 + _ [64]byte // offset 64 } // size 128 var majorVersion = func() int { @@ -144,7 +144,6 @@ func findProcessPath12(network string, from netip.AddrPort, to netip.AddrPort) ( data = data[icb.Len:] var connFromAddr netip.Addr - if icb.Family == unix.AF_INET { connFromAddr = netip.AddrFrom4([4]byte(icb.InEndpoints.LAddr[12:16])) } else if icb.Family == unix.AF_INET6 { @@ -153,9 +152,9 @@ func findProcessPath12(network string, from netip.AddrPort, to netip.AddrPort) ( continue } - connFrom := netip.AddrPortFrom(connFromAddr, binary.BigEndian.Uint16(icb.InEndpoints.LPort[:])) + connFromPort := binary.BigEndian.Uint16(icb.InEndpoints.LPort[:]) - if connFrom == from { + if (connFromAddr == from.Addr() || connFromAddr.IsUnspecified()) && connFromPort == from.Port() { pid, err := findPidBySocketAddr12(icb.SocketAddr) if err != nil { return "", err @@ -208,7 +207,8 @@ func findExecutableByPid(pid uint32) (string, error) { uintptr(unsafe.Pointer(&buf[0])), uintptr(unsafe.Pointer(&size)), 0, - 0) + 0, + ) if errno != 0 || size == 0 { return "", fmt.Errorf("sysctl: get proc name: %w", errno) } diff --git a/component/process/process_linux.go b/component/process/process_linux.go index 78821d7..5d41248 100644 --- a/component/process/process_linux.go +++ b/component/process/process_linux.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "net" "net/netip" "os" "unsafe" @@ -59,40 +60,65 @@ func findProcessPath(network string, from netip.AddrPort, to netip.AddrPort) (st } func resolveSocketByNetlink(network string, from netip.AddrPort, to netip.AddrPort) (inode uint32, uid uint32, err error) { - request := &inetDiagRequest{ - States: 0xffffffff, - Cookie: [2]uint32{0xffffffff, 0xffffffff}, - } - - if from.Addr().Is4() { - request.Family = unix.AF_INET + var families []byte + if from.Addr().Unmap().Is4() { + families = []byte{unix.AF_INET, unix.AF_INET6} } else { - request.Family = unix.AF_INET6 + families = []byte{unix.AF_INET6, unix.AF_INET} } - // Swap src & dst for udp - // See also https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html + var protocol byte switch network { case TCP: - request.Protocol = unix.IPPROTO_TCP - - copy(request.Src[:], from.Addr().AsSlice()) - copy(request.Dst[:], to.Addr().AsSlice()) - - binary.BigEndian.PutUint16(request.SrcPort[:], from.Port()) - binary.BigEndian.PutUint16(request.DstPort[:], to.Port()) + protocol = unix.IPPROTO_TCP case UDP: - request.Protocol = unix.IPPROTO_UDP - - copy(request.Dst[:], from.Addr().AsSlice()) - copy(request.Src[:], to.Addr().AsSlice()) - - binary.BigEndian.PutUint16(request.DstPort[:], from.Port()) - binary.BigEndian.PutUint16(request.SrcPort[:], to.Port()) + protocol = unix.IPPROTO_UDP default: return 0, 0, ErrInvalidNetwork } + if protocol == unix.IPPROTO_UDP { + // Swap from & to for udp + // See also https://www.mail-archive.com/netdev@vger.kernel.org/msg248638.html + from, to = to, from + } + + for _, family := range families { + inode, uid, err = resolveSocketByNetlinkExact(family, protocol, from, to, netlink.Request) + if err == nil { + return inode, uid, err + } + } + + return 0, 0, ErrNotFound +} + +func resolveSocketByNetlinkExact(family byte, protocol byte, from netip.AddrPort, to netip.AddrPort, flags netlink.HeaderFlags) (inode uint32, uid uint32, err error) { + request := &inetDiagRequest{ + Family: family, + Protocol: protocol, + States: 0xffffffff, + Cookie: [2]uint32{0xffffffff, 0xffffffff}, + } + + var ( + fromAddr []byte + toAddr []byte + ) + if family == unix.AF_INET { + fromAddr = net.IP(from.Addr().AsSlice()).To4() + toAddr = net.IP(to.Addr().AsSlice()).To4() + } else { + fromAddr = net.IP(from.Addr().AsSlice()).To16() + toAddr = net.IP(to.Addr().AsSlice()).To16() + } + + copy(request.Src[:], fromAddr) + copy(request.Dst[:], toAddr) + + binary.BigEndian.PutUint16(request.SrcPort[:], from.Port()) + binary.BigEndian.PutUint16(request.DstPort[:], to.Port()) + conn, err := netlink.Dial(unix.NETLINK_INET_DIAG, nil) if err != nil { return 0, 0, err @@ -102,7 +128,7 @@ func resolveSocketByNetlink(network string, from netip.AddrPort, to netip.AddrPo message := netlink.Message{ Header: netlink.Header{ Type: 20, // SOCK_DIAG_BY_FAMILY - Flags: netlink.Request, + Flags: flags, }, Data: (*(*[unsafe.Sizeof(*request)]byte)(unsafe.Pointer(request)))[:], } diff --git a/component/process/process_test.go b/component/process/process_test.go index e46754d..2a71f29 100644 --- a/component/process/process_test.go +++ b/component/process/process_test.go @@ -28,7 +28,7 @@ func testConn(t *testing.T, network, address string) { } defer rConn.Close() - path, err := FindProcessPath(TCP, netip.MustParseAddrPort(conn.LocalAddr().String()), netip.MustParseAddrPort(conn.RemoteAddr().String())) + path, err := FindProcessPath(TCP, conn.LocalAddr().(*net.TCPAddr).AddrPort(), conn.RemoteAddr().(*net.TCPAddr).AddrPort()) if err != nil { assert.FailNow(t, "Find process path failed", err) } @@ -68,7 +68,12 @@ func testPacketConn(t *testing.T, network, lAddress, rAddress string) { assert.FailNow(t, "Send message failed", err) } - path, err := FindProcessPath(UDP, netip.MustParseAddrPort(lConn.LocalAddr().String()), netip.MustParseAddrPort(rConn.LocalAddr().String())) + _, lAddr, err := rConn.ReadFrom([]byte{0}) + if err != nil { + assert.FailNow(t, "Receive message failed", err) + } + + path, err := FindProcessPath(UDP, lAddr.(*net.UDPAddr).AddrPort(), rConn.LocalAddr().(*net.UDPAddr).AddrPort()) if err != nil { assert.FailNow(t, "Find process path", err) } @@ -88,6 +93,12 @@ func TestFindProcessPathUDP(t *testing.T) { t.Run("v6", func(t *testing.T) { testPacketConn(t, "udp6", "[::1]:0", "[::1]:0") }) + t.Run("v4AnyLocal", func(t *testing.T) { + testPacketConn(t, "udp4", "0.0.0.0:0", "127.0.0.1:0") + }) + t.Run("v6AnyLocal", func(t *testing.T) { + testPacketConn(t, "udp6", "[::]:0", "[::1]:0") + }) } func BenchmarkFindProcessName(b *testing.B) { diff --git a/component/process/process_windows.go b/component/process/process_windows.go index 806adba..62a02ea 100644 --- a/component/process/process_windows.go +++ b/component/process/process_windows.go @@ -168,7 +168,7 @@ loop: localAddr := netip.AddrFrom4(entry.LocalAddr) localPort := windows.Ntohs(uint16(entry.LocalPort)) - if localAddr == from.Addr() && localPort == from.Port() { + if (localAddr == from.Addr() || localAddr.IsUnspecified()) && localPort == from.Port() { return entry.OwningPid, nil } } @@ -189,7 +189,7 @@ loop: localAddr := netip.AddrFrom16(entry.LocalAddr) localPort := windows.Ntohs(uint16(entry.LocalPort)) - if localAddr == from.Addr() && localPort == from.Port() { + if (localAddr == from.Addr() || localAddr.IsUnspecified()) && localPort == from.Port() { return entry.OwningPid, nil } }