refactor: proxymux

This commit rewrites proxymux package to provide following functions:

+ proxymux.ListenSOCKS(address string)
+ proxymux.ListenHTTP(address string)

both are drop-in replacements for net.Listen("tcp", address)

The above functions can be called with the same address to take
advantage of the mux feature.

Tests are not included, but we will have them very soon.

This commit should be in PR #1006, but I ended up with it in a separate
branch here. Please rebase if you want to merge it.
This commit is contained in:
Haruue 2024-04-11 20:53:28 +08:00
parent d34ff757c3
commit 34574e0339
No known key found for this signature in database
GPG Key ID: F6083B28CBCBC148
5 changed files with 301 additions and 145 deletions

View File

@ -15,13 +15,14 @@ import (
"strings" "strings"
"time" "time"
"github.com/apernet/hysteria/app/internal/proxymux"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/apernet/hysteria/app/internal/forwarding" "github.com/apernet/hysteria/app/internal/forwarding"
"github.com/apernet/hysteria/app/internal/http" "github.com/apernet/hysteria/app/internal/http"
"github.com/apernet/hysteria/app/internal/proxymux"
"github.com/apernet/hysteria/app/internal/redirect" "github.com/apernet/hysteria/app/internal/redirect"
"github.com/apernet/hysteria/app/internal/socks5" "github.com/apernet/hysteria/app/internal/socks5"
"github.com/apernet/hysteria/app/internal/tproxy" "github.com/apernet/hysteria/app/internal/tproxy"
@ -64,7 +65,6 @@ type clientConfig struct {
Bandwidth clientConfigBandwidth `mapstructure:"bandwidth"` Bandwidth clientConfigBandwidth `mapstructure:"bandwidth"`
FastOpen bool `mapstructure:"fastOpen"` FastOpen bool `mapstructure:"fastOpen"`
Lazy bool `mapstructure:"lazy"` Lazy bool `mapstructure:"lazy"`
Mixed *mixedConfig `mapstructure:"mixed"`
SOCKS5 *socks5Config `mapstructure:"socks5"` SOCKS5 *socks5Config `mapstructure:"socks5"`
HTTP *httpConfig `mapstructure:"http"` HTTP *httpConfig `mapstructure:"http"`
TCPForwarding []tcpForwardingEntry `mapstructure:"tcpForwarding"` TCPForwarding []tcpForwardingEntry `mapstructure:"tcpForwarding"`
@ -115,14 +115,6 @@ type clientConfigBandwidth struct {
Down string `mapstructure:"down"` Down string `mapstructure:"down"`
} }
type mixedConfig struct {
Listen string `mapstructure:"listen"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
DisableUDP bool `mapstructure:"disableUDP"`
Realm string `mapstructure:"realm"`
}
type socks5Config struct { type socks5Config struct {
Listen string `mapstructure:"listen"` Listen string `mapstructure:"listen"`
Username string `mapstructure:"username"` Username string `mapstructure:"username"`
@ -457,11 +449,6 @@ func runClient(cmd *cobra.Command, args []string) {
// Register modes // Register modes
var runner clientModeRunner var runner clientModeRunner
if config.Mixed != nil {
runner.Add("Mixed server", func() error {
return clientMixed(*config.Mixed, c)
})
}
if config.SOCKS5 != nil { if config.SOCKS5 != nil {
runner.Add("SOCKS5 server", func() error { runner.Add("SOCKS5 server", func() error {
return clientSOCKS5(*config.SOCKS5, c) return clientSOCKS5(*config.SOCKS5, c)
@ -542,50 +529,11 @@ func (r *clientModeRunner) Run() {
} }
} }
func clientMixed(config mixedConfig, c client.Client) error {
if config.Listen == "" {
return configError{Field: "listen", Err: errors.New("listen address is empty")}
}
l, err := correctnet.Listen("tcp", config.Listen)
if err != nil {
return configError{Field: "listen", Err: err}
}
var authFunc func(username, password string) bool
username, password := config.Username, config.Password
if username != "" && password != "" {
authFunc = func(u, p string) bool {
return u == username && p == password
}
}
s := socks5.Server{
HyClient: c,
AuthFunc: authFunc,
DisableUDP: config.DisableUDP,
EventLogger: &socks5Logger{},
}
logger.Info("SOCKS5 server listening", zap.String("addr", config.Listen))
h := http.Server{
HyClient: c,
AuthFunc: authFunc,
AuthRealm: config.Realm,
EventLogger: &httpLogger{},
}
logger.Info("HTTP proxy server listening", zap.String("addr", config.Listen))
socks5Ln, httpLn := proxymux.SplitSOCKSAndHTTP(l)
go func() {
if err = h.Serve(httpLn); err != nil {
logger.Fatal(err.Error())
}
}()
return s.Serve(socks5Ln)
}
func clientSOCKS5(config socks5Config, c client.Client) error { func clientSOCKS5(config socks5Config, c client.Client) error {
if config.Listen == "" { if config.Listen == "" {
return configError{Field: "listen", Err: errors.New("listen address is empty")} return configError{Field: "listen", Err: errors.New("listen address is empty")}
} }
l, err := correctnet.Listen("tcp", config.Listen) l, err := proxymux.ListenSOCKS(config.Listen)
if err != nil { if err != nil {
return configError{Field: "listen", Err: err} return configError{Field: "listen", Err: err}
} }
@ -610,7 +558,7 @@ func clientHTTP(config httpConfig, c client.Client) error {
if config.Listen == "" { if config.Listen == "" {
return configError{Field: "listen", Err: errors.New("listen address is empty")} return configError{Field: "listen", Err: errors.New("listen address is empty")}
} }
l, err := correctnet.Listen("tcp", config.Listen) l, err := proxymux.ListenHTTP(config.Listen)
if err != nil { if err != nil {
return configError{Field: "listen", Err: err} return configError{Field: "listen", Err: err}
} }

View File

@ -53,13 +53,6 @@ func TestClientConfig(t *testing.T) {
}, },
FastOpen: true, FastOpen: true,
Lazy: true, Lazy: true,
Mixed: &mixedConfig{
Listen: "127.0.0.1:1080",
Username: "anon",
Password: "bro",
DisableUDP: true,
Realm: "martian",
},
SOCKS5: &socks5Config{ SOCKS5: &socks5Config{
Listen: "127.0.0.1:1080", Listen: "127.0.0.1:1080",
Username: "anon", Username: "anon",

View File

@ -35,13 +35,6 @@ fastOpen: true
lazy: true lazy: true
mixed:
listen: 127.0.0.1:1080
username: anon
password: bro
disableUDP: true
realm: martian
socks5: socks5:
listen: 127.0.0.1:1080 listen: 127.0.0.1:1080
username: anon username: anon

View File

@ -0,0 +1,72 @@
package proxymux
import (
"net"
"sync"
"github.com/apernet/hysteria/extras/correctnet"
)
type muxManager struct {
listeners map[string]*muxListener
lock sync.Mutex
}
var globalMuxManager *muxManager
func init() {
globalMuxManager = &muxManager{
listeners: make(map[string]*muxListener),
}
}
func (m *muxManager) GetOrCreate(address string) (*muxListener, error) {
key, err := m.canonicalizeAddrPort(address)
if err != nil {
return nil, err
}
m.lock.Lock()
defer m.lock.Unlock()
if ml, ok := m.listeners[key]; ok {
return ml, nil
}
listener, err := correctnet.Listen("tcp", key)
if err != nil {
return nil, err
}
ml := newMuxListener(listener, func() {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.listeners, key)
})
m.listeners[key] = ml
return ml, nil
}
func (m *muxManager) canonicalizeAddrPort(address string) (string, error) {
taddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return "", err
}
return taddr.String(), nil
}
func ListenHTTP(address string) (net.Listener, error) {
ml, err := globalMuxManager.GetOrCreate(address)
if err != nil {
return nil, err
}
return ml.ListenHTTP()
}
func ListenSOCKS(address string) (net.Listener, error) {
ml, err := globalMuxManager.GetOrCreate(address)
if err != nil {
return nil, err
}
return ml.ListenSOCKS()
}

View File

@ -1,124 +1,257 @@
// Package proxymux splits a net.Listener in two, routing SOCKS5
// connections to one and HTTP requests to the other.
//
// It allows for hosting both a SOCKS5 proxy and an HTTP proxy on the
// same listener.
package proxymux package proxymux
import ( import (
"errors"
"fmt"
"io" "io"
"net" "net"
"sync" "sync"
"time"
) )
// SplitSOCKSAndHTTP accepts connections on ln and passes connections func newMuxListener(listener net.Listener, deleteFunc func()) *muxListener {
// through to either socksListener or httpListener, depending the l := &muxListener{
// first byte sent by the client. base: listener,
func SplitSOCKSAndHTTP(ln net.Listener) (socksListener, httpListener net.Listener) { acceptChan: make(chan net.Conn),
sl := &listener{ closeChan: make(chan struct{}),
addr: ln.Addr(), deleteFunc: deleteFunc,
c: make(chan net.Conn),
closed: make(chan struct{}),
} }
hl := &listener{ go l.acceptLoop()
addr: ln.Addr(), go l.mainLoop()
c: make(chan net.Conn), return l
closed: make(chan struct{}),
}
go splitSOCKSAndHTTPListener(ln, sl, hl)
return sl, hl
} }
func splitSOCKSAndHTTPListener(ln net.Listener, sl, hl *listener) { type muxListener struct {
lock sync.Mutex
base net.Listener
acceptErr error
acceptChan chan net.Conn
closeChan chan struct{}
socksListener *subListener
httpListener *subListener
deleteFunc func()
}
func (l *muxListener) acceptLoop() {
defer close(l.acceptChan)
for { for {
conn, err := ln.Accept() conn, err := l.base.Accept()
if err != nil { if err != nil {
sl.Close() l.lock.Lock()
hl.Close() l.acceptErr = err
l.lock.Unlock()
return return
} }
go routeConn(conn, sl, hl) select {
case <-l.closeChan:
return
case l.acceptChan <- conn:
}
} }
} }
func routeConn(c net.Conn, socksListener, httpListener *listener) { func (l *muxListener) mainLoop() {
if err := c.SetReadDeadline(time.Now().Add(15 * time.Second)); err != nil { defer func() {
c.Close() l.deleteFunc()
return l.base.Close()
}
close(l.closeChan)
l.lock.Lock()
defer l.lock.Unlock()
if sl := l.httpListener; sl != nil {
close(sl.acceptChan)
l.httpListener = nil
}
if sl := l.socksListener; sl != nil {
close(sl.acceptChan)
l.socksListener = nil
}
}()
for {
var socksCloseChan, httpCloseChan chan struct{}
if l.httpListener != nil {
httpCloseChan = l.httpListener.closeChan
}
if l.socksListener != nil {
socksCloseChan = l.socksListener.closeChan
}
select {
case <-l.closeChan:
return
case conn, ok := <-l.acceptChan:
if !ok {
return
}
go l.dispatch(conn)
case <-socksCloseChan:
l.lock.Lock()
l.socksListener = nil
l.lock.Unlock()
if l.checkIdle() {
return
}
case <-httpCloseChan:
l.lock.Lock()
l.httpListener = nil
l.lock.Unlock()
if l.checkIdle() {
return
}
}
}
}
func (l *muxListener) dispatch(conn net.Conn) {
var b [1]byte var b [1]byte
if _, err := io.ReadFull(c, b[:]); err != nil { if _, err := io.ReadFull(conn, b[:]); err != nil {
c.Close() conn.Close()
return return
} }
if err := c.SetReadDeadline(time.Time{}); err != nil { l.lock.Lock()
c.Close() var target *subListener
return
}
conn := &connWithOneByte{
Conn: c,
b: b[0],
}
// First byte of a SOCKS5 session is a version byte set to 5.
var ln *listener
if b[0] == 5 { if b[0] == 5 {
ln = socksListener target = l.socksListener
} else { } else {
ln = httpListener target = l.httpListener
} }
l.lock.Unlock()
if target == nil {
conn.Close()
return
}
wconn := &connWithOneByte{Conn: conn, b: b[0]}
select { select {
case ln.c <- conn: case <-target.closeChan:
case <-ln.closed: case target.acceptChan <- wconn:
c.Close()
} }
} }
type listener struct { func (l *muxListener) checkIdle() bool {
addr net.Addr l.lock.Lock()
c chan net.Conn defer l.lock.Unlock()
mu sync.Mutex // serializes close() on closed. It's okay to receive on closed without locking.
closed chan struct{} return l.httpListener == nil && l.socksListener == nil
} }
func (ln *listener) Accept() (net.Conn, error) { func (l *muxListener) getAndClearAcceptError() error {
// Once closed, reliably stay closed, don't race with attempts at l.lock.Lock()
// further connections. defer l.lock.Unlock()
if l.acceptErr == nil {
return nil
}
err := l.acceptErr
l.acceptErr = nil
return err
}
func (l *muxListener) ListenHTTP() (net.Listener, error) {
l.lock.Lock()
defer l.lock.Unlock()
if l.httpListener != nil {
return nil, OpErr{
Addr: l.base.Addr(),
Protocol: "http",
Op: "bind-protocol",
Err: ErrProtocolInUse,
}
}
select { select {
case <-ln.closed: case <-l.closeChan:
return nil, net.ErrClosed return nil, net.ErrClosed
default: default:
} }
sl := newSubListener(l.getAndClearAcceptError, l.base.Addr)
l.httpListener = sl
return sl, nil
}
func (l *muxListener) ListenSOCKS() (net.Listener, error) {
l.lock.Lock()
defer l.lock.Unlock()
if l.socksListener != nil {
return nil, OpErr{
Addr: l.base.Addr(),
Protocol: "socks",
Op: "bind-protocol",
Err: ErrProtocolInUse,
}
}
select { select {
case ret := <-ln.c: case <-l.closeChan:
return ret, nil
case <-ln.closed:
return nil, net.ErrClosed return nil, net.ErrClosed
default:
}
sl := newSubListener(l.getAndClearAcceptError, l.base.Addr)
l.socksListener = sl
return sl, nil
}
func newSubListener(acceptErrorFunc func() error, addrFunc func() net.Addr) *subListener {
return &subListener{
acceptChan: make(chan net.Conn),
acceptErrorFunc: acceptErrorFunc,
closeChan: make(chan struct{}),
addrFunc: addrFunc,
} }
} }
func (ln *listener) Close() error { type subListener struct {
ln.mu.Lock() // receive connections or closure from upstream
defer ln.mu.Unlock() acceptChan chan net.Conn
// get an error of Accept() from upstream
acceptErrorFunc func() error
// notify upstream that we are closed
closeChan chan struct{}
// Listener.Addr() implementation of base listener
addrFunc func() net.Addr
}
func (l *subListener) Accept() (net.Conn, error) {
select { select {
case <-ln.closed: case <-l.closeChan:
// Already closed // closed by ourselves
default: return nil, net.ErrClosed
close(ln.closed) case conn, ok := <-l.acceptChan:
if !ok {
// closed by upstream
if acceptErr := l.acceptErrorFunc(); acceptErr != nil {
return nil, acceptErr
}
return nil, net.ErrClosed
}
return conn, nil
} }
}
func (l *subListener) Addr() net.Addr {
return l.addrFunc()
}
// Close implements net.Listener.Close.
// Upstream should use close(l.acceptChan) instead.
func (l *subListener) Close() error {
close(l.closeChan)
return nil return nil
} }
func (ln *listener) Addr() net.Addr {
return ln.addr
}
// connWithOneByte is a net.Conn that returns b for the first read // connWithOneByte is a net.Conn that returns b for the first read
// request, then forwards everything else to Conn. // request, then forwards everything else to Conn.
type connWithOneByte struct { type connWithOneByte struct {
@ -139,3 +272,20 @@ func (c *connWithOneByte) Read(bs []byte) (int, error) {
bs[0] = c.b bs[0] = c.b
return 1, nil return 1, nil
} }
type OpErr struct {
Addr net.Addr
Protocol string
Op string
Err error
}
func (m OpErr) Error() string {
return fmt.Sprintf("mux-listen: %s[%s]: %s: %v", m.Addr, m.Protocol, m.Op, m.Err)
}
func (m OpErr) Unwrap() error {
return m.Err
}
var ErrProtocolInUse = errors.New("protocol already in use")