feat: ignoreClientBandwidth

This commit is contained in:
Toby 2023-08-07 16:34:35 -07:00
parent f95a31120d
commit cc0d0181e1
8 changed files with 156 additions and 84 deletions

View File

@ -39,6 +39,7 @@ type serverConfig struct {
ACME *serverConfigACME `mapstructure:"acme"` ACME *serverConfigACME `mapstructure:"acme"`
QUIC serverConfigQUIC `mapstructure:"quic"` QUIC serverConfigQUIC `mapstructure:"quic"`
Bandwidth serverConfigBandwidth `mapstructure:"bandwidth"` Bandwidth serverConfigBandwidth `mapstructure:"bandwidth"`
IgnoreClientBandwidth bool `mapstructure:"ignoreClientBandwidth"`
DisableUDP bool `mapstructure:"disableUDP"` DisableUDP bool `mapstructure:"disableUDP"`
UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"` UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"`
Auth serverConfigAuth `mapstructure:"auth"` Auth serverConfigAuth `mapstructure:"auth"`
@ -360,6 +361,11 @@ func (c *serverConfig) fillBandwidthConfig(hyConfig *server.Config) error {
return nil return nil
} }
func (c *serverConfig) fillIgnoreClientBandwidth(hyConfig *server.Config) error {
hyConfig.IgnoreClientBandwidth = c.IgnoreClientBandwidth
return nil
}
func (c *serverConfig) fillDisableUDP(hyConfig *server.Config) error { func (c *serverConfig) fillDisableUDP(hyConfig *server.Config) error {
hyConfig.DisableUDP = c.DisableUDP hyConfig.DisableUDP = c.DisableUDP
return nil return nil
@ -445,6 +451,7 @@ func (c *serverConfig) Config() (*server.Config, error) {
c.fillQUICConfig, c.fillQUICConfig,
c.fillOutboundConfig, c.fillOutboundConfig,
c.fillBandwidthConfig, c.fillBandwidthConfig,
c.fillIgnoreClientBandwidth,
c.fillDisableUDP, c.fillDisableUDP,
c.fillUDPIdleTimeout, c.fillUDPIdleTimeout,
c.fillAuthenticator, c.fillAuthenticator,

View File

@ -55,6 +55,7 @@ func TestServerConfig(t *testing.T) {
Up: "500 mbps", Up: "500 mbps",
Down: "100 mbps", Down: "100 mbps",
}, },
IgnoreClientBandwidth: true,
DisableUDP: true, DisableUDP: true,
UDPIdleTimeout: 120 * time.Second, UDPIdleTimeout: 120 * time.Second,
Auth: serverConfigAuth{ Auth: serverConfigAuth{

View File

@ -34,6 +34,8 @@ bandwidth:
up: 500 mbps up: 500 mbps
down: 100 mbps down: 100 mbps
ignoreClientBandwidth: true
disableUDP: true disableUDP: true
udpIdleTimeout: 120s udpIdleTimeout: 120s

View File

@ -9,9 +9,7 @@ import (
"time" "time"
coreErrs "github.com/apernet/hysteria/core/errors" coreErrs "github.com/apernet/hysteria/core/errors"
"github.com/apernet/hysteria/core/internal/congestion/bbr" "github.com/apernet/hysteria/core/internal/congestion"
"github.com/apernet/hysteria/core/internal/congestion/brutal"
"github.com/apernet/hysteria/core/internal/congestion/common"
"github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/core/internal/utils" "github.com/apernet/hysteria/core/internal/utils"
@ -104,7 +102,10 @@ func (c *clientImpl) connect() error {
}, },
Header: make(http.Header), Header: make(http.Header),
} }
protocol.AuthRequestDataToHeader(req.Header, c.config.Auth, c.config.BandwidthConfig.MaxRx) protocol.AuthRequestToHeader(req.Header, protocol.AuthRequest{
Auth: c.config.Auth,
Rx: c.config.BandwidthConfig.MaxRx,
})
resp, err := rt.RoundTrip(req) resp, err := rt.RoundTrip(req)
if err != nil { if err != nil {
if conn != nil { if conn != nil {
@ -119,28 +120,30 @@ func (c *clientImpl) connect() error {
return coreErrs.AuthError{StatusCode: resp.StatusCode} return coreErrs.AuthError{StatusCode: resp.StatusCode}
} }
// Auth OK // Auth OK
udpEnabled, serverRx := protocol.AuthResponseDataFromHeader(resp.Header) authResp := protocol.AuthResponseFromHeader(resp.Header)
if authResp.RxAuto {
// Server asks client to use bandwidth detection,
// ignore local bandwidth config and use BBR
congestion.UseBBR(conn)
} else {
// actualTx = min(serverRx, clientTx) // actualTx = min(serverRx, clientTx)
actualTx := serverRx actualTx := authResp.Rx
if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx { if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx {
// Server doesn't have a limit, or our clientTx is smaller than serverRx
actualTx = c.config.BandwidthConfig.MaxTx actualTx = c.config.BandwidthConfig.MaxTx
} }
// Use Brutal CC if actualTx > 0, otherwise use BBR
if actualTx > 0 { if actualTx > 0 {
conn.SetCongestionControl(brutal.NewBrutalSender(actualTx)) congestion.UseBrutal(conn, actualTx)
} else { } else {
conn.SetCongestionControl(bbr.NewBBRSender( // We don't know our own bandwidth either, use BBR
bbr.DefaultClock{}, congestion.UseBBR(conn)
bbr.GetInitialPacketSize(conn.RemoteAddr()), }
bbr.InitialCongestionWindow*common.InitMaxDatagramSize,
bbr.DefaultBBRMaxCongestionWindow*common.InitMaxDatagramSize,
))
} }
_ = resp.Body.Close() _ = resp.Body.Close()
c.pktConn = pktConn c.pktConn = pktConn
c.conn = conn c.conn = conn
if udpEnabled { if authResp.UDPEnabled {
c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn})
} }
return nil return nil

View File

@ -0,0 +1,21 @@
package congestion
import (
"github.com/apernet/hysteria/core/internal/congestion/bbr"
"github.com/apernet/hysteria/core/internal/congestion/brutal"
"github.com/apernet/hysteria/core/internal/congestion/common"
"github.com/apernet/quic-go"
)
func UseBBR(conn quic.Connection) {
conn.SetCongestionControl(bbr.NewBBRSender(
bbr.DefaultClock{},
bbr.GetInitialPacketSize(conn.RemoteAddr()),
bbr.InitialCongestionWindow*common.InitMaxDatagramSize,
bbr.DefaultBBRMaxCongestionWindow*common.InitMaxDatagramSize,
))
}
func UseBrutal(conn quic.Connection, tx uint64) {
conn.SetCongestionControl(brutal.NewBrutalSender(tx))
}

View File

@ -17,26 +17,52 @@ const (
StatusAuthOK = 233 StatusAuthOK = 233
) )
func AuthRequestDataFromHeader(h http.Header) (auth string, rx uint64) { // AuthRequest is what client sends to server for authentication.
auth = h.Get(RequestHeaderAuth) type AuthRequest struct {
rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) Auth string
return Rx uint64 // 0 = unknown, client asks server to use bandwidth detection
} }
func AuthRequestDataToHeader(h http.Header, auth string, rx uint64) { // AuthResponse is what server sends to client when authentication is passed.
h.Set(RequestHeaderAuth, auth) type AuthResponse struct {
h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10)) UDPEnabled bool
Rx uint64 // 0 = unlimited
RxAuto bool // true = server asks client to use bandwidth detection
}
func AuthRequestFromHeader(h http.Header) AuthRequest {
rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64)
return AuthRequest{
Auth: h.Get(RequestHeaderAuth),
Rx: rx,
}
}
func AuthRequestToHeader(h http.Header, req AuthRequest) {
h.Set(RequestHeaderAuth, req.Auth)
h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10))
h.Set(CommonHeaderPadding, authRequestPadding.String()) h.Set(CommonHeaderPadding, authRequestPadding.String())
} }
func AuthResponseDataFromHeader(h http.Header) (udp bool, rx uint64) { func AuthResponseFromHeader(h http.Header) AuthResponse {
udp, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled)) resp := AuthResponse{}
rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled))
return rxStr := h.Get(CommonHeaderCCRX)
if rxStr == "auto" {
// Special case for server requesting client to use bandwidth detection
resp.RxAuto = true
} else {
resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64)
}
return resp
} }
func AuthResponseDataToHeader(h http.Header, udp bool, rx uint64) { func AuthResponseToHeader(h http.Header, resp AuthResponse) {
h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(udp)) h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled))
h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10)) if resp.RxAuto {
h.Set(CommonHeaderCCRX, "auto")
} else {
h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10))
}
h.Set(CommonHeaderPadding, authResponsePadding.String()) h.Set(CommonHeaderPadding, authResponsePadding.String())
} }

View File

@ -24,6 +24,7 @@ type Config struct {
Conn net.PacketConn Conn net.PacketConn
Outbound Outbound Outbound Outbound
BandwidthConfig BandwidthConfig BandwidthConfig BandwidthConfig
IgnoreClientBandwidth bool
DisableUDP bool DisableUDP bool
UDPIdleTimeout time.Duration UDPIdleTimeout time.Duration
Authenticator Authenticator Authenticator Authenticator

View File

@ -9,9 +9,7 @@ import (
"github.com/apernet/quic-go" "github.com/apernet/quic-go"
"github.com/apernet/quic-go/http3" "github.com/apernet/quic-go/http3"
"github.com/apernet/hysteria/core/internal/congestion/bbr" "github.com/apernet/hysteria/core/internal/congestion"
"github.com/apernet/hysteria/core/internal/congestion/brutal"
"github.com/apernet/hysteria/core/internal/congestion/common"
"github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/core/internal/utils" "github.com/apernet/hysteria/core/internal/utils"
) )
@ -96,9 +94,9 @@ type h3sHandler struct {
conn quic.Connection conn quic.Connection
authenticated bool authenticated bool
authMutex sync.Mutex
authID string authID string
udpOnce sync.Once
udpSM *udpSessionManager // Only set after authentication udpSM *udpSessionManager // Only set after authentication
} }
@ -111,36 +109,49 @@ func newH3sHandler(config *Config, conn quic.Connection) *h3sHandler {
func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath { if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
h.authMutex.Lock()
defer h.authMutex.Unlock()
if h.authenticated { if h.authenticated {
// Already authenticated // Already authenticated
protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx) protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
UDPEnabled: !h.config.DisableUDP,
Rx: h.config.BandwidthConfig.MaxRx,
RxAuto: h.config.IgnoreClientBandwidth,
})
w.WriteHeader(protocol.StatusAuthOK) w.WriteHeader(protocol.StatusAuthOK)
return return
} }
auth, clientRx := protocol.AuthRequestDataFromHeader(r.Header) authReq := protocol.AuthRequestFromHeader(r.Header)
// actualTx = min(serverTx, clientRx) actualTx := authReq.Rx
actualTx := clientRx ok, id := h.config.Authenticator.Authenticate(h.conn.RemoteAddr(), authReq.Auth, actualTx)
if h.config.BandwidthConfig.MaxTx > 0 && actualTx > h.config.BandwidthConfig.MaxTx {
actualTx = h.config.BandwidthConfig.MaxTx
}
ok, id := h.config.Authenticator.Authenticate(h.conn.RemoteAddr(), auth, actualTx)
if ok { if ok {
// Set authenticated flag // Set authenticated flag
h.authenticated = true h.authenticated = true
h.authID = id h.authID = id
// Use Brutal CC if actualTx > 0, otherwise use BBR if h.config.IgnoreClientBandwidth {
if actualTx > 0 { // Ignore client bandwidth, always use BBR
h.conn.SetCongestionControl(brutal.NewBrutalSender(actualTx)) congestion.UseBBR(h.conn)
actualTx = 0
} else { } else {
h.conn.SetCongestionControl(bbr.NewBBRSender( // actualTx = min(serverTx, clientRx)
bbr.DefaultClock{}, if h.config.BandwidthConfig.MaxTx > 0 && actualTx > h.config.BandwidthConfig.MaxTx {
bbr.GetInitialPacketSize(h.conn.RemoteAddr()), // We have a maxTx limit and the client is asking for more than that,
bbr.InitialCongestionWindow*common.InitMaxDatagramSize, // return and use the limit instead
bbr.DefaultBBRMaxCongestionWindow*common.InitMaxDatagramSize, actualTx = h.config.BandwidthConfig.MaxTx
)) }
if actualTx > 0 {
congestion.UseBrutal(h.conn, actualTx)
} else {
// Client doesn't know its own bandwidth, use BBR
congestion.UseBBR(h.conn)
}
} }
// Auth OK, send response // Auth OK, send response
protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx) protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{
UDPEnabled: !h.config.DisableUDP,
Rx: h.config.BandwidthConfig.MaxRx,
RxAuto: h.config.IgnoreClientBandwidth,
})
w.WriteHeader(protocol.StatusAuthOK) w.WriteHeader(protocol.StatusAuthOK)
// Call event logger // Call event logger
if h.config.EventLogger != nil { if h.config.EventLogger != nil {
@ -150,14 +161,14 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// We use sync.Once to make sure that only one goroutine is started, // We use sync.Once to make sure that only one goroutine is started,
// as ServeHTTP may be called by multiple goroutines simultaneously // as ServeHTTP may be called by multiple goroutines simultaneously
if !h.config.DisableUDP { if !h.config.DisableUDP {
h.udpOnce.Do(func() { go func() {
sm := newUDPSessionManager( sm := newUDPSessionManager(
&udpIOImpl{h.conn, id, h.config.TrafficLogger, h.config.Outbound}, &udpIOImpl{h.conn, id, h.config.TrafficLogger, h.config.Outbound},
&udpEventLoggerImpl{h.conn, id, h.config.EventLogger}, &udpEventLoggerImpl{h.conn, id, h.config.EventLogger},
h.config.UDPIdleTimeout) h.config.UDPIdleTimeout)
h.udpSM = sm h.udpSM = sm
go sm.Run() go sm.Run()
}) }()
} }
} else { } else {
// Auth failed, pretend to be a normal HTTP server // Auth failed, pretend to be a normal HTTP server