Merge pull request #1006 from xmapst/master

实现HTTP/SOCKS5混合端口
This commit is contained in:
Toby
2024-04-12 19:36:07 -07:00
committed by GitHub
8 changed files with 1283 additions and 2 deletions

View File

@@ -21,6 +21,7 @@ import (
"github.com/apernet/hysteria/app/internal/forwarding"
"github.com/apernet/hysteria/app/internal/http"
"github.com/apernet/hysteria/app/internal/proxymux"
"github.com/apernet/hysteria/app/internal/redirect"
"github.com/apernet/hysteria/app/internal/socks5"
"github.com/apernet/hysteria/app/internal/tproxy"
@@ -531,7 +532,7 @@ func clientSOCKS5(config socks5Config, c client.Client) error {
if config.Listen == "" {
return configError{Field: "listen", Err: errors.New("listen address is empty")}
}
l, err := correctnet.Listen("tcp", config.Listen)
l, err := proxymux.ListenSOCKS(config.Listen)
if err != nil {
return configError{Field: "listen", Err: err}
}
@@ -556,7 +557,7 @@ func clientHTTP(config httpConfig, c client.Client) error {
if config.Listen == "" {
return configError{Field: "listen", Err: errors.New("listen address is empty")}
}
l, err := correctnet.Listen("tcp", config.Listen)
l, err := proxymux.ListenHTTP(config.Listen)
if err != nil {
return configError{Field: "listen", Err: err}
}

View File

@@ -0,0 +1,12 @@
with-expecter: true
dir: internal/mocks
outpkg: mocks
packages:
net:
interfaces:
Listener:
config:
mockname: MockListener
Conn:
config:
mockname: MockConn

View File

@@ -0,0 +1,427 @@
// Code generated by mockery v2.42.2. DO NOT EDIT.
package mocks
import (
net "net"
mock "github.com/stretchr/testify/mock"
time "time"
)
// MockConn is an autogenerated mock type for the Conn type
type MockConn struct {
mock.Mock
}
type MockConn_Expecter struct {
mock *mock.Mock
}
func (_m *MockConn) EXPECT() *MockConn_Expecter {
return &MockConn_Expecter{mock: &_m.Mock}
}
// Close provides a mock function with given fields:
func (_m *MockConn) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// MockConn_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockConn_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockConn_Expecter) Close() *MockConn_Close_Call {
return &MockConn_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockConn_Close_Call) Run(run func()) *MockConn_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockConn_Close_Call) Return(_a0 error) *MockConn_Close_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockConn_Close_Call) RunAndReturn(run func() error) *MockConn_Close_Call {
_c.Call.Return(run)
return _c
}
// LocalAddr provides a mock function with given fields:
func (_m *MockConn) LocalAddr() net.Addr {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for LocalAddr")
}
var r0 net.Addr
if rf, ok := ret.Get(0).(func() net.Addr); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(net.Addr)
}
}
return r0
}
// MockConn_LocalAddr_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LocalAddr'
type MockConn_LocalAddr_Call struct {
*mock.Call
}
// LocalAddr is a helper method to define mock.On call
func (_e *MockConn_Expecter) LocalAddr() *MockConn_LocalAddr_Call {
return &MockConn_LocalAddr_Call{Call: _e.mock.On("LocalAddr")}
}
func (_c *MockConn_LocalAddr_Call) Run(run func()) *MockConn_LocalAddr_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockConn_LocalAddr_Call) Return(_a0 net.Addr) *MockConn_LocalAddr_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockConn_LocalAddr_Call) RunAndReturn(run func() net.Addr) *MockConn_LocalAddr_Call {
_c.Call.Return(run)
return _c
}
// Read provides a mock function with given fields: b
func (_m *MockConn) Read(b []byte) (int, error) {
ret := _m.Called(b)
if len(ret) == 0 {
panic("no return value specified for Read")
}
var r0 int
var r1 error
if rf, ok := ret.Get(0).(func([]byte) (int, error)); ok {
return rf(b)
}
if rf, ok := ret.Get(0).(func([]byte) int); ok {
r0 = rf(b)
} else {
r0 = ret.Get(0).(int)
}
if rf, ok := ret.Get(1).(func([]byte) error); ok {
r1 = rf(b)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockConn_Read_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Read'
type MockConn_Read_Call struct {
*mock.Call
}
// Read is a helper method to define mock.On call
// - b []byte
func (_e *MockConn_Expecter) Read(b interface{}) *MockConn_Read_Call {
return &MockConn_Read_Call{Call: _e.mock.On("Read", b)}
}
func (_c *MockConn_Read_Call) Run(run func(b []byte)) *MockConn_Read_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]byte))
})
return _c
}
func (_c *MockConn_Read_Call) Return(n int, err error) *MockConn_Read_Call {
_c.Call.Return(n, err)
return _c
}
func (_c *MockConn_Read_Call) RunAndReturn(run func([]byte) (int, error)) *MockConn_Read_Call {
_c.Call.Return(run)
return _c
}
// RemoteAddr provides a mock function with given fields:
func (_m *MockConn) RemoteAddr() net.Addr {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for RemoteAddr")
}
var r0 net.Addr
if rf, ok := ret.Get(0).(func() net.Addr); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(net.Addr)
}
}
return r0
}
// MockConn_RemoteAddr_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoteAddr'
type MockConn_RemoteAddr_Call struct {
*mock.Call
}
// RemoteAddr is a helper method to define mock.On call
func (_e *MockConn_Expecter) RemoteAddr() *MockConn_RemoteAddr_Call {
return &MockConn_RemoteAddr_Call{Call: _e.mock.On("RemoteAddr")}
}
func (_c *MockConn_RemoteAddr_Call) Run(run func()) *MockConn_RemoteAddr_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockConn_RemoteAddr_Call) Return(_a0 net.Addr) *MockConn_RemoteAddr_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockConn_RemoteAddr_Call) RunAndReturn(run func() net.Addr) *MockConn_RemoteAddr_Call {
_c.Call.Return(run)
return _c
}
// SetDeadline provides a mock function with given fields: t
func (_m *MockConn) SetDeadline(t time.Time) error {
ret := _m.Called(t)
if len(ret) == 0 {
panic("no return value specified for SetDeadline")
}
var r0 error
if rf, ok := ret.Get(0).(func(time.Time) error); ok {
r0 = rf(t)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockConn_SetDeadline_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetDeadline'
type MockConn_SetDeadline_Call struct {
*mock.Call
}
// SetDeadline is a helper method to define mock.On call
// - t time.Time
func (_e *MockConn_Expecter) SetDeadline(t interface{}) *MockConn_SetDeadline_Call {
return &MockConn_SetDeadline_Call{Call: _e.mock.On("SetDeadline", t)}
}
func (_c *MockConn_SetDeadline_Call) Run(run func(t time.Time)) *MockConn_SetDeadline_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(time.Time))
})
return _c
}
func (_c *MockConn_SetDeadline_Call) Return(_a0 error) *MockConn_SetDeadline_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockConn_SetDeadline_Call) RunAndReturn(run func(time.Time) error) *MockConn_SetDeadline_Call {
_c.Call.Return(run)
return _c
}
// SetReadDeadline provides a mock function with given fields: t
func (_m *MockConn) SetReadDeadline(t time.Time) error {
ret := _m.Called(t)
if len(ret) == 0 {
panic("no return value specified for SetReadDeadline")
}
var r0 error
if rf, ok := ret.Get(0).(func(time.Time) error); ok {
r0 = rf(t)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockConn_SetReadDeadline_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetReadDeadline'
type MockConn_SetReadDeadline_Call struct {
*mock.Call
}
// SetReadDeadline is a helper method to define mock.On call
// - t time.Time
func (_e *MockConn_Expecter) SetReadDeadline(t interface{}) *MockConn_SetReadDeadline_Call {
return &MockConn_SetReadDeadline_Call{Call: _e.mock.On("SetReadDeadline", t)}
}
func (_c *MockConn_SetReadDeadline_Call) Run(run func(t time.Time)) *MockConn_SetReadDeadline_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(time.Time))
})
return _c
}
func (_c *MockConn_SetReadDeadline_Call) Return(_a0 error) *MockConn_SetReadDeadline_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockConn_SetReadDeadline_Call) RunAndReturn(run func(time.Time) error) *MockConn_SetReadDeadline_Call {
_c.Call.Return(run)
return _c
}
// SetWriteDeadline provides a mock function with given fields: t
func (_m *MockConn) SetWriteDeadline(t time.Time) error {
ret := _m.Called(t)
if len(ret) == 0 {
panic("no return value specified for SetWriteDeadline")
}
var r0 error
if rf, ok := ret.Get(0).(func(time.Time) error); ok {
r0 = rf(t)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockConn_SetWriteDeadline_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetWriteDeadline'
type MockConn_SetWriteDeadline_Call struct {
*mock.Call
}
// SetWriteDeadline is a helper method to define mock.On call
// - t time.Time
func (_e *MockConn_Expecter) SetWriteDeadline(t interface{}) *MockConn_SetWriteDeadline_Call {
return &MockConn_SetWriteDeadline_Call{Call: _e.mock.On("SetWriteDeadline", t)}
}
func (_c *MockConn_SetWriteDeadline_Call) Run(run func(t time.Time)) *MockConn_SetWriteDeadline_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(time.Time))
})
return _c
}
func (_c *MockConn_SetWriteDeadline_Call) Return(_a0 error) *MockConn_SetWriteDeadline_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockConn_SetWriteDeadline_Call) RunAndReturn(run func(time.Time) error) *MockConn_SetWriteDeadline_Call {
_c.Call.Return(run)
return _c
}
// Write provides a mock function with given fields: b
func (_m *MockConn) Write(b []byte) (int, error) {
ret := _m.Called(b)
if len(ret) == 0 {
panic("no return value specified for Write")
}
var r0 int
var r1 error
if rf, ok := ret.Get(0).(func([]byte) (int, error)); ok {
return rf(b)
}
if rf, ok := ret.Get(0).(func([]byte) int); ok {
r0 = rf(b)
} else {
r0 = ret.Get(0).(int)
}
if rf, ok := ret.Get(1).(func([]byte) error); ok {
r1 = rf(b)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockConn_Write_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Write'
type MockConn_Write_Call struct {
*mock.Call
}
// Write is a helper method to define mock.On call
// - b []byte
func (_e *MockConn_Expecter) Write(b interface{}) *MockConn_Write_Call {
return &MockConn_Write_Call{Call: _e.mock.On("Write", b)}
}
func (_c *MockConn_Write_Call) Run(run func(b []byte)) *MockConn_Write_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]byte))
})
return _c
}
func (_c *MockConn_Write_Call) Return(n int, err error) *MockConn_Write_Call {
_c.Call.Return(n, err)
return _c
}
func (_c *MockConn_Write_Call) RunAndReturn(run func([]byte) (int, error)) *MockConn_Write_Call {
_c.Call.Return(run)
return _c
}
// NewMockConn creates a new instance of MockConn. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockConn(t interface {
mock.TestingT
Cleanup(func())
}) *MockConn {
mock := &MockConn{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@@ -0,0 +1,185 @@
// Code generated by mockery v2.42.2. DO NOT EDIT.
package mocks
import (
net "net"
mock "github.com/stretchr/testify/mock"
)
// MockListener is an autogenerated mock type for the Listener type
type MockListener struct {
mock.Mock
}
type MockListener_Expecter struct {
mock *mock.Mock
}
func (_m *MockListener) EXPECT() *MockListener_Expecter {
return &MockListener_Expecter{mock: &_m.Mock}
}
// Accept provides a mock function with given fields:
func (_m *MockListener) Accept() (net.Conn, error) {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Accept")
}
var r0 net.Conn
var r1 error
if rf, ok := ret.Get(0).(func() (net.Conn, error)); ok {
return rf()
}
if rf, ok := ret.Get(0).(func() net.Conn); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(net.Conn)
}
}
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockListener_Accept_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Accept'
type MockListener_Accept_Call struct {
*mock.Call
}
// Accept is a helper method to define mock.On call
func (_e *MockListener_Expecter) Accept() *MockListener_Accept_Call {
return &MockListener_Accept_Call{Call: _e.mock.On("Accept")}
}
func (_c *MockListener_Accept_Call) Run(run func()) *MockListener_Accept_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockListener_Accept_Call) Return(_a0 net.Conn, _a1 error) *MockListener_Accept_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockListener_Accept_Call) RunAndReturn(run func() (net.Conn, error)) *MockListener_Accept_Call {
_c.Call.Return(run)
return _c
}
// Addr provides a mock function with given fields:
func (_m *MockListener) Addr() net.Addr {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Addr")
}
var r0 net.Addr
if rf, ok := ret.Get(0).(func() net.Addr); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(net.Addr)
}
}
return r0
}
// MockListener_Addr_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Addr'
type MockListener_Addr_Call struct {
*mock.Call
}
// Addr is a helper method to define mock.On call
func (_e *MockListener_Expecter) Addr() *MockListener_Addr_Call {
return &MockListener_Addr_Call{Call: _e.mock.On("Addr")}
}
func (_c *MockListener_Addr_Call) Run(run func()) *MockListener_Addr_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockListener_Addr_Call) Return(_a0 net.Addr) *MockListener_Addr_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockListener_Addr_Call) RunAndReturn(run func() net.Addr) *MockListener_Addr_Call {
_c.Call.Return(run)
return _c
}
// Close provides a mock function with given fields:
func (_m *MockListener) Close() error {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Close")
}
var r0 error
if rf, ok := ret.Get(0).(func() error); ok {
r0 = rf()
} else {
r0 = ret.Error(0)
}
return r0
}
// MockListener_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockListener_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockListener_Expecter) Close() *MockListener_Close_Call {
return &MockListener_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockListener_Close_Call) Run(run func()) *MockListener_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockListener_Close_Call) Return(_a0 error) *MockListener_Close_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockListener_Close_Call) RunAndReturn(run func() error) *MockListener_Close_Call {
_c.Call.Return(run)
return _c
}
// NewMockListener creates a new instance of MockListener. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockListener(t interface {
mock.TestingT
Cleanup(func())
}) *MockListener {
mock := &MockListener{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@@ -0,0 +1,72 @@
package proxymux
import (
"net"
"sync"
"github.com/apernet/hysteria/extras/correctnet"
)
type muxManager struct {
listeners map[string]*muxListener
lock sync.Mutex
}
var globalMuxManager *muxManager
func init() {
globalMuxManager = &muxManager{
listeners: make(map[string]*muxListener),
}
}
func (m *muxManager) GetOrCreate(address string) (*muxListener, error) {
key, err := m.canonicalizeAddrPort(address)
if err != nil {
return nil, err
}
m.lock.Lock()
defer m.lock.Unlock()
if ml, ok := m.listeners[key]; ok {
return ml, nil
}
listener, err := correctnet.Listen("tcp", key)
if err != nil {
return nil, err
}
ml := newMuxListener(listener, func() {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.listeners, key)
})
m.listeners[key] = ml
return ml, nil
}
func (m *muxManager) canonicalizeAddrPort(address string) (string, error) {
taddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return "", err
}
return taddr.String(), nil
}
func ListenHTTP(address string) (net.Listener, error) {
ml, err := globalMuxManager.GetOrCreate(address)
if err != nil {
return nil, err
}
return ml.ListenHTTP()
}
func ListenSOCKS(address string) (net.Listener, error) {
ml, err := globalMuxManager.GetOrCreate(address)
if err != nil {
return nil, err
}
return ml.ListenSOCKS()
}

