diff --git a/adapters/outbound/util.go b/adapters/outbound/util.go index 7922991..8d61796 100644 --- a/adapters/outbound/util.go +++ b/adapters/outbound/util.go @@ -115,8 +115,9 @@ func dialTimeout(network, address string, timeout time.Duration) (net.Conn, erro type dialResult struct { net.Conn error - ipv6 bool - done bool + resolved bool + ipv6 bool + done bool } results := make(chan dialResult) var primary, fallback dialResult @@ -142,6 +143,7 @@ func dialTimeout(network, address string, timeout time.Duration) (net.Conn, erro if result.error != nil { return } + result.resolved = true if ipv6 { result.Conn, result.error = dialer.DialContext(ctx, "tcp6", net.JoinHostPort(ip.String(), port)) @@ -160,14 +162,20 @@ func dialTimeout(network, address string, timeout time.Duration) (net.Conn, erro return res.Conn, nil } - if res.ipv6 { + if !res.ipv6 { primary = res } else { fallback = res } if primary.done && fallback.done { - return nil, primary.error + if primary.resolved { + return nil, primary.error + } else if fallback.resolved { + return nil, fallback.error + } else { + return nil, primary.error + } } } } diff --git a/dns/iputil.go b/dns/iputil.go index 66dd0c0..7967b44 100644 --- a/dns/iputil.go +++ b/dns/iputil.go @@ -3,10 +3,12 @@ package dns import ( "errors" "net" + "strings" ) var ( errIPNotFound = errors.New("cannot found ip") + errIPVersion = errors.New("ip version error") ) // ResolveIPv4 with a host, return ipv4 @@ -18,8 +20,11 @@ func ResolveIPv4(host string) (net.IP, error) { } ip := net.ParseIP(host) - if ip4 := ip.To4(); ip4 != nil { - return ip4, nil + if ip != nil { + if !strings.Contains(host, ":") { + return ip, nil + } + return nil, errIPVersion } if DefaultResolver != nil { @@ -32,8 +37,8 @@ func ResolveIPv4(host string) (net.IP, error) { } for _, ip := range ipAddrs { - if ip4 := ip.To4(); ip4 != nil { - return ip4, nil + if len(ip) == net.IPv4len { + return ip, nil } } @@ -49,8 +54,11 @@ func ResolveIPv6(host string) (net.IP, error) { } ip := net.ParseIP(host) - if ip6 := ip.To16(); ip6 != nil { - return ip6, nil + if ip != nil { + if strings.Contains(host, ":") { + return ip, nil + } + return nil, errIPVersion } if DefaultResolver != nil { @@ -63,8 +71,8 @@ func ResolveIPv6(host string) (net.IP, error) { } for _, ip := range ipAddrs { - if ip6 := ip.To16(); ip6 != nil { - return ip6, nil + if len(ip) == net.IPv6len { + return ip, nil } } diff --git a/dns/resolver.go b/dns/resolver.go index a9cd283..0ba2342 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -221,13 +221,11 @@ func (r *Resolver) fallbackExchange(m *D.Msg) (msg *D.Msg, err error) { func (r *Resolver) resolveIP(host string, dnsType uint16) (ip net.IP, err error) { ip = net.ParseIP(host) - if dnsType == D.TypeAAAA { - if ip6 := ip.To16(); ip6 != nil { - return ip6, nil - } - } else { - if ip4 := ip.To4(); ip4 != nil { - return ip4, nil + if ip != nil { + if dnsType == D.TypeAAAA && len(ip) == net.IPv6len { + return ip, nil + } else if dnsType == D.TypeA && len(ip) == net.IPv4len { + return ip, nil } }