1
0

Fix: dial tcp with context to avoid margin of error

This commit is contained in:
Dreamacro 2019-10-12 23:55:39 +08:00
parent 0cdc40beb3
commit 7c4a359a2b
14 changed files with 47 additions and 30 deletions

View File

@ -91,7 +91,13 @@ func (p *Proxy) Alive() bool {
}
func (p *Proxy) Dial(metadata *C.Metadata) (C.Conn, error) {
conn, err := p.ProxyAdapter.Dial(metadata)
ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel()
return p.DialContext(ctx, metadata)
}
func (p *Proxy) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
conn, err := p.ProxyAdapter.DialContext(ctx, metadata)
if err != nil {
p.alive = false
}
@ -157,7 +163,7 @@ func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
}
start := time.Now()
instance, err := p.Dial(&addr)
instance, err := p.DialContext(ctx, &addr)
if err != nil {
return
}

View File

@ -1,6 +1,7 @@
package adapters
import (
"context"
"net"
C "github.com/Dreamacro/clash/constant"
@ -10,13 +11,13 @@ type Direct struct {
*Base
}
func (d *Direct) Dial(metadata *C.Metadata) (C.Conn, error) {
func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
address := net.JoinHostPort(metadata.Host, metadata.DstPort)
if metadata.DstIP != nil {
address = net.JoinHostPort(metadata.DstIP.String(), metadata.DstPort)
}
c, err := dialTimeout("tcp", address, tcpTimeout)
c, err := dialContext(ctx, "tcp", address)
if err != nil {
return nil, err
}

View File

@ -31,9 +31,9 @@ func (f *Fallback) Now() string {
return proxy.Name()
}
func (f *Fallback) Dial(metadata *C.Metadata) (C.Conn, error) {
func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
proxy := f.findAliveProxy()
c, err := proxy.Dial(metadata)
c, err := proxy.DialContext(ctx, metadata)
if err == nil {
c.AppendToChains(f)
}

View File

@ -3,6 +3,7 @@ package adapters
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"encoding/base64"
"errors"
@ -35,8 +36,8 @@ type HttpOption struct {
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
}
func (h *Http) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", h.addr, tcpTimeout)
func (h *Http) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", h.addr)
if err == nil && h.tls {
cc := tls.Client(c, h.tlsConfig)
err = cc.Handshake()

View File

@ -54,7 +54,7 @@ func jumpHash(key uint64, buckets int32) int32 {
return int32(b)
}
func (lb *LoadBalance) Dial(metadata *C.Metadata) (c C.Conn, err error) {
func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) {
defer func() {
if err == nil {
c.AppendToChains(lb)
@ -67,11 +67,11 @@ func (lb *LoadBalance) Dial(metadata *C.Metadata) (c C.Conn, err error) {
idx := jumpHash(key, buckets)
proxy := lb.proxies[idx]
if proxy.Alive() {
c, err = proxy.Dial(metadata)
c, err = proxy.DialContext(ctx, metadata)
return
}
}
c, err = lb.proxies[0].Dial(metadata)
c, err = lb.proxies[0].DialContext(ctx, metadata)
return
}

View File

@ -1,6 +1,7 @@
package adapters
import (
"context"
"io"
"net"
"time"
@ -12,7 +13,7 @@ type Reject struct {
*Base
}
func (r *Reject) Dial(metadata *C.Metadata) (C.Conn, error) {
func (r *Reject) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
return newConn(&NopConn{}, r), nil
}

View File

@ -1,6 +1,7 @@
package adapters
import (
"context"
"encoding/json"
"errors"
"net"
@ -20,8 +21,8 @@ type SelectorOption struct {
Proxies []string `proxy:"proxies"`
}
func (s *Selector) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := s.selected.Dial(metadata)
func (s *Selector) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := s.selected.DialContext(ctx, metadata)
if err == nil {
c.AppendToChains(s)
}

View File

@ -1,6 +1,7 @@
package adapters
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
@ -57,8 +58,8 @@ type v2rayObfsOption struct {
Mux bool `obfs:"mux,omitempty"`
}
func (ss *ShadowSocks) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", ss.server, tcpTimeout)
func (ss *ShadowSocks) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", ss.server)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", ss.server, err.Error())
}

View File

@ -1,6 +1,7 @@
package adapters
import (
"context"
"fmt"
"net"
"strconv"
@ -26,8 +27,8 @@ type SnellOption struct {
ObfsOpts map[string]interface{} `proxy:"obfs-opts,omitempty"`
}
func (s *Snell) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", s.server, tcpTimeout)
func (s *Snell) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", s.server)
if err != nil {
return nil, fmt.Errorf("%s connect error: %s", s.server, err.Error())
}

View File

@ -1,6 +1,7 @@
package adapters
import (
"context"
"crypto/tls"
"fmt"
"io"
@ -33,8 +34,8 @@ type Socks5Option struct {
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
}
func (ss *Socks5) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", ss.addr, tcpTimeout)
func (ss *Socks5) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", ss.addr)
if err == nil && ss.tls {
cc := tls.Client(c, ss.tlsConfig)
@ -60,7 +61,9 @@ func (ss *Socks5) Dial(metadata *C.Metadata) (C.Conn, error) {
}
func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err error) {
c, err := dialTimeout("tcp", ss.addr, tcpTimeout)
ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel()
c, err := dialContext(ctx, "tcp", ss.addr)
if err != nil {
err = fmt.Errorf("%s connect error", ss.addr)
return

View File

@ -33,9 +33,9 @@ func (u *URLTest) Now() string {
return u.fast.Name()
}
func (u *URLTest) Dial(metadata *C.Metadata) (c C.Conn, err error) {
func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata) (c C.Conn, err error) {
for i := 0; i < 3; i++ {
c, err = u.fast.Dial(metadata)
c, err = u.fast.DialContext(ctx, metadata)
if err == nil {
c.AppendToChains(u)
return

View File

@ -86,15 +86,13 @@ func serializesSocksAddr(metadata *C.Metadata) []byte {
return bytes.Join(buf, nil)
}
func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
func dialContext(ctx context.Context, network, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
dialer := net.Dialer{}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
returned := make(chan struct{})
defer close(returned)

View File

@ -1,6 +1,7 @@
package adapters
import (
"context"
"fmt"
"net"
"strconv"
@ -31,8 +32,8 @@ type VmessOption struct {
SkipCertVerify bool `proxy:"skip-cert-verify,omitempty"`
}
func (v *Vmess) Dial(metadata *C.Metadata) (C.Conn, error) {
c, err := dialTimeout("tcp", v.server, tcpTimeout)
func (v *Vmess) DialContext(ctx context.Context, metadata *C.Metadata) (C.Conn, error) {
c, err := dialContext(ctx, "tcp", v.server)
if err != nil {
return nil, fmt.Errorf("%s connect error", v.server)
}
@ -42,7 +43,9 @@ func (v *Vmess) Dial(metadata *C.Metadata) (C.Conn, error) {
}
func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) {
c, err := dialTimeout("tcp", v.server, tcpTimeout)
ctx, cancel := context.WithTimeout(context.Background(), tcpTimeout)
defer cancel()
c, err := dialContext(ctx, "tcp", v.server)
if err != nil {
return nil, nil, fmt.Errorf("%s connect error", v.server)
}

View File

@ -58,7 +58,7 @@ type PacketConn interface {
type ProxyAdapter interface {
Name() string
Type() AdapterType
Dial(metadata *Metadata) (Conn, error)
DialContext(ctx context.Context, metadata *Metadata) (Conn, error)
DialUDP(metadata *Metadata) (PacketConn, net.Addr, error)
SupportUDP() bool
Destroy()
@ -74,6 +74,7 @@ type Proxy interface {
ProxyAdapter
Alive() bool
DelayHistory() []DelayHistory
Dial(metadata *Metadata) (Conn, error)
LastDelay() uint16
URLTest(ctx context.Context, url string) (uint16, error)
}