diff --git a/adapters/outbound/base.go b/adapters/outbound/base.go index b560fc7..505489b 100644 --- a/adapters/outbound/base.go +++ b/adapters/outbound/base.go @@ -91,7 +91,13 @@ func (p *Proxy) Alive() bool { } func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) { - conn, err := p.ProxyAdapter.Dial(metadata) + ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) + defer cancel() + return p.DialContext(ctx, metadata) +} + +func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + conn, err := p.ProxyAdapter.DialContext(ctx, metadata) if err != nil { p.alive = false } @@ -157,7 +163,7 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) { } start := time.Now() - instance, err := p.Dial(&addr) + instance, err := p.DialContext(ctx, &addr) if err != nil { return } diff --git a/adapters/outbound/direct.go b/adapters/outbound/direct.go index 2b0a2a4..22a4171 100644 --- a/adapters/outbound/direct.go +++ b/adapters/outbound/direct.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "net" C "github.com/Dreamacro/clash/constant" @@ -10,13 +11,13 @@ type Direct struct { *Base } -func (d *Direct) Dial(metadata *C.Metadata) (C.Conn, error) { +func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { address := net.JoinHostPort(metadata.Host, metadata.DstPort) if metadata.DstIP != nil { address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort) } - c, err := dialTimeout("tcp", address, tcpTimeout) + c, err := dialContext(ctx, "tcp", address) if err != nil { return nil, err } diff --git a/adapters/outbound/fallback.go b/adapters/outbound/fallback.go index 9b44ede..3e43e63 100644 --- a/adapters/outbound/fallback.go +++ b/adapters/outbound/fallback.go @@ -31,9 +31,9 @@ func (f *Fallback) Now() string { return proxy.Name() } -func (f *Fallback) Dial(metadata *C.Metadata) (C.Conn, error) { +func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { proxy := f.findAliveProxy() - c, err := proxy.Dial(metadata) + c, err := proxy.DialContext(ctx, metadata) if err == nil { c.AppendToChains(f) } diff --git a/adapters/outbound/http.go b/adapters/outbound/http.go index 357ed5d..77b835b 100644 --- a/adapters/outbound/http.go +++ b/adapters/outbound/http.go @@ -3,6 +3,7 @@ package adapters import ( "bufio" "bytes" + "context" "crypto/tls" "encoding/base64" "errors" @@ -35,8 +36,8 @@ type HttpOption struct { SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` } -func (h *Http) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := dialTimeout("tcp", h.addr, tcpTimeout) +func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := dialContext(ctx, "tcp", h.addr) if err == nil && h.tls { cc := tls.Client(c, h.tlsConfig) err = cc.Handshake() diff --git a/adapters/outbound/loadbalance.go b/adapters/outbound/loadbalance.go index c719e8b..d27beec 100644 --- a/adapters/outbound/loadbalance.go +++ b/adapters/outbound/loadbalance.go @@ -54,7 +54,7 @@ func jumpHash(key uint64, buckets int32) int32 { return int32(b) } -func (lb *LoadBalance) Dial(metadata *C.Metadata) (c C.Conn, err error) { +func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) { defer func() { if err == nil { c.AppendToChains(lb) @@ -67,11 +67,11 @@ func (lb *LoadBalance) Dial(metadata *C.Metadata) (c C.Conn, err error) { idx := jumpHash(key, buckets) proxy := lb.proxies[idx] if proxy.Alive() { - c, err = proxy.Dial(metadata) + c, err = proxy.DialContext(ctx, metadata) return } } - c, err = lb.proxies[0].Dial(metadata) + c, err = lb.proxies[0].DialContext(ctx, metadata) return } diff --git a/adapters/outbound/reject.go b/adapters/outbound/reject.go index de395d5..65ab119 100644 --- a/adapters/outbound/reject.go +++ b/adapters/outbound/reject.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "io" "net" "time" @@ -12,7 +13,7 @@ type Reject struct { *Base } -func (r *Reject) Dial(metadata *C.Metadata) (C.Conn, error) { +func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { return newConn(&NopConn{}, r), nil } diff --git a/adapters/outbound/selector.go b/adapters/outbound/selector.go index 31d5a0a..b7ed661 100644 --- a/adapters/outbound/selector.go +++ b/adapters/outbound/selector.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "encoding/json" "errors" "net" @@ -20,8 +21,8 @@ type SelectorOption struct { Proxies []string `proxy:"proxies"` } -func (s *Selector) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := s.selected.Dial(metadata) +func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := s.selected.DialContext(ctx, metadata) if err == nil { c.AppendToChains(s) } diff --git a/adapters/outbound/shadowsocks.go b/adapters/outbound/shadowsocks.go index c46f8fc..22d160b 100644 --- a/adapters/outbound/shadowsocks.go +++ b/adapters/outbound/shadowsocks.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -57,8 +58,8 @@ type v2rayObfsOption struct { Mux bool `obfs:"mux,omitempty"` } -func (ss *ShadowSocks) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := dialTimeout("tcp", ss.server, tcpTimeout) +func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := dialContext(ctx, "tcp", ss.server) if err != nil { return nil, fmt.Errorf("%s connect error: %s", ss.server, err.Error()) } diff --git a/adapters/outbound/snell.go b/adapters/outbound/snell.go index b413119..6b95aac 100644 --- a/adapters/outbound/snell.go +++ b/adapters/outbound/snell.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "fmt" "net" "strconv" @@ -26,8 +27,8 @@ type SnellOption struct { ObfsOpts map[string]interface{} `proxy:"obfs-opts,omitempty"` } -func (s *Snell) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := dialTimeout("tcp", s.server, tcpTimeout) +func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := dialContext(ctx, "tcp", s.server) if err != nil { return nil, fmt.Errorf("%s connect error: %s", s.server, err.Error()) } diff --git a/adapters/outbound/socks5.go b/adapters/outbound/socks5.go index d99daa9..9355cff 100644 --- a/adapters/outbound/socks5.go +++ b/adapters/outbound/socks5.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "crypto/tls" "fmt" "io" @@ -33,8 +34,8 @@ type Socks5Option struct { SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` } -func (ss *Socks5) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := dialTimeout("tcp", ss.addr, tcpTimeout) +func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := dialContext(ctx, "tcp", ss.addr) if err == nil && ss.tls { cc := tls.Client(c, ss.tlsConfig) @@ -60,7 +61,9 @@ func (ss *Socks5) Dial(metadata *C.Metadata) (C.Conn, error) { } func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err error) { - c, err := dialTimeout("tcp", ss.addr, tcpTimeout) + ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) + defer cancel() + c, err := dialContext(ctx, "tcp", ss.addr) if err != nil { err = fmt.Errorf("%s connect error", ss.addr) return diff --git a/adapters/outbound/urltest.go b/adapters/outbound/urltest.go index 60b4bed..2bdb872 100644 --- a/adapters/outbound/urltest.go +++ b/adapters/outbound/urltest.go @@ -33,9 +33,9 @@ func (u *URLTest) Now() string { return u.fast.Name() } -func (u *URLTest) Dial(metadata *C.Metadata) (c C.Conn, err error) { +func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) { for i := 0; i < 3; i++ { - c, err = u.fast.Dial(metadata) + c, err = u.fast.DialContext(ctx, metadata) if err == nil { c.AppendToChains(u) return diff --git a/adapters/outbound/util.go b/adapters/outbound/util.go index 46c4581..22b2d95 100644 --- a/adapters/outbound/util.go +++ b/adapters/outbound/util.go @@ -86,15 +86,13 @@ func serializesSocksAddr(metadata *C.Metadata) []byte { return bytes.Join(buf, nil) } -func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { +func dialContext(ctx context.Context, network, address string) (net.Conn, error) { host, port, err := net.SplitHostPort(address) if err != nil { return nil, err } dialer := net.Dialer{} - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() returned := make(chan struct{}) defer close(returned) diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go index 5b5337a..d61172e 100644 --- a/adapters/outbound/vmess.go +++ b/adapters/outbound/vmess.go @@ -1,6 +1,7 @@ package adapters import ( + "context" "fmt" "net" "strconv" @@ -31,8 +32,8 @@ type VmessOption struct { SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"` } -func (v *Vmess) Dial(metadata *C.Metadata) (C.Conn, error) { - c, err := dialTimeout("tcp", v.server, tcpTimeout) +func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) { + c, err := dialContext(ctx, "tcp", v.server) if err != nil { return nil, fmt.Errorf("%s connect error", v.server) } @@ -42,7 +43,9 @@ func (v *Vmess) Dial(metadata *C.Metadata) (C.Conn, error) { } func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { - c, err := dialTimeout("tcp", v.server, tcpTimeout) + ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout) + defer cancel() + c, err := dialContext(ctx, "tcp", v.server) if err != nil { return nil, nil, fmt.Errorf("%s connect error", v.server) } diff --git a/constant/adapters.go b/constant/adapters.go index 2e155ac..97d65a5 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -58,7 +58,7 @@ type PacketConn interface { type ProxyAdapter interface { Name() string Type() AdapterType - Dial(metadata *Metadata) (Conn, error) + DialContext(ctx context.Context, metadata *Metadata) (Conn, error) DialUDP(metadata *Metadata) (PacketConn, net.Addr, error) SupportUDP() bool Destroy() @@ -74,6 +74,7 @@ type Proxy interface { ProxyAdapter Alive() bool DelayHistory() []DelayHistory + Dial(metadata *Metadata) (Conn, error) LastDelay() uint16 URLTest(ctx context.Context, url string) (uint16, error) }