diff --git a/app/cmd/server.go b/app/cmd/server.go index 44bc016..4ef7b0f 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -33,18 +33,19 @@ func init() { } type serverConfig struct { - Listen string `mapstructure:"listen"` - Obfs serverConfigObfs `mapstructure:"obfs"` - TLS *serverConfigTLS `mapstructure:"tls"` - ACME *serverConfigACME `mapstructure:"acme"` - QUIC serverConfigQUIC `mapstructure:"quic"` - Bandwidth serverConfigBandwidth `mapstructure:"bandwidth"` - DisableUDP bool `mapstructure:"disableUDP"` - UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"` - Auth serverConfigAuth `mapstructure:"auth"` - Resolver serverConfigResolver `mapstructure:"resolver"` - Outbounds []serverConfigOutboundEntry `mapstructure:"outbounds"` - Masquerade serverConfigMasquerade `mapstructure:"masquerade"` + Listen string `mapstructure:"listen"` + Obfs serverConfigObfs `mapstructure:"obfs"` + TLS *serverConfigTLS `mapstructure:"tls"` + ACME *serverConfigACME `mapstructure:"acme"` + QUIC serverConfigQUIC `mapstructure:"quic"` + Bandwidth serverConfigBandwidth `mapstructure:"bandwidth"` + IgnoreClientBandwidth bool `mapstructure:"ignoreClientBandwidth"` + DisableUDP bool `mapstructure:"disableUDP"` + UDPIdleTimeout time.Duration `mapstructure:"udpIdleTimeout"` + Auth serverConfigAuth `mapstructure:"auth"` + Resolver serverConfigResolver `mapstructure:"resolver"` + Outbounds []serverConfigOutboundEntry `mapstructure:"outbounds"` + Masquerade serverConfigMasquerade `mapstructure:"masquerade"` } type serverConfigObfsSalamander struct { @@ -360,6 +361,11 @@ func (c *serverConfig) fillBandwidthConfig(hyConfig *server.Config) error { return nil } +func (c *serverConfig) fillIgnoreClientBandwidth(hyConfig *server.Config) error { + hyConfig.IgnoreClientBandwidth = c.IgnoreClientBandwidth + return nil +} + func (c *serverConfig) fillDisableUDP(hyConfig *server.Config) error { hyConfig.DisableUDP = c.DisableUDP return nil @@ -445,6 +451,7 @@ func (c *serverConfig) Config() (*server.Config, error) { c.fillQUICConfig, c.fillOutboundConfig, c.fillBandwidthConfig, + c.fillIgnoreClientBandwidth, c.fillDisableUDP, c.fillUDPIdleTimeout, c.fillAuthenticator, diff --git a/app/cmd/server_test.go b/app/cmd/server_test.go index 114338b..a3a9365 100644 --- a/app/cmd/server_test.go +++ b/app/cmd/server_test.go @@ -55,8 +55,9 @@ func TestServerConfig(t *testing.T) { Up: "500 mbps", Down: "100 mbps", }, - DisableUDP: true, - UDPIdleTimeout: 120 * time.Second, + IgnoreClientBandwidth: true, + DisableUDP: true, + UDPIdleTimeout: 120 * time.Second, Auth: serverConfigAuth{ Type: "password", Password: "goofy_ahh_password", diff --git a/app/cmd/server_test.yaml b/app/cmd/server_test.yaml index 66b45dc..c0b0f37 100644 --- a/app/cmd/server_test.yaml +++ b/app/cmd/server_test.yaml @@ -34,6 +34,8 @@ bandwidth: up: 500 mbps down: 100 mbps +ignoreClientBandwidth: true + disableUDP: true udpIdleTimeout: 120s diff --git a/core/client/client.go b/core/client/client.go index c3750c4..35da850 100644 --- a/core/client/client.go +++ b/core/client/client.go @@ -9,9 +9,7 @@ import ( "time" coreErrs "github.com/apernet/hysteria/core/errors" - "github.com/apernet/hysteria/core/internal/congestion/bbr" - "github.com/apernet/hysteria/core/internal/congestion/brutal" - "github.com/apernet/hysteria/core/internal/congestion/common" + "github.com/apernet/hysteria/core/internal/congestion" "github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/utils" @@ -104,7 +102,10 @@ func (c *clientImpl) connect() error { }, Header: make(http.Header), } - protocol.AuthRequestDataToHeader(req.Header, c.config.Auth, c.config.BandwidthConfig.MaxRx) + protocol.AuthRequestToHeader(req.Header, protocol.AuthRequest{ + Auth: c.config.Auth, + Rx: c.config.BandwidthConfig.MaxRx, + }) resp, err := rt.RoundTrip(req) if err != nil { if conn != nil { @@ -119,28 +120,30 @@ func (c *clientImpl) connect() error { return coreErrs.AuthError{StatusCode: resp.StatusCode} } // Auth OK - udpEnabled, serverRx := protocol.AuthResponseDataFromHeader(resp.Header) - // actualTx = min(serverRx, clientTx) - actualTx := serverRx - if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx { - actualTx = c.config.BandwidthConfig.MaxTx - } - // Use Brutal CC if actualTx > 0, otherwise use BBR - if actualTx > 0 { - conn.SetCongestionControl(brutal.NewBrutalSender(actualTx)) + authResp := protocol.AuthResponseFromHeader(resp.Header) + if authResp.RxAuto { + // Server asks client to use bandwidth detection, + // ignore local bandwidth config and use BBR + congestion.UseBBR(conn) } else { - conn.SetCongestionControl(bbr.NewBBRSender( - bbr.DefaultClock{}, - bbr.GetInitialPacketSize(conn.RemoteAddr()), - bbr.InitialCongestionWindow*common.InitMaxDatagramSize, - bbr.DefaultBBRMaxCongestionWindow*common.InitMaxDatagramSize, - )) + // actualTx = min(serverRx, clientTx) + actualTx := authResp.Rx + if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx { + // Server doesn't have a limit, or our clientTx is smaller than serverRx + actualTx = c.config.BandwidthConfig.MaxTx + } + if actualTx > 0 { + congestion.UseBrutal(conn, actualTx) + } else { + // We don't know our own bandwidth either, use BBR + congestion.UseBBR(conn) + } } _ = resp.Body.Close() c.pktConn = pktConn c.conn = conn - if udpEnabled { + if authResp.UDPEnabled { c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) } return nil diff --git a/core/internal/congestion/utils.go b/core/internal/congestion/utils.go new file mode 100644 index 0000000..6a71da1 --- /dev/null +++ b/core/internal/congestion/utils.go @@ -0,0 +1,21 @@ +package congestion + +import ( + "github.com/apernet/hysteria/core/internal/congestion/bbr" + "github.com/apernet/hysteria/core/internal/congestion/brutal" + "github.com/apernet/hysteria/core/internal/congestion/common" + "github.com/apernet/quic-go" +) + +func UseBBR(conn quic.Connection) { + conn.SetCongestionControl(bbr.NewBBRSender( + bbr.DefaultClock{}, + bbr.GetInitialPacketSize(conn.RemoteAddr()), + bbr.InitialCongestionWindow*common.InitMaxDatagramSize, + bbr.DefaultBBRMaxCongestionWindow*common.InitMaxDatagramSize, + )) +} + +func UseBrutal(conn quic.Connection, tx uint64) { + conn.SetCongestionControl(brutal.NewBrutalSender(tx)) +} diff --git a/core/internal/protocol/http.go b/core/internal/protocol/http.go index f4ccf82..abcc1a4 100644 --- a/core/internal/protocol/http.go +++ b/core/internal/protocol/http.go @@ -17,26 +17,52 @@ const ( StatusAuthOK = 233 ) -func AuthRequestDataFromHeader(h http.Header) (auth string, rx uint64) { - auth = h.Get(RequestHeaderAuth) - rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) - return +// AuthRequest is what client sends to server for authentication. +type AuthRequest struct { + Auth string + Rx uint64 // 0 = unknown, client asks server to use bandwidth detection } -func AuthRequestDataToHeader(h http.Header, auth string, rx uint64) { - h.Set(RequestHeaderAuth, auth) - h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10)) +// AuthResponse is what server sends to client when authentication is passed. +type AuthResponse struct { + UDPEnabled bool + Rx uint64 // 0 = unlimited + RxAuto bool // true = server asks client to use bandwidth detection +} + +func AuthRequestFromHeader(h http.Header) AuthRequest { + rx, _ := strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) + return AuthRequest{ + Auth: h.Get(RequestHeaderAuth), + Rx: rx, + } +} + +func AuthRequestToHeader(h http.Header, req AuthRequest) { + h.Set(RequestHeaderAuth, req.Auth) + h.Set(CommonHeaderCCRX, strconv.FormatUint(req.Rx, 10)) h.Set(CommonHeaderPadding, authRequestPadding.String()) } -func AuthResponseDataFromHeader(h http.Header) (udp bool, rx uint64) { - udp, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled)) - rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) - return +func AuthResponseFromHeader(h http.Header) AuthResponse { + resp := AuthResponse{} + resp.UDPEnabled, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled)) + rxStr := h.Get(CommonHeaderCCRX) + if rxStr == "auto" { + // Special case for server requesting client to use bandwidth detection + resp.RxAuto = true + } else { + resp.Rx, _ = strconv.ParseUint(rxStr, 10, 64) + } + return resp } -func AuthResponseDataToHeader(h http.Header, udp bool, rx uint64) { - h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(udp)) - h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10)) +func AuthResponseToHeader(h http.Header, resp AuthResponse) { + h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(resp.UDPEnabled)) + if resp.RxAuto { + h.Set(CommonHeaderCCRX, "auto") + } else { + h.Set(CommonHeaderCCRX, strconv.FormatUint(resp.Rx, 10)) + } h.Set(CommonHeaderPadding, authResponsePadding.String()) } diff --git a/core/server/config.go b/core/server/config.go index 368da24..f647f0d 100644 --- a/core/server/config.go +++ b/core/server/config.go @@ -19,17 +19,18 @@ const ( ) type Config struct { - TLSConfig TLSConfig - QUICConfig QUICConfig - Conn net.PacketConn - Outbound Outbound - BandwidthConfig BandwidthConfig - DisableUDP bool - UDPIdleTimeout time.Duration - Authenticator Authenticator - EventLogger EventLogger - TrafficLogger TrafficLogger - MasqHandler http.Handler + TLSConfig TLSConfig + QUICConfig QUICConfig + Conn net.PacketConn + Outbound Outbound + BandwidthConfig BandwidthConfig + IgnoreClientBandwidth bool + DisableUDP bool + UDPIdleTimeout time.Duration + Authenticator Authenticator + EventLogger EventLogger + TrafficLogger TrafficLogger + MasqHandler http.Handler } // fill fills the fields that are not set by the user with default values when possible, diff --git a/core/server/server.go b/core/server/server.go index 5781851..1848ec0 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -9,9 +9,7 @@ import ( "github.com/apernet/quic-go" "github.com/apernet/quic-go/http3" - "github.com/apernet/hysteria/core/internal/congestion/bbr" - "github.com/apernet/hysteria/core/internal/congestion/brutal" - "github.com/apernet/hysteria/core/internal/congestion/common" + "github.com/apernet/hysteria/core/internal/congestion" "github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/utils" ) @@ -96,10 +94,10 @@ type h3sHandler struct { conn quic.Connection authenticated bool + authMutex sync.Mutex authID string - udpOnce sync.Once - udpSM *udpSessionManager // Only set after authentication + udpSM *udpSessionManager // Only set after authentication } func newH3sHandler(config *Config, conn quic.Connection) *h3sHandler { @@ -111,36 +109,49 @@ func newH3sHandler(config *Config, conn quic.Connection) *h3sHandler { func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath { + h.authMutex.Lock() + defer h.authMutex.Unlock() if h.authenticated { // Already authenticated - protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx) + protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ + UDPEnabled: !h.config.DisableUDP, + Rx: h.config.BandwidthConfig.MaxRx, + RxAuto: h.config.IgnoreClientBandwidth, + }) w.WriteHeader(protocol.StatusAuthOK) return } - auth, clientRx := protocol.AuthRequestDataFromHeader(r.Header) - // actualTx = min(serverTx, clientRx) - actualTx := clientRx - if h.config.BandwidthConfig.MaxTx > 0 && actualTx > h.config.BandwidthConfig.MaxTx { - actualTx = h.config.BandwidthConfig.MaxTx - } - ok, id := h.config.Authenticator.Authenticate(h.conn.RemoteAddr(), auth, actualTx) + authReq := protocol.AuthRequestFromHeader(r.Header) + actualTx := authReq.Rx + ok, id := h.config.Authenticator.Authenticate(h.conn.RemoteAddr(), authReq.Auth, actualTx) if ok { // Set authenticated flag h.authenticated = true h.authID = id - // Use Brutal CC if actualTx > 0, otherwise use BBR - if actualTx > 0 { - h.conn.SetCongestionControl(brutal.NewBrutalSender(actualTx)) + if h.config.IgnoreClientBandwidth { + // Ignore client bandwidth, always use BBR + congestion.UseBBR(h.conn) + actualTx = 0 } else { - h.conn.SetCongestionControl(bbr.NewBBRSender( - bbr.DefaultClock{}, - bbr.GetInitialPacketSize(h.conn.RemoteAddr()), - bbr.InitialCongestionWindow*common.InitMaxDatagramSize, - bbr.DefaultBBRMaxCongestionWindow*common.InitMaxDatagramSize, - )) + // actualTx = min(serverTx, clientRx) + if h.config.BandwidthConfig.MaxTx > 0 && actualTx > h.config.BandwidthConfig.MaxTx { + // We have a maxTx limit and the client is asking for more than that, + // return and use the limit instead + actualTx = h.config.BandwidthConfig.MaxTx + } + if actualTx > 0 { + congestion.UseBrutal(h.conn, actualTx) + } else { + // Client doesn't know its own bandwidth, use BBR + congestion.UseBBR(h.conn) + } } // Auth OK, send response - protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx) + protocol.AuthResponseToHeader(w.Header(), protocol.AuthResponse{ + UDPEnabled: !h.config.DisableUDP, + Rx: h.config.BandwidthConfig.MaxRx, + RxAuto: h.config.IgnoreClientBandwidth, + }) w.WriteHeader(protocol.StatusAuthOK) // Call event logger if h.config.EventLogger != nil { @@ -150,14 +161,14 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // We use sync.Once to make sure that only one goroutine is started, // as ServeHTTP may be called by multiple goroutines simultaneously if !h.config.DisableUDP { - h.udpOnce.Do(func() { + go func() { sm := newUDPSessionManager( &udpIOImpl{h.conn, id, h.config.TrafficLogger, h.config.Outbound}, &udpEventLoggerImpl{h.conn, id, h.config.EventLogger}, h.config.UDPIdleTimeout) h.udpSM = sm go sm.Run() - }) + }() } } else { // Auth failed, pretend to be a normal HTTP server