1
0

Fix: dial tcp with context to avoid margin of error

This commit is contained in:
Dreamacro 2019-10-12 23:55:39 +08:00
parent 0cdc40beb3
commit 7c4a359a2b
14 changed files with 47 additions and 30 deletions

View File

@ -91,7 +91,13 @@ func (p *Proxy) Alive() bool {
} }
func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) { func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
conn, err := p.ProxyAdapter.Dial(metadata) ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel()
return p.DialContext(ctx, metadata)
}
func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
conn, err := p.ProxyAdapter.DialContext(ctx, metadata)
if err != nil { if err != nil {
p.alive = false p.alive = false
} }
@ -157,7 +163,7 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
} }
start := time.Now() start := time.Now()
instance, err := p.Dial(&addr) instance, err := p.DialContext(ctx, &addr)
if err != nil { if err != nil {
return return
} }

View File

@ -1,6 +1,7 @@
package adapters package adapters
import ( import (
"context"
"net" "net"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
@ -10,13 +11,13 @@ type Direct struct {
*Base *Base
} }
func (d *Direct) Dial(metadata *C.Metadata) (C.Conn, error) { func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
address := net.JoinHostPort(metadata.Host, metadata.DstPort) address := net.JoinHostPort(metadata.Host, metadata.DstPort)
if metadata.DstIP != nil { if metadata.DstIP != nil {
address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort) address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort)
} }
c, err := dialTimeout("tcp", address, tcpTimeout) c, err := dialContext(ctx, "tcp", address)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -31,9 +31,9 @@ func (f *Fallback) Now() string {
return proxy.Name() return proxy.Name()
} }
func (f *Fallback) Dial(metadata *C.Metadata) (C.Conn, error) { func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
proxy := f.findAliveProxy() proxy := f.findAliveProxy()
c, err := proxy.Dial(metadata) c, err := proxy.DialContext(ctx, metadata)
if err == nil { if err == nil {
c.AppendToChains(f) c.AppendToChains(f)
} }

View File

@ -3,6 +3,7 @@ package adapters
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
@ -35,8 +36,8 @@ type HttpOption struct {
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
} }
func (h *Http) Dial(metadata *C.Metadata) (C.Conn, error) { func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", h.addr, tcpTimeout) c, err := dialContext(ctx, "tcp", h.addr)
if err == nil && h.tls { if err == nil && h.tls {
cc := tls.Client(c, h.tlsConfig) cc := tls.Client(c, h.tlsConfig)
err = cc.Handshake() err = cc.Handshake()

View File

@ -54,7 +54,7 @@ func jumpHash(key uint64, buckets int32) int32 {
return int32(b) return int32(b)
} }
func (lb *LoadBalance) Dial(metadata *C.Metadata) (c C.Conn, err error) { func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) {
defer func() { defer func() {
if err == nil { if err == nil {
c.AppendToChains(lb) c.AppendToChains(lb)
@ -67,11 +67,11 @@ func (lb *LoadBalance) Dial(metadata *C.Metadata) (c C.Conn, err error) {
idx := jumpHash(key, buckets) idx := jumpHash(key, buckets)
proxy := lb.proxies[idx] proxy := lb.proxies[idx]
if proxy.Alive() { if proxy.Alive() {
c, err = proxy.Dial(metadata) c, err = proxy.DialContext(ctx, metadata)
return return
} }
} }
c, err = lb.proxies[0].Dial(metadata) c, err = lb.proxies[0].DialContext(ctx, metadata)
return return
} }

View File

@ -1,6 +1,7 @@
package adapters package adapters
import ( import (
"context"
"io" "io"
"net" "net"
"time" "time"
@ -12,7 +13,7 @@ type Reject struct {
*Base *Base
} }
func (r *Reject) Dial(metadata *C.Metadata) (C.Conn, error) { func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
return newConn(&NopConn{}, r), nil return newConn(&NopConn{}, r), nil
} }

View File

@ -1,6 +1,7 @@
package adapters package adapters
import ( import (
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"net" "net"
@ -20,8 +21,8 @@ type SelectorOption struct {
Proxies []string `proxy:"proxies"` Proxies []string `proxy:"proxies"`
} }
func (s *Selector) Dial(metadata *C.Metadata) (C.Conn, error) { func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := s.selected.Dial(metadata) c, err := s.selected.DialContext(ctx, metadata)
if err == nil { if err == nil {
c.AppendToChains(s) c.AppendToChains(s)
} }

View File

@ -1,6 +1,7 @@
package adapters package adapters
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -57,8 +58,8 @@ type v2rayObfsOption struct {
Mux bool `obfs:"mux,omitempty"` Mux bool `obfs:"mux,omitempty"`
} }
func (ss *ShadowSocks) Dial(metadata *C.Metadata) (C.Conn, error) { func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", ss.server, tcpTimeout) c, err := dialContext(ctx, "tcp", ss.server)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", ss.server, err.Error()) return nil, fmt.Errorf("%s connect error: %s", ss.server, err.Error())
} }

View File

@ -1,6 +1,7 @@
package adapters package adapters
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@ -26,8 +27,8 @@ type SnellOption struct {
ObfsOpts map[string]interface{} `proxy:"obfs-opts,omitempty"` ObfsOpts map[string]interface{} `proxy:"obfs-opts,omitempty"`
} }
func (s *Snell) Dial(metadata *C.Metadata) (C.Conn, error) { func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", s.server, tcpTimeout) c, err := dialContext(ctx, "tcp", s.server)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error: %s", s.server, err.Error()) return nil, fmt.Errorf("%s connect error: %s", s.server, err.Error())
} }