View File

@@ -0,0 +1,110 @@
package proxymux
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestListenSOCKS(t *testing.T) {
address := "127.2.39.129:11081"
sl, err := ListenSOCKS(address)
if !assert.NoError(t, err) {
return
}
defer func() {
sl.Close()
}()
hl, err := ListenHTTP(address)
if !assert.NoError(t, err) {
return
}
defer hl.Close()
_, err = ListenSOCKS(address)
if !assert.ErrorIs(t, err, ErrProtocolInUse) {
return
}
sl.Close()
sl, err = ListenSOCKS(address)
if !assert.NoError(t, err) {
return
}
}
func TestListenHTTP(t *testing.T) {
address := "127.2.39.129:11082"
hl, err := ListenHTTP(address)
if !assert.NoError(t, err) {
return
}
defer func() {
hl.Close()
}()
sl, err := ListenSOCKS(address)
if !assert.NoError(t, err) {
return
}
defer sl.Close()
_, err = ListenHTTP(address)
if !assert.ErrorIs(t, err, ErrProtocolInUse) {
return
}
hl.Close()
hl, err = ListenHTTP(address)
if !assert.NoError(t, err) {
return
}
}
func TestRelease(t *testing.T) {
address := "127.2.39.129:11083"
hl, err := ListenHTTP(address)
if !assert.NoError(t, err) {
return
}
sl, err := ListenSOCKS(address)
if !assert.NoError(t, err) {
return
}
if !assert.True(t, globalMuxManager.testAddressExists(address)) {
return
}
_, err = net.Listen("tcp", address)
if !assert.Error(t, err) {
return
}
hl.Close()
sl.Close()
// Wait for muxListener released
time.Sleep(time.Second)
if !assert.False(t, globalMuxManager.testAddressExists(address)) {
return
}
lis, err := net.Listen("tcp", address)
if !assert.NoError(t, err) {
return
}
defer lis.Close()
}
func (m *muxManager) testAddressExists(address string) bool {
m.lock.Lock()
defer m.lock.Unlock()
_, ok := m.listeners[address]
return ok
}

