From 9e7813776821cd6661514014dd0f47b70bdb5746 Mon Sep 17 00:00:00 2001 From: fuyun Date: Thu, 3 Aug 2023 22:30:08 +0800 Subject: [PATCH] Feature: add `inbounds` for flexible binding inbound (#2818) --- adapter/inbound/http.go | 2 + adapter/inbound/https.go | 2 + adapter/inbound/socket.go | 2 + component/process/process_windows.go | 4 +- config/config.go | 41 +-- constant/listener.go | 85 ++++++ constant/metadata.go | 1 + constant/rule.go | 3 + hub/executor/executor.go | 51 ++-- hub/route/configs.go | 51 ++-- hub/route/inbounds.go | 39 +++ hub/route/server.go | 1 + listener/http/server.go | 4 +- listener/listener.go | 420 +++++++++++---------------- listener/mixed/mixed.go | 2 +- listener/redir/tcp.go | 2 +- listener/socks/tcp.go | 2 +- listener/socks/udp.go | 2 +- listener/tproxy/tcp.go | 2 +- listener/tproxy/udp.go | 2 +- rule/parser.go | 6 +- rule/port.go | 35 ++- test/clash_test.go | 43 +-- test/listener_test.go | 78 +++++ test/rule_test.go | 33 +++ 25 files changed, 552 insertions(+), 361 deletions(-) create mode 100644 hub/route/inbounds.go create mode 100644 test/listener_test.go create mode 100644 test/rule_test.go diff --git a/adapter/inbound/http.go b/adapter/inbound/http.go index a7a5552..5242c8e 100644 --- a/adapter/inbound/http.go +++ b/adapter/inbound/http.go @@ -3,6 +3,7 @@ package inbound import ( "net" "net/netip" + "strconv" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/context" @@ -21,6 +22,7 @@ func NewHTTP(target socks5.Addr, source net.Addr, originTarget net.Addr, conn ne if originTarget != nil { if addrPort, err := netip.ParseAddrPort(originTarget.String()); err == nil { metadata.OriginDst = addrPort + metadata.InboundPort = strconv.Itoa(int(addrPort.Port())) } } return context.NewConnContext(conn, metadata) diff --git a/adapter/inbound/https.go b/adapter/inbound/https.go index 5c2b1a7..6d39468 100644 --- a/adapter/inbound/https.go +++ b/adapter/inbound/https.go @@ -4,6 +4,7 @@ import ( "net" "net/http" "net/netip" + "strconv" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/context" @@ -19,6 +20,7 @@ func NewHTTPS(request *http.Request, conn net.Conn) *context.ConnContext { } if addrPort, err := netip.ParseAddrPort(conn.LocalAddr().String()); err == nil { metadata.OriginDst = addrPort + metadata.InboundPort = strconv.Itoa(int(addrPort.Port())) } return context.NewConnContext(conn, metadata) } diff --git a/adapter/inbound/socket.go b/adapter/inbound/socket.go index 31bde9a..efe9e5a 100644 --- a/adapter/inbound/socket.go +++ b/adapter/inbound/socket.go @@ -3,6 +3,7 @@ package inbound import ( "net" "net/netip" + "strconv" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/context" @@ -20,6 +21,7 @@ func NewSocket(target socks5.Addr, conn net.Conn, source C.Type) *context.ConnCo } if addrPort, err := netip.ParseAddrPort(conn.LocalAddr().String()); err == nil { metadata.OriginDst = addrPort + metadata.InboundPort = strconv.Itoa(int(addrPort.Port())) } return context.NewConnContext(conn, metadata) } diff --git a/component/process/process_windows.go b/component/process/process_windows.go index 62a02ea..47ede05 100644 --- a/component/process/process_windows.go +++ b/component/process/process_windows.go @@ -6,9 +6,9 @@ import ( "net/netip" "unsafe" - "golang.org/x/sys/windows" - "github.com/Dreamacro/clash/common/pool" + + "golang.org/x/sys/windows" ) var ( diff --git a/config/config.go b/config/config.go index 09183e6..48b53b0 100644 --- a/config/config.go +++ b/config/config.go @@ -28,25 +28,14 @@ import ( // General config type General struct { - Inbound + LagecyInbound Controller - Mode T.TunnelMode `json:"mode"` - LogLevel log.LogLevel `json:"log-level"` - IPv6 bool `json:"ipv6"` - Interface string `json:"-"` - RoutingMark int `json:"-"` -} - -// Inbound -type Inbound struct { - Port int `json:"port"` - SocksPort int `json:"socks-port"` - RedirPort int `json:"redir-port"` - TProxyPort int `json:"tproxy-port"` - MixedPort int `json:"mixed-port"` - Authentication []string `json:"authentication"` - AllowLan bool `json:"allow-lan"` - BindAddress string `json:"bind-address"` + Authentication []string `json:"authentication"` + Mode T.TunnelMode `json:"mode"` + LogLevel log.LogLevel `json:"log-level"` + IPv6 bool `json:"ipv6"` + Interface string `json:"-"` + RoutingMark int `json:"-"` } // Controller @@ -56,6 +45,16 @@ type Controller struct { Secret string `json:"-"` } +type LagecyInbound struct { + Port int `json:"port"` + SocksPort int `json:"socks-port"` + RedirPort int `json:"redir-port"` + TProxyPort int `json:"tproxy-port"` + MixedPort int `json:"mixed-port"` + AllowLan bool `json:"allow-lan"` + BindAddress string `json:"bind-address"` +} + // DNS config type DNS struct { Enable bool `yaml:"enable"` @@ -98,6 +97,7 @@ type Config struct { Experimental *Experimental Hosts *trie.DomainTrie Profile *Profile + Inbounds []C.Inbound Rules []C.Rule Users []auth.AuthUser Proxies map[string]C.Proxy @@ -207,6 +207,7 @@ type RawConfig struct { ProxyProvider map[string]map[string]any `yaml:"proxy-providers"` Hosts map[string]string `yaml:"hosts"` + Inbounds []C.Inbound `yaml:"inbounds"` DNS RawDNS `yaml:"dns"` Experimental Experimental `yaml:"experimental"` Profile Profile `yaml:"profile"` @@ -275,6 +276,8 @@ func ParseRawConfig(rawCfg *RawConfig) (*Config, error) { } config.General = general + config.Inbounds = rawCfg.Inbounds + proxies, providers, err := parseProxies(rawCfg) if err != nil { return nil, err @@ -326,7 +329,7 @@ func parseGeneral(cfg *RawConfig) (*General, error) { } return &General{ - Inbound: Inbound{ + LagecyInbound: LagecyInbound{ Port: cfg.Port, SocksPort: cfg.SocksPort, RedirPort: cfg.RedirPort, diff --git a/constant/listener.go b/constant/listener.go index 07782a9..afd498a 100644 --- a/constant/listener.go +++ b/constant/listener.go @@ -1,7 +1,92 @@ package constant +import ( + "fmt" + "net" + "net/url" + "strconv" +) + type Listener interface { RawAddress() string Address() string Close() error } + +type InboundType string + +const ( + InboundTypeSocks InboundType = "socks" + InboundTypeRedir InboundType = "redir" + InboundTypeTproxy InboundType = "tproxy" + InboundTypeHTTP InboundType = "http" + InboundTypeMixed InboundType = "mixed" +) + +var supportInboundTypes = map[InboundType]bool{ + InboundTypeSocks: true, + InboundTypeRedir: true, + InboundTypeTproxy: true, + InboundTypeHTTP: true, + InboundTypeMixed: true, +} + +type inbound struct { + Type InboundType `json:"type" yaml:"type"` + BindAddress string `json:"bind-address" yaml:"bind-address"` + IsFromPortCfg bool `json:"-" yaml:"-"` +} + +// Inbound +type Inbound inbound + +// UnmarshalYAML implements yaml.Unmarshaler +func (i *Inbound) UnmarshalYAML(unmarshal func(any) error) error { + var tp string + if err := unmarshal(&tp); err != nil { + var inner inbound + if err := unmarshal(&inner); err != nil { + return err + } + + *i = Inbound(inner) + return nil + } + + inner, err := parseInbound(tp) + if err != nil { + return err + } + *i = Inbound(*inner) + if !supportInboundTypes[i.Type] { + return fmt.Errorf("not support inbound type: %s", i.Type) + } + _, portStr, err := net.SplitHostPort(i.BindAddress) + if err != nil { + return fmt.Errorf("bind address parse error. addr:%s, err:%v", i.BindAddress, err) + } + port, err := strconv.Atoi(portStr) + if err != nil { + return fmt.Errorf("port not a number. addr:%s", i.BindAddress) + } + if port == 0 { + return fmt.Errorf("invalid bind port. addr:%s", i.BindAddress) + } + return nil +} + +func parseInbound(alias string) (*inbound, error) { + u, err := url.Parse(alias) + if err != nil { + return nil, err + } + listenerType := InboundType(u.Scheme) + return &inbound{ + Type: listenerType, + BindAddress: u.Host, + }, nil +} + +func (i *Inbound) ToAlias() string { + return string(i.Type) + "://" + i.BindAddress +} diff --git a/constant/metadata.go b/constant/metadata.go index b92cac7..d8d9eb7 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -69,6 +69,7 @@ type Metadata struct { DstIP net.IP `json:"destinationIP"` SrcPort string `json:"sourcePort"` DstPort string `json:"destinationPort"` + InboundPort string `json:"inboundPort"` Host string `json:"host"` DNSMode DNSMode `json:"dnsMode"` ProcessPath string `json:"processPath"` diff --git a/constant/rule.go b/constant/rule.go index 5cbf8d9..a7ab7e2 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -10,6 +10,7 @@ const ( SrcIPCIDR SrcPort DstPort + InboundPort Process ProcessPath IPSet @@ -36,6 +37,8 @@ func (rt RuleType) String() string { return "SrcPort" case DstPort: return "DstPort" + case InboundPort: + return "InboundPort" case Process: return "Process" case ProcessPath: diff --git a/hub/executor/executor.go b/hub/executor/executor.go index aaf5ba6..3f0d487 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -73,6 +73,7 @@ func ApplyConfig(cfg *config.Config, force bool) { updateHosts(cfg.Hosts) updateProfile(cfg) updateGeneral(cfg.General, force) + updateInbounds(cfg.Inbounds, force) updateDNS(cfg.DNS) updateExperimental(cfg) updateTunnels(cfg.Tunnels) @@ -86,19 +87,19 @@ func GetGeneral() *config.General { } general := &config.General{ - Inbound: config.Inbound{ - Port: ports.Port, - SocksPort: ports.SocksPort, - RedirPort: ports.RedirPort, - TProxyPort: ports.TProxyPort, - MixedPort: ports.MixedPort, - Authentication: authenticator, - AllowLan: listener.AllowLan(), - BindAddress: listener.BindAddress(), + LagecyInbound: config.LagecyInbound{ + Port: ports.Port, + SocksPort: ports.SocksPort, + RedirPort: ports.RedirPort, + TProxyPort: ports.TProxyPort, + MixedPort: ports.MixedPort, + AllowLan: listener.AllowLan(), + BindAddress: listener.BindAddress(), }, - Mode: tunnel.Mode(), - LogLevel: log.Level(), - IPv6: !resolver.DisableIPv6, + Authentication: authenticator, + Mode: tunnel.Mode(), + LogLevel: log.Level(), + IPv6: !resolver.DisableIPv6, } return general @@ -164,6 +165,16 @@ func updateTunnels(tunnels []config.Tunnel) { listener.PatchTunnel(tunnels, tunnel.TCPIn(), tunnel.UDPIn()) } +func updateInbounds(inbounds []C.Inbound, force bool) { + if !force { + return + } + tcpIn := tunnel.TCPIn() + udpIn := tunnel.UDPIn() + + listener.ReCreateListeners(inbounds, tcpIn, udpIn) +} + func updateGeneral(general *config.General, force bool) { log.SetLevel(general.LogLevel) tunnel.SetMode(general.Mode) @@ -184,14 +195,14 @@ func updateGeneral(general *config.General, force bool) { bindAddress := general.BindAddress listener.SetBindAddress(bindAddress) - tcpIn := tunnel.TCPIn() - udpIn := tunnel.UDPIn() - - listener.ReCreateHTTP(general.Port, tcpIn) - listener.ReCreateSocks(general.SocksPort, tcpIn, udpIn) - listener.ReCreateRedir(general.RedirPort, tcpIn, udpIn) - listener.ReCreateTProxy(general.TProxyPort, tcpIn, udpIn) - listener.ReCreateMixed(general.MixedPort, tcpIn, udpIn) + ports := listener.Ports{ + Port: general.Port, + SocksPort: general.SocksPort, + RedirPort: general.RedirPort, + TProxyPort: general.TProxyPort, + MixedPort: general.MixedPort, + } + listener.ReCreatePortsListeners(ports, tunnel.TCPIn(), tunnel.UDPIn()) } func updateUsers(users []auth.AuthUser) { diff --git a/hub/route/configs.go b/hub/route/configs.go index 526e7bb..3841ec4 100644 --- a/hub/route/configs.go +++ b/hub/route/configs.go @@ -6,14 +6,15 @@ import ( "github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/config" - "github.com/Dreamacro/clash/constant" + C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/hub/executor" - P "github.com/Dreamacro/clash/listener" + "github.com/Dreamacro/clash/listener" "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/tunnel" "github.com/go-chi/chi/v5" "github.com/go-chi/render" + "github.com/samber/lo" ) func configRouter() http.Handler { @@ -29,14 +30,6 @@ func getConfigs(w http.ResponseWriter, r *http.Request) { render.JSON(w, r, general) } -func pointerOrDefault(p *int, def int) int { - if p != nil { - return *p - } - - return def -} - func patchConfigs(w http.ResponseWriter, r *http.Request) { general := struct { Port *int `json:"port"` @@ -56,25 +49,6 @@ func patchConfigs(w http.ResponseWriter, r *http.Request) { return } - if general.AllowLan != nil { - P.SetAllowLan(*general.AllowLan) - } - - if general.BindAddress != nil { - P.SetBindAddress(*general.BindAddress) - } - - ports := P.GetPorts() - - tcpIn := tunnel.TCPIn() - udpIn := tunnel.UDPIn() - - P.ReCreateHTTP(pointerOrDefault(general.Port, ports.Port), tcpIn) - P.ReCreateSocks(pointerOrDefault(general.SocksPort, ports.SocksPort), tcpIn, udpIn) - P.ReCreateRedir(pointerOrDefault(general.RedirPort, ports.RedirPort), tcpIn, udpIn) - P.ReCreateTProxy(pointerOrDefault(general.TProxyPort, ports.TProxyPort), tcpIn, udpIn) - P.ReCreateMixed(pointerOrDefault(general.MixedPort, ports.MixedPort), tcpIn, udpIn) - if general.Mode != nil { tunnel.SetMode(*general.Mode) } @@ -87,6 +61,23 @@ func patchConfigs(w http.ResponseWriter, r *http.Request) { resolver.DisableIPv6 = !*general.IPv6 } + if general.AllowLan != nil { + listener.SetAllowLan(*general.AllowLan) + } + + if general.BindAddress != nil { + listener.SetBindAddress(*general.BindAddress) + } + + ports := listener.GetPorts() + ports.Port = lo.FromPtrOr(general.Port, ports.Port) + ports.SocksPort = lo.FromPtrOr(general.SocksPort, ports.SocksPort) + ports.RedirPort = lo.FromPtrOr(general.RedirPort, ports.RedirPort) + ports.TProxyPort = lo.FromPtrOr(general.TProxyPort, ports.TProxyPort) + ports.MixedPort = lo.FromPtrOr(general.MixedPort, ports.MixedPort) + + listener.ReCreatePortsListeners(*ports, tunnel.TCPIn(), tunnel.UDPIn()) + render.NoContent(w, r) } @@ -114,7 +105,7 @@ func updateConfigs(w http.ResponseWriter, r *http.Request) { } } else { if req.Path == "" { - req.Path = constant.Path.Config() + req.Path = C.Path.Config() } if !filepath.IsAbs(req.Path) { render.Status(r, http.StatusBadRequest) diff --git a/hub/route/inbounds.go b/hub/route/inbounds.go new file mode 100644 index 0000000..8e53bd1 --- /dev/null +++ b/hub/route/inbounds.go @@ -0,0 +1,39 @@ +package route + +import ( + "net/http" + + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/listener" + "github.com/Dreamacro/clash/tunnel" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/render" +) + +func inboundRouter() http.Handler { + r := chi.NewRouter() + r.Get("/", getInbounds) + r.Put("/", updateInbounds) + return r +} + +func getInbounds(w http.ResponseWriter, r *http.Request) { + inbounds := listener.GetInbounds() + render.JSON(w, r, render.M{ + "inbounds": inbounds, + }) +} + +func updateInbounds(w http.ResponseWriter, r *http.Request) { + var req []C.Inbound + if err := render.DecodeJSON(r.Body, &req); err != nil { + render.Status(r, http.StatusBadRequest) + render.JSON(w, r, ErrBadRequest) + return + } + tcpIn := tunnel.TCPIn() + udpIn := tunnel.UDPIn() + listener.ReCreateListeners(req, tcpIn, udpIn) + render.NoContent(w, r) +} diff --git a/hub/route/server.go b/hub/route/server.go index 66a5907..0cd3ef5 100644 --- a/hub/route/server.go +++ b/hub/route/server.go @@ -69,6 +69,7 @@ func Start(addr string, secret string) { r.Get("/traffic", traffic) r.Get("/version", version) r.Mount("/configs", configRouter()) + r.Mount("/inbounds", inboundRouter()) r.Mount("/proxies", proxyRouter()) r.Mount("/rules", ruleRouter()) r.Mount("/connections", connectionRouter()) diff --git a/listener/http/server.go b/listener/http/server.go index 1c21cd6..85f6e36 100644 --- a/listener/http/server.go +++ b/listener/http/server.go @@ -29,11 +29,11 @@ func (l *Listener) Close() error { return l.listener.Close() } -func New(addr string, in chan<- C.ConnContext) (*Listener, error) { +func New(addr string, in chan<- C.ConnContext) (C.Listener, error) { return NewWithAuthenticate(addr, in, true) } -func NewWithAuthenticate(addr string, in chan<- C.ConnContext, authenticate bool) (*Listener, error) { +func NewWithAuthenticate(addr string, in chan<- C.ConnContext, authenticate bool) (C.Listener, error) { l, err := net.Listen("tcp", addr) if err != nil { return nil, err diff --git a/listener/listener.go b/listener/listener.go index 972ca2d..f6b1963 100644 --- a/listener/listener.go +++ b/listener/listener.go @@ -25,25 +25,15 @@ var ( allowLan = false bindAddress = "*" - socksListener *socks.Listener - socksUDPListener *socks.UDPListener - httpListener *http.Listener - redirListener *redir.Listener - redirUDPListener *tproxy.UDPListener - tproxyListener *tproxy.Listener - tproxyUDPListener *tproxy.UDPListener - mixedListener *mixed.Listener - mixedUDPLister *socks.UDPListener + tcpListeners = map[C.Inbound]C.Listener{} + udpListeners = map[C.Inbound]C.Listener{} + tunnelTCPListeners = map[string]*tunnel.Listener{} tunnelUDPListeners = map[string]*tunnel.PacketConn{} // lock for recreate function - socksMux sync.Mutex - httpMux sync.Mutex - redirMux sync.Mutex - tproxyMux sync.Mutex - mixedMux sync.Mutex - tunnelMux sync.Mutex + recreateMux sync.Mutex + tunnelMux sync.Mutex ) type Ports struct { @@ -54,6 +44,26 @@ type Ports struct { MixedPort int `json:"mixed-port"` } +var tcpListenerCreators = map[C.InboundType]tcpListenerCreator{ + C.InboundTypeHTTP: http.New, + C.InboundTypeSocks: socks.New, + C.InboundTypeRedir: redir.New, + C.InboundTypeTproxy: tproxy.New, + C.InboundTypeMixed: mixed.New, +} + +var udpListenerCreators = map[C.InboundType]udpListenerCreator{ + C.InboundTypeSocks: socks.NewUDP, + C.InboundTypeRedir: tproxy.NewUDP, + C.InboundTypeTproxy: tproxy.NewUDP, + C.InboundTypeMixed: socks.NewUDP, +} + +type ( + tcpListenerCreator func(addr string, tcpIn chan<- C.ConnContext) (C.Listener, error) + udpListenerCreator func(addr string, udpIn chan<- *inbound.PacketAdapter) (C.Listener, error) +) + func AllowLan() bool { return allowLan } @@ -70,243 +80,119 @@ func SetBindAddress(host string) { bindAddress = host } -func ReCreateHTTP(port int, tcpIn chan<- C.ConnContext) { - httpMux.Lock() - defer httpMux.Unlock() - - var err error - defer func() { - if err != nil { - log.Errorln("Start HTTP server error: %s", err.Error()) - } - }() - - addr := genAddr(bindAddress, port, allowLan) - - if httpListener != nil { - if httpListener.RawAddress() == addr { - return - } - httpListener.Close() - httpListener = nil - } - +func createListener(inbound C.Inbound, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) { + addr := inbound.BindAddress if portIsZero(addr) { return } - - httpListener, err = http.New(addr, tcpIn) - if err != nil { + tcpCreator := tcpListenerCreators[inbound.Type] + udpCreator := udpListenerCreators[inbound.Type] + if tcpCreator == nil && udpCreator == nil { + log.Errorln("inbound type %s not support.", inbound.Type) return } - - log.Infoln("HTTP proxy listening at: %s", httpListener.Address()) + if tcpCreator != nil { + tcpListener, err := tcpCreator(addr, tcpIn) + if err != nil { + log.Errorln("create addr %s tcp listener error. err:%v", addr, err) + return + } + tcpListeners[inbound] = tcpListener + } + if udpCreator != nil { + udpListener, err := udpCreator(addr, udpIn) + if err != nil { + log.Errorln("create addr %s udp listener error. err:%v", addr, err) + return + } + udpListeners[inbound] = udpListener + } + log.Infoln("inbound %s create success.", inbound.ToAlias()) } -func ReCreateSocks(port int, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) { - socksMux.Lock() - defer socksMux.Unlock() - - var err error - defer func() { - if err != nil { - log.Errorln("Start SOCKS server error: %s", err.Error()) +func closeListener(inbound C.Inbound) { + listener := tcpListeners[inbound] + if listener != nil { + if err := listener.Close(); err != nil { + log.Errorln("close tcp address `%s` error. err:%s", inbound.ToAlias(), err.Error()) } - }() - - addr := genAddr(bindAddress, port, allowLan) - - shouldTCPIgnore := false - shouldUDPIgnore := false - - if socksListener != nil { - if socksListener.RawAddress() != addr { - socksListener.Close() - socksListener = nil - } else { - shouldTCPIgnore = true + delete(tcpListeners, inbound) + } + listener = udpListeners[inbound] + if listener != nil { + if err := listener.Close(); err != nil { + log.Errorln("close udp address `%s` error. err:%s", inbound.ToAlias(), err.Error()) } + delete(udpListeners, inbound) } - - if socksUDPListener != nil { - if socksUDPListener.RawAddress() != addr { - socksUDPListener.Close() - socksUDPListener = nil - } else { - shouldUDPIgnore = true - } - } - - if shouldTCPIgnore && shouldUDPIgnore { - return - } - - if portIsZero(addr) { - return - } - - tcpListener, err := socks.New(addr, tcpIn) - if err != nil { - return - } - - udpListener, err := socks.NewUDP(addr, udpIn) - if err != nil { - tcpListener.Close() - return - } - - socksListener = tcpListener - socksUDPListener = udpListener - - log.Infoln("SOCKS proxy listening at: %s", socksListener.Address()) } -func ReCreateRedir(port int, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) { - redirMux.Lock() - defer redirMux.Unlock() +func getNeedCloseAndCreateInbound(originInbounds []C.Inbound, newInbounds []C.Inbound) ([]C.Inbound, []C.Inbound) { + needCloseMap := map[C.Inbound]bool{} + needClose := []C.Inbound{} + needCreate := []C.Inbound{} - var err error - defer func() { - if err != nil { - log.Errorln("Start Redir server error: %s", err.Error()) + for _, inbound := range originInbounds { + needCloseMap[inbound] = true + } + for _, inbound := range newInbounds { + if needCloseMap[inbound] { + delete(needCloseMap, inbound) + } else { + needCreate = append(needCreate, inbound) } - }() - - addr := genAddr(bindAddress, port, allowLan) - - if redirListener != nil { - if redirListener.RawAddress() == addr { - return - } - redirListener.Close() - redirListener = nil } - - if redirUDPListener != nil { - if redirUDPListener.RawAddress() == addr { - return - } - redirUDPListener.Close() - redirUDPListener = nil + for inbound := range needCloseMap { + needClose = append(needClose, inbound) } - - if portIsZero(addr) { - return - } - - redirListener, err = redir.New(addr, tcpIn) - if err != nil { - return - } - - redirUDPListener, err = tproxy.NewUDP(addr, udpIn) - if err != nil { - log.Warnln("Failed to start Redir UDP Listener: %s", err) - } - - log.Infoln("Redirect proxy listening at: %s", redirListener.Address()) + return needClose, needCreate } -func ReCreateTProxy(port int, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) { - tproxyMux.Lock() - defer tproxyMux.Unlock() - - var err error - defer func() { - if err != nil { - log.Errorln("Start TProxy server error: %s", err.Error()) +// only recreate inbound config listener +func ReCreateListeners(inbounds []C.Inbound, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) { + newInbounds := []C.Inbound{} + newInbounds = append(newInbounds, inbounds...) + for _, inbound := range getInbounds() { + if inbound.IsFromPortCfg { + newInbounds = append(newInbounds, inbound) } - }() - - addr := genAddr(bindAddress, port, allowLan) - - if tproxyListener != nil { - if tproxyListener.RawAddress() == addr { - return - } - tproxyListener.Close() - tproxyListener = nil } - - if tproxyUDPListener != nil { - if tproxyUDPListener.RawAddress() == addr { - return - } - tproxyUDPListener.Close() - tproxyUDPListener = nil - } - - if portIsZero(addr) { - return - } - - tproxyListener, err = tproxy.New(addr, tcpIn) - if err != nil { - return - } - - tproxyUDPListener, err = tproxy.NewUDP(addr, udpIn) - if err != nil { - log.Warnln("Failed to start TProxy UDP Listener: %s", err) - } - - log.Infoln("TProxy server listening at: %s", tproxyListener.Address()) + reCreateListeners(newInbounds, tcpIn, udpIn) } -func ReCreateMixed(port int, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) { - mixedMux.Lock() - defer mixedMux.Unlock() +// only recreate ports config listener +func ReCreatePortsListeners(ports Ports, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) { + newInbounds := []C.Inbound{} + newInbounds = append(newInbounds, GetInbounds()...) + newInbounds = addPortInbound(newInbounds, C.InboundTypeHTTP, ports.Port) + newInbounds = addPortInbound(newInbounds, C.InboundTypeSocks, ports.SocksPort) + newInbounds = addPortInbound(newInbounds, C.InboundTypeRedir, ports.RedirPort) + newInbounds = addPortInbound(newInbounds, C.InboundTypeTproxy, ports.TProxyPort) + newInbounds = addPortInbound(newInbounds, C.InboundTypeMixed, ports.MixedPort) + reCreateListeners(newInbounds, tcpIn, udpIn) +} - var err error - defer func() { - if err != nil { - log.Errorln("Start Mixed(http+socks) server error: %s", err.Error()) - } - }() - - addr := genAddr(bindAddress, port, allowLan) - - shouldTCPIgnore := false - shouldUDPIgnore := false - - if mixedListener != nil { - if mixedListener.RawAddress() != addr { - mixedListener.Close() - mixedListener = nil - } else { - shouldTCPIgnore = true - } - } - if mixedUDPLister != nil { - if mixedUDPLister.RawAddress() != addr { - mixedUDPLister.Close() - mixedUDPLister = nil - } else { - shouldUDPIgnore = true - } +func addPortInbound(inbounds []C.Inbound, inboundType C.InboundType, port int) []C.Inbound { + if port != 0 { + inbounds = append(inbounds, C.Inbound{ + Type: inboundType, + BindAddress: genAddr(bindAddress, port, allowLan), + IsFromPortCfg: true, + }) } + return inbounds +} - if shouldTCPIgnore && shouldUDPIgnore { - return +func reCreateListeners(inbounds []C.Inbound, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) { + recreateMux.Lock() + defer recreateMux.Unlock() + needClose, needCreate := getNeedCloseAndCreateInbound(getInbounds(), inbounds) + for _, inbound := range needClose { + closeListener(inbound) } - - if portIsZero(addr) { - return + for _, inbound := range needCreate { + createListener(inbound, tcpIn, udpIn) } - - mixedListener, err = mixed.New(addr, tcpIn) - if err != nil { - return - } - - mixedUDPLister, err = socks.NewUDP(addr, udpIn) - if err != nil { - mixedListener.Close() - return - } - - log.Infoln("Mixed(http+socks) proxy listening at: %s", mixedListener.Address()) } func PatchTunnel(tunnels []config.Tunnel, tcpIn chan<- C.ConnContext, udpIn chan<- *inbound.PacketAdapter) { @@ -398,43 +284,55 @@ func PatchTunnel(tunnels []config.Tunnel, tcpIn chan<- C.ConnContext, udpIn chan } } +func GetInbounds() []C.Inbound { + return lo.Filter(getInbounds(), func(inbound C.Inbound, idx int) bool { + return !inbound.IsFromPortCfg + }) +} + +// GetInbounds return the inbounds of proxy servers +func getInbounds() []C.Inbound { + var inbounds []C.Inbound + for inbound := range tcpListeners { + inbounds = append(inbounds, inbound) + } + for inbound := range udpListeners { + if _, ok := tcpListeners[inbound]; !ok { + inbounds = append(inbounds, inbound) + } + } + return inbounds +} + // GetPorts return the ports of proxy servers func GetPorts() *Ports { ports := &Ports{} - - if httpListener != nil { - _, portStr, _ := net.SplitHostPort(httpListener.Address()) - port, _ := strconv.Atoi(portStr) - ports.Port = port + for _, inbound := range getInbounds() { + fillPort(inbound, ports) } - - if socksListener != nil { - _, portStr, _ := net.SplitHostPort(socksListener.Address()) - port, _ := strconv.Atoi(portStr) - ports.SocksPort = port - } - - if redirListener != nil { - _, portStr, _ := net.SplitHostPort(redirListener.Address()) - port, _ := strconv.Atoi(portStr) - ports.RedirPort = port - } - - if tproxyListener != nil { - _, portStr, _ := net.SplitHostPort(tproxyListener.Address()) - port, _ := strconv.Atoi(portStr) - ports.TProxyPort = port - } - - if mixedListener != nil { - _, portStr, _ := net.SplitHostPort(mixedListener.Address()) - port, _ := strconv.Atoi(portStr) - ports.MixedPort = port - } - return ports } +func fillPort(inbound C.Inbound, ports *Ports) { + if inbound.IsFromPortCfg { + port := getPort(inbound.BindAddress) + switch inbound.Type { + case C.InboundTypeHTTP: + ports.Port = port + case C.InboundTypeSocks: + ports.SocksPort = port + case C.InboundTypeTproxy: + ports.TProxyPort = port + case C.InboundTypeRedir: + ports.RedirPort = port + case C.InboundTypeMixed: + ports.MixedPort = port + default: + // do nothing + } + } +} + func portIsZero(addr string) bool { _, port, err := net.SplitHostPort(addr) if port == "0" || port == "" || err != nil { @@ -453,3 +351,15 @@ func genAddr(host string, port int, allowLan bool) string { return fmt.Sprintf("127.0.0.1:%d", port) } + +func getPort(addr string) int { + _, portStr, err := net.SplitHostPort(addr) + if err != nil { + return 0 + } + port, err := strconv.Atoi(portStr) + if err != nil { + return 0 + } + return port +} diff --git a/listener/mixed/mixed.go b/listener/mixed/mixed.go index f720e49..c8ff9e9 100644 --- a/listener/mixed/mixed.go +++ b/listener/mixed/mixed.go @@ -35,7 +35,7 @@ func (l *Listener) Close() error { return l.listener.Close() } -func New(addr string, in chan<- C.ConnContext) (*Listener, error) { +func New(addr string, in chan<- C.ConnContext) (C.Listener, error) { l, err := net.Listen("tcp", addr) if err != nil { return nil, err diff --git a/listener/redir/tcp.go b/listener/redir/tcp.go index 15c98a8..5f4b903 100644 --- a/listener/redir/tcp.go +++ b/listener/redir/tcp.go @@ -29,7 +29,7 @@ func (l *Listener) Close() error { return l.listener.Close() } -func New(addr string, in chan<- C.ConnContext) (*Listener, error) { +func New(addr string, in chan<- C.ConnContext) (C.Listener, error) { l, err := net.Listen("tcp", addr) if err != nil { return nil, err diff --git a/listener/socks/tcp.go b/listener/socks/tcp.go index 7cce32e..bdf7501 100644 --- a/listener/socks/tcp.go +++ b/listener/socks/tcp.go @@ -34,7 +34,7 @@ func (l *Listener) Close() error { return l.listener.Close() } -func New(addr string, in chan<- C.ConnContext) (*Listener, error) { +func New(addr string, in chan<- C.ConnContext) (C.Listener, error) { l, err := net.Listen("tcp", addr) if err != nil { return nil, err diff --git a/listener/socks/udp.go b/listener/socks/udp.go index 5ef4216..dd230f8 100644 --- a/listener/socks/udp.go +++ b/listener/socks/udp.go @@ -33,7 +33,7 @@ func (l *UDPListener) Close() error { return l.packetConn.Close() } -func NewUDP(addr string, in chan<- *inbound.PacketAdapter) (*UDPListener, error) { +func NewUDP(addr string, in chan<- *inbound.PacketAdapter) (C.Listener, error) { l, err := net.ListenPacket("udp", addr) if err != nil { return nil, err diff --git a/listener/tproxy/tcp.go b/listener/tproxy/tcp.go index 1a09f36..c6365e6 100644 --- a/listener/tproxy/tcp.go +++ b/listener/tproxy/tcp.go @@ -36,7 +36,7 @@ func (l *Listener) handleTProxy(conn net.Conn, in chan<- C.ConnContext) { in <- inbound.NewSocket(target, conn, C.TPROXY) } -func New(addr string, in chan<- C.ConnContext) (*Listener, error) { +func New(addr string, in chan<- C.ConnContext) (C.Listener, error) { l, err := net.Listen("tcp", addr) if err != nil { return nil, err diff --git a/listener/tproxy/udp.go b/listener/tproxy/udp.go index 4d8f6fc..3c40360 100644 --- a/listener/tproxy/udp.go +++ b/listener/tproxy/udp.go @@ -32,7 +32,7 @@ func (l *UDPListener) Close() error { return l.packetConn.Close() } -func NewUDP(addr string, in chan<- *inbound.PacketAdapter) (*UDPListener, error) { +func NewUDP(addr string, in chan<- *inbound.PacketAdapter) (C.Listener, error) { l, err := net.ListenPacket("udp", addr) if err != nil { return nil, err diff --git a/rule/parser.go b/rule/parser.go index a26ef0a..90c54d0 100644 --- a/rule/parser.go +++ b/rule/parser.go @@ -28,9 +28,11 @@ func ParseRule(tp, payload, target string, params []string) (C.Rule, error) { case "SRC-IP-CIDR": parsed, parseErr = NewIPCIDR(payload, target, WithIPCIDRSourceIP(true), WithIPCIDRNoResolve(true)) case "SRC-PORT": - parsed, parseErr = NewPort(payload, target, true) + parsed, parseErr = NewPort(payload, target, PortTypeSrc) case "DST-PORT": - parsed, parseErr = NewPort(payload, target, false) + parsed, parseErr = NewPort(payload, target, PortTypeDest) + case "INBOUND-PORT": + parsed, parseErr = NewPort(payload, target, PortTypeInbound) case "PROCESS-NAME": parsed, parseErr = NewProcess(payload, target, true) case "PROCESS-PATH": diff --git a/rule/port.go b/rule/port.go index 2cc7a7a..ea58720 100644 --- a/rule/port.go +++ b/rule/port.go @@ -1,29 +1,50 @@ package rules import ( + "fmt" "strconv" C "github.com/Dreamacro/clash/constant" ) +type PortType int + +const ( + PortTypeSrc PortType = iota + PortTypeDest + PortTypeInbound +) + type Port struct { adapter string port string - isSource bool + portType PortType } func (p *Port) RuleType() C.RuleType { - if p.isSource { + switch p.portType { + case PortTypeSrc: return C.SrcPort + case PortTypeDest: + return C.DstPort + case PortTypeInbound: + return C.InboundPort + default: + panic(fmt.Errorf("unknown port type: %v", p.portType)) } - return C.DstPort } func (p *Port) Match(metadata *C.Metadata) bool { - if p.isSource { + switch p.portType { + case PortTypeSrc: return metadata.SrcPort == p.port + case PortTypeDest: + return metadata.DstPort == p.port + case PortTypeInbound: + return metadata.InboundPort == p.port + default: + panic(fmt.Errorf("unknown port type: %v", p.portType)) } - return metadata.DstPort == p.port } func (p *Port) Adapter() string { @@ -42,7 +63,7 @@ func (p *Port) ShouldFindProcess() bool { return false } -func NewPort(port string, adapter string, isSource bool) (*Port, error) { +func NewPort(port string, adapter string, portType PortType) (*Port, error) { _, err := strconv.ParseUint(port, 10, 16) if err != nil { return nil, errPayload @@ -50,6 +71,6 @@ func NewPort(port string, adapter string, isSource bool) (*Port, error) { return &Port{ adapter: adapter, port: port, - isSource: isSource, + portType: portType, }, nil } diff --git a/test/clash_test.go b/test/clash_test.go index 84606ab..e08d17e 100644 --- a/test/clash_test.go +++ b/test/clash_test.go @@ -209,23 +209,27 @@ func newLargeDataPair() (chan hashPair, chan hashPair, func(t *testing.T) error) return pingCh, pongCh, test } -func testPingPongWithSocksPort(t *testing.T, port int) { +func testPingPongWithSocksPort(t *testing.T, port int) error { + l, err := Listen("tcp", ":10001") + require.NoError(t, err) + defer l.Close() + pingCh, pongCh, test := newPingPongPair() go func() { - l, err := Listen("tcp", ":10001") - require.NoError(t, err) - defer l.Close() - c, err := l.Accept() - require.NoError(t, err) + if err != nil { + return + } buf := make([]byte, 4) - _, err = io.ReadFull(c, buf) - require.NoError(t, err) + if _, err = io.ReadFull(c, buf); err != nil { + return + } pingCh <- buf - _, err = c.Write([]byte("pong")) - require.NoError(t, err) + if _, err = c.Write([]byte("pong")); err != nil { + return + } }() go func() { @@ -233,20 +237,23 @@ func testPingPongWithSocksPort(t *testing.T, port int) { require.NoError(t, err) defer c.Close() - _, err = socks5.ClientHandshake(c, socks5.ParseAddr("127.0.0.1:10001"), socks5.CmdConnect, nil) - require.NoError(t, err) + if _, err = socks5.ClientHandshake(c, socks5.ParseAddr("127.0.0.1:10001"), socks5.CmdConnect, nil); err != nil { + return + } - _, err = c.Write([]byte("ping")) - require.NoError(t, err) + if _, err = c.Write([]byte("ping")); err != nil { + return + } buf := make([]byte, 4) - _, err = io.ReadFull(c, buf) - require.NoError(t, err) + if _, err = io.ReadFull(c, buf); err != nil { + return + } pongCh <- buf }() - test(t) + return test(t) } func testPingPongWithConn(t *testing.T, c net.Conn) error { @@ -655,7 +662,7 @@ log-level: silent defer cleanup() require.True(t, TCPing(net.JoinHostPort("127.0.0.1", "10000"))) - testPingPongWithSocksPort(t, 10000) + require.NoError(t, testPingPongWithSocksPort(t, 10000)) } func Benchmark_Direct(b *testing.B) { diff --git a/test/listener_test.go b/test/listener_test.go new file mode 100644 index 0000000..8332a36 --- /dev/null +++ b/test/listener_test.go @@ -0,0 +1,78 @@ +package main + +import ( + "net" + "strconv" + "testing" + "time" + + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/listener" + "github.com/Dreamacro/clash/tunnel" + + "github.com/stretchr/testify/require" +) + +func TestClash_Listener(t *testing.T) { + basic := ` +log-level: silent +port: 7890 +socks-port: 7891 +redir-port: 7892 +tproxy-port: 7893 +mixed-port: 7894 +` + + err := parseAndApply(basic) + require.NoError(t, err) + defer cleanup() + + time.Sleep(waitTime) + + for i := 7890; i <= 7894; i++ { + require.True(t, TCPing(net.JoinHostPort("127.0.0.1", strconv.Itoa(i))), "tcp port %d", i) + } +} + +func TestClash_ListenerCreate(t *testing.T) { + basic := ` +log-level: silent +` + err := parseAndApply(basic) + require.NoError(t, err) + defer cleanup() + + time.Sleep(waitTime) + tcpIn := tunnel.TCPIn() + udpIn := tunnel.UDPIn() + + ports := listener.Ports{ + Port: 7890, + } + listener.ReCreatePortsListeners(ports, tcpIn, udpIn) + require.True(t, TCPing("127.0.0.1:7890")) + require.Equal(t, ports, *listener.GetPorts()) + + inbounds := []C.Inbound{ + { + Type: C.InboundTypeHTTP, + BindAddress: "127.0.0.1:7891", + }, + } + listener.ReCreateListeners(inbounds, tcpIn, udpIn) + require.True(t, TCPing("127.0.0.1:7890")) + require.Equal(t, ports, *listener.GetPorts()) + + require.True(t, TCPing("127.0.0.1:7891")) + require.Equal(t, len(inbounds), len(listener.GetInbounds())) + + ports.Port = 0 + ports.SocksPort = 7892 + listener.ReCreatePortsListeners(ports, tcpIn, udpIn) + require.False(t, TCPing("127.0.0.1:7890")) + require.True(t, TCPing("127.0.0.1:7892")) + require.Equal(t, ports, *listener.GetPorts()) + + require.True(t, TCPing("127.0.0.1:7891")) + require.Equal(t, len(inbounds), len(listener.GetInbounds())) +} diff --git a/test/rule_test.go b/test/rule_test.go new file mode 100644 index 0000000..3b108d0 --- /dev/null +++ b/test/rule_test.go @@ -0,0 +1,33 @@ +package main + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClash_RuleInbound(t *testing.T) { + basic := ` +socks-port: 7890 +inbounds: + - socks://127.0.0.1:7891 + - type: socks + bind-address: 127.0.0.1:7892 +rules: + - INBOUND-PORT,7891,REJECT +log-level: silent +` + + err := parseAndApply(basic) + require.NoError(t, err) + defer cleanup() + + require.True(t, TCPing(net.JoinHostPort("127.0.0.1", "7890"))) + require.True(t, TCPing(net.JoinHostPort("127.0.0.1", "7891"))) + require.True(t, TCPing(net.JoinHostPort("127.0.0.1", "7892"))) + + require.Error(t, testPingPongWithSocksPort(t, 7891)) + require.NoError(t, testPingPongWithSocksPort(t, 7890)) + require.NoError(t, testPingPongWithSocksPort(t, 7892)) +}