View File

@ -1,6 +1,7 @@
package adapters package adapters
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io" "io"
@ -33,8 +34,8 @@ type Socks5Option struct {
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
} }
func (ss *Socks5) Dial(metadata *C.Metadata) (C.Conn, error) { func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", ss.addr, tcpTimeout) c, err := dialContext(ctx, "tcp", ss.addr)
if err == nil && ss.tls { if err == nil && ss.tls {
cc := tls.Client(c, ss.tlsConfig) cc := tls.Client(c, ss.tlsConfig)
@ -60,7 +61,9 @@ func (ss *Socks5) Dial(metadata *C.Metadata) (C.Conn, error) {
} }
func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err error) { func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err error) {
c, err := dialTimeout("tcp", ss.addr, tcpTimeout) ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel()
c, err := dialContext(ctx, "tcp", ss.addr)
if err != nil { if err != nil {
err = fmt.Errorf("%s connect error", ss.addr) err = fmt.Errorf("%s connect error", ss.addr)
return return

View File

@ -33,9 +33,9 @@ func (u *URLTest) Now() string {
return u.fast.Name() return u.fast.Name()
} }
func (u *URLTest) Dial(metadata *C.Metadata) (c C.Conn, err error) { func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
c, err = u.fast.Dial(metadata) c, err = u.fast.DialContext(ctx, metadata)
if err == nil { if err == nil {
c.AppendToChains(u) c.AppendToChains(u)
return return

View File

@ -86,15 +86,13 @@ func serializesSocksAddr(metadata *C.Metadata) []byte {
return bytes.Join(buf, nil) return bytes.Join(buf, nil)
} }
func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address) host, port, err := net.SplitHostPort(address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dialer := net.Dialer{} dialer := net.Dialer{}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
returned := make(chan struct{}) returned := make(chan struct{})
defer close(returned) defer close(returned)

View File

@ -1,6 +1,7 @@
package adapters package adapters
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@ -31,8 +32,8 @@ type VmessOption struct {
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
} }
func (v *Vmess) Dial(metadata *C.Metadata) (C.Conn, error) { func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", v.server, tcpTimeout) c, err := dialContext(ctx, "tcp", v.server)
if err != nil { if err != nil {
return nil, fmt.Errorf("%s connect error", v.server) return nil, fmt.Errorf("%s connect error", v.server)
} }
@ -42,7 +43,9 @@ func (v *Vmess) Dial(metadata *C.Metadata) (C.Conn, error) {
} }
func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) {
c, err := dialTimeout("tcp", v.server, tcpTimeout) ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel()
c, err := dialContext(ctx, "tcp", v.server)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("%s connect error", v.server) return nil, nil, fmt.Errorf("%s connect error", v.server)
} }

View File

@ -58,7 +58,7 @@ type PacketConn interface {
type ProxyAdapter interface { type ProxyAdapter interface {
Name() string Name() string
Type() AdapterType Type() AdapterType
Dial(metadata *Metadata) (Conn, error) DialContext(ctx context.Context, metadata *Metadata) (Conn, error)
DialUDP(metadata *Metadata) (PacketConn, net.Addr, error) DialUDP(metadata *Metadata) (PacketConn, net.Addr, error)
SupportUDP() bool SupportUDP() bool
Destroy() Destroy()
@ -74,6 +74,7 @@ type Proxy interface {
ProxyAdapter ProxyAdapter
Alive() bool Alive() bool
DelayHistory() []DelayHistory DelayHistory() []DelayHistory
Dial(metadata *Metadata) (Conn, error)
LastDelay() uint16 LastDelay() uint16
URLTest(ctx context.Context, url string) (uint16, error) URLTest(ctx context.Context, url string) (uint16, error)
} }