View File

@@ -0,0 +1,320 @@
package proxymux
import (
"errors"
"fmt"
"io"
"net"
"sync"
)
func newMuxListener(listener net.Listener, deleteFunc func()) *muxListener {
l := &muxListener{
base: listener,
acceptChan: make(chan net.Conn),
closeChan: make(chan struct{}),
deleteFunc: deleteFunc,
}
go l.acceptLoop()
go l.mainLoop()
return l
}
type muxListener struct {
lock sync.Mutex
base net.Listener
acceptErr error
acceptChan chan net.Conn
closeChan chan struct{}
socksListener *subListener
httpListener *subListener
deleteFunc func()
}
func (l *muxListener) acceptLoop() {
defer close(l.acceptChan)
for {
conn, err := l.base.Accept()
if err != nil {
l.lock.Lock()
l.acceptErr = err
l.lock.Unlock()
return
}
select {
case <-l.closeChan:
return
case l.acceptChan <- conn:
}
}
}
func (l *muxListener) mainLoop() {
defer func() {
l.deleteFunc()
l.base.Close()
close(l.closeChan)
l.lock.Lock()
defer l.lock.Unlock()
if sl := l.httpListener; sl != nil {
close(sl.acceptChan)
l.httpListener = nil
}
if sl := l.socksListener; sl != nil {
close(sl.acceptChan)
l.socksListener = nil
}
}()
for {
var socksCloseChan, httpCloseChan chan struct{}
if l.httpListener != nil {
httpCloseChan = l.httpListener.closeChan
}
if l.socksListener != nil {
socksCloseChan = l.socksListener.closeChan
}
select {
case <-l.closeChan:
return
case conn, ok := <-l.acceptChan:
if !ok {
return
}
go l.dispatch(conn)
case <-socksCloseChan:
l.lock.Lock()
if socksCloseChan == l.socksListener.closeChan {
// not replaced by another ListenSOCKS()
l.socksListener = nil
}
l.lock.Unlock()
if l.checkIdle() {
return
}
case <-httpCloseChan:
l.lock.Lock()
if httpCloseChan == l.httpListener.closeChan {
// not replaced by another ListenHTTP()
l.httpListener = nil
}
l.lock.Unlock()
if l.checkIdle() {
return
}
}
}
}
func (l *muxListener) dispatch(conn net.Conn) {
var b [1]byte
if _, err := io.ReadFull(conn, b[:]); err != nil {
conn.Close()
return
}
l.lock.Lock()
var target *subListener
if b[0] == 5 {
target = l.socksListener
} else {
target = l.httpListener
}
l.lock.Unlock()
if target == nil {
conn.Close()
return
}
wconn := &connWithOneByte{Conn: conn, b: b[0]}
select {
case <-target.closeChan:
case target.acceptChan <- wconn:
}
}
func (l *muxListener) checkIdle() bool {
l.lock.Lock()
defer l.lock.Unlock()
return l.httpListener == nil && l.socksListener == nil
}
func (l *muxListener) getAndClearAcceptError() error {
l.lock.Lock()
defer l.lock.Unlock()
if l.acceptErr == nil {
return nil
}
err := l.acceptErr
l.acceptErr = nil
return err
}
func (l *muxListener) ListenHTTP() (net.Listener, error) {
l.lock.Lock()
defer l.lock.Unlock()
if l.httpListener != nil {
subListenerPendingClosed := false
select {
case <-l.httpListener.closeChan:
subListenerPendingClosed = true
default:
}
if !subListenerPendingClosed {
return nil, OpErr{
Addr: l.base.Addr(),
Protocol: "http",
Op: "bind-protocol",
Err: ErrProtocolInUse,
}
}
l.httpListener = nil
}
select {
case <-l.closeChan:
return nil, net.ErrClosed
default:
}
sl := newSubListener(l.getAndClearAcceptError, l.base.Addr)
l.httpListener = sl
return sl, nil
}
func (l *muxListener) ListenSOCKS() (net.Listener, error) {
l.lock.Lock()
defer l.lock.Unlock()
if l.socksListener != nil {
subListenerPendingClosed := false
select {
case <-l.socksListener.closeChan:
subListenerPendingClosed = true
default:
}
if !subListenerPendingClosed {
return nil, OpErr{
Addr: l.base.Addr(),
Protocol: "socks",
Op: "bind-protocol",
Err: ErrProtocolInUse,
}
}
l.socksListener = nil
}
select {
case <-l.closeChan:
return nil, net.ErrClosed
default:
}
sl := newSubListener(l.getAndClearAcceptError, l.base.Addr)
l.socksListener = sl
return sl, nil
}
func newSubListener(acceptErrorFunc func() error, addrFunc func() net.Addr) *subListener {
return &subListener{
acceptChan: make(chan net.Conn),
acceptErrorFunc: acceptErrorFunc,
closeChan: make(chan struct{}),
addrFunc: addrFunc,
}
}
type subListener struct {
// receive connections or closure from upstream
acceptChan chan net.Conn
// get an error of Accept() from upstream
acceptErrorFunc func() error
// notify upstream that we are closed
closeChan chan struct{}
// Listener.Addr() implementation of base listener
addrFunc func() net.Addr
}
func (l *subListener) Accept() (net.Conn, error) {
select {
case <-l.closeChan:
// closed by ourselves
return nil, net.ErrClosed
case conn, ok := <-l.acceptChan:
if !ok {
// closed by upstream
if acceptErr := l.acceptErrorFunc(); acceptErr != nil {
return nil, acceptErr
}
return nil, net.ErrClosed
}
return conn, nil
}
}
func (l *subListener) Addr() net.Addr {
return l.addrFunc()
}
// Close implements net.Listener.Close.
// Upstream should use close(l.acceptChan) instead.
func (l *subListener) Close() error {
select {
case <-l.closeChan:
return nil
default:
}
close(l.closeChan)
return nil
}
// connWithOneByte is a net.Conn that returns b for the first read
// request, then forwards everything else to Conn.
type connWithOneByte struct {
net.Conn
b byte
bRead bool
}
func (c *connWithOneByte) Read(bs []byte) (int, error) {
if c.bRead {
return c.Conn.Read(bs)
}
if len(bs) == 0 {
return 0, nil
}
c.bRead = true
bs[0] = c.b
return 1, nil
}
type OpErr struct {
Addr net.Addr
Protocol string
Op string
Err error
}
func (m OpErr) Error() string {
return fmt.Sprintf("mux-listen: %s[%s]: %s: %v", m.Addr, m.Protocol, m.Op, m.Err)
}
func (m OpErr) Unwrap() error {
return m.Err
}
var ErrProtocolInUse = errors.New("protocol already in use")

View File

@@ -0,0 +1,154 @@
package proxymux
import (
"bytes"
"net"
"sync"
"testing"
"time"
"github.com/apernet/hysteria/app/internal/proxymux/internal/mocks"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
//go:generate mockery
func testMockListener(t *testing.T, connChan <-chan net.Conn) net.Listener {
closedChan := make(chan struct{})
mockListener := mocks.NewMockListener(t)
mockListener.EXPECT().Accept().RunAndReturn(func() (net.Conn, error) {
select {
case <-closedChan:
return nil, net.ErrClosed
case conn, ok := <-connChan:
if !ok {
panic("unexpected closed channel (connChan)")
}
return conn, nil
}
})
mockListener.EXPECT().Close().RunAndReturn(func() error {
select {
case <-closedChan:
default:
close(closedChan)
}
return nil
})
return mockListener
}
func testMockConn(t *testing.T, b []byte) net.Conn {
buf := bytes.NewReader(b)
isClosed := false
mockConn := mocks.NewMockConn(t)
mockConn.EXPECT().Read(mock.Anything).RunAndReturn(func(b []byte) (int, error) {
if isClosed {
return 0, net.ErrClosed
}
return buf.Read(b)
})
mockConn.EXPECT().Close().RunAndReturn(func() error {
isClosed = true
return nil
})
return mockConn
}
func TestMuxHTTP(t *testing.T) {
connChan := make(chan net.Conn)
mockListener := testMockListener(t, connChan)
mockConn := testMockConn(t, []byte("CONNECT example.com:443 HTTP/1.1\r\n\r\n"))
mux := newMuxListener(mockListener, func() {})
hl, err := mux.ListenHTTP()
if !assert.NoError(t, err) {
return
}
sl, err := mux.ListenSOCKS()
if !assert.NoError(t, err) {
return
}
connChan <- mockConn
var socksConn, httpConn net.Conn
var socksErr, httpErr error
var wg sync.WaitGroup
wg.Add(2)
go func() {
socksConn, socksErr = sl.Accept()
wg.Done()
}()
go func() {
httpConn, httpErr = hl.Accept()
wg.Done()
}()
time.Sleep(time.Second)
sl.Close()
hl.Close()
wg.Wait()
assert.Nil(t, socksConn)
assert.ErrorIs(t, socksErr, net.ErrClosed)
assert.NotNil(t, httpConn)
httpConn.Close()
assert.NoError(t, httpErr)
// Wait for muxListener released
<-mux.acceptChan
}
func TestMuxSOCKS(t *testing.T) {
connChan := make(chan net.Conn)
mockListener := testMockListener(t, connChan)
mockConn := testMockConn(t, []byte{0x05, 0x02, 0x00, 0x01}) // SOCKS5 Connect Request: NOAUTH+GSSAPI
mux := newMuxListener(mockListener, func() {})
hl, err := mux.ListenHTTP()
if !assert.NoError(t, err) {
return
}
sl, err := mux.ListenSOCKS()
if !assert.NoError(t, err) {
return
}
connChan <- mockConn
var socksConn, httpConn net.Conn
var socksErr, httpErr error
var wg sync.WaitGroup
wg.Add(2)
go func() {
socksConn, socksErr = sl.Accept()
wg.Done()
}()
go func() {
httpConn, httpErr = hl.Accept()
wg.Done()
}()
time.Sleep(time.Second)
sl.Close()
hl.Close()
wg.Wait()
assert.NotNil(t, socksConn)
socksConn.Close()
assert.NoError(t, socksErr)
assert.Nil(t, httpConn)
assert.ErrorIs(t, httpErr, net.ErrClosed)
// Wait for muxListener released
<-mux.acceptChan
}