From ca3de154bac2196e9747fb0ddab90b587340dbbf Mon Sep 17 00:00:00 2001 From: Toby Date: Fri, 21 Oct 2022 15:48:00 -0700 Subject: [PATCH] chore: remove congestion factory --- cmd/client.go | 8 ++------ cmd/server.go | 7 +------ pkg/congestion/brutal.go | 4 ++-- pkg/core/client.go | 27 ++++++++++++-------------- pkg/core/server.go | 42 ++++++++++++++++++++-------------------- 5 files changed, 38 insertions(+), 50 deletions(-) diff --git a/cmd/client.go b/cmd/client.go index d3a76e3..f36e8c5 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -17,7 +17,6 @@ import ( "github.com/yosuke-furukawa/json5/encoding/json5" "github.com/HyNetwork/hysteria/pkg/acl" - hyCongestion "github.com/HyNetwork/hysteria/pkg/congestion" "github.com/HyNetwork/hysteria/pkg/core" hyHTTP "github.com/HyNetwork/hysteria/pkg/http" "github.com/HyNetwork/hysteria/pkg/obfs" @@ -26,7 +25,6 @@ import ( "github.com/HyNetwork/hysteria/pkg/tproxy" "github.com/HyNetwork/hysteria/pkg/transport" "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/congestion" "github.com/sirupsen/logrus" ) @@ -145,10 +143,8 @@ func client(config *clientConfig) { for { try += 1 c, err := core.NewClient(config.Server, config.Protocol, auth, tlsConfig, quicConfig, - transport.DefaultClientTransport, up, down, - func(refBPS uint64) congestion.CongestionControl { - return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) - }, obfuscator, func(err error) { + transport.DefaultClientTransport, up, down, obfuscator, + func(err error) { if config.QuitOnDisconnect { logrus.WithFields(logrus.Fields{ "addr": config.Server, diff --git a/cmd/server.go b/cmd/server.go index 6412751..47041e5 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -10,14 +10,12 @@ import ( "github.com/HyNetwork/hysteria/cmd/auth" "github.com/HyNetwork/hysteria/pkg/acl" - hyCongestion "github.com/HyNetwork/hysteria/pkg/congestion" "github.com/HyNetwork/hysteria/pkg/core" "github.com/HyNetwork/hysteria/pkg/obfs" "github.com/HyNetwork/hysteria/pkg/pmtud_fix" "github.com/HyNetwork/hysteria/pkg/sockopt" "github.com/HyNetwork/hysteria/pkg/transport" "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/congestion" "github.com/oschwald/geoip2-golang" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -224,10 +222,7 @@ func server(config *serverConfig) { } up, down, _ := config.Speed() server, err := core.NewServer(config.Listen, config.Protocol, tlsConfig, quicConfig, transport.DefaultServerTransport, - up, down, - func(refBPS uint64) congestion.CongestionControl { - return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) - }, config.DisableUDP, aclEngine, obfuscator, connectFunc, disconnectFunc, + up, down, config.DisableUDP, aclEngine, obfuscator, connectFunc, disconnectFunc, tcpRequestFunc, tcpErrorFunc, udpRequestFunc, udpErrorFunc, promReg) if err != nil { logrus.WithField("error", err).Fatal("Failed to initialize server") diff --git a/pkg/congestion/brutal.go b/pkg/congestion/brutal.go index 02f12d4..740cb99 100644 --- a/pkg/congestion/brutal.go +++ b/pkg/congestion/brutal.go @@ -30,9 +30,9 @@ type pktInfo struct { LossCount uint64 } -func NewBrutalSender(bps congestion.ByteCount) *BrutalSender { +func NewBrutalSender(bps uint64) *BrutalSender { bs := &BrutalSender{ - bps: bps, + bps: congestion.ByteCount(bps), maxDatagramSize: initMaxDatagramSize, ackRate: 1, } diff --git a/pkg/core/client.go b/pkg/core/client.go index c2a1bf3..3a2fa86 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -12,27 +12,25 @@ import ( "sync" "time" + "github.com/HyNetwork/hysteria/pkg/congestion" + "github.com/HyNetwork/hysteria/pkg/obfs" "github.com/HyNetwork/hysteria/pkg/pmtud_fix" "github.com/HyNetwork/hysteria/pkg/transport" "github.com/HyNetwork/hysteria/pkg/utils" "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/congestion" "github.com/lunixbochs/struc" ) var ErrClosed = errors.New("closed") -type CongestionFactory func(refBPS uint64) congestion.CongestionControl - type Client struct { - transport *transport.ClientTransport - serverAddr string - protocol string - sendBPS, recvBPS uint64 - auth []byte - congestionFactory CongestionFactory - obfuscator obfs.Obfuscator + transport *transport.ClientTransport + serverAddr string + protocol string + sendBPS, recvBPS uint64 + auth []byte + obfuscator obfs.Obfuscator tlsConfig *tls.Config quicConfig *quic.Config @@ -49,8 +47,8 @@ type Client struct { } func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, - transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, - obfuscator obfs.Obfuscator, quicReconnectFunc func(err error), + transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, obfuscator obfs.Obfuscator, + quicReconnectFunc func(err error), ) (*Client, error) { quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery c := &Client{ @@ -60,7 +58,6 @@ func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.C sendBPS: sendBPS, recvBPS: recvBPS, auth: auth, - congestionFactory: congestionFactory, obfuscator: obfuscator, tlsConfig: tlsConfig, quicConfig: quicConfig, @@ -125,8 +122,8 @@ func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bo return false, "", err } // Set the congestion accordingly - if sh.OK && c.congestionFactory != nil { - qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS)) + if sh.OK { + qs.SetCongestionControl(congestion.NewBrutalSender(sh.Rate.RecvBPS)) } return sh.OK, sh.Message, nil } diff --git a/pkg/core/server.go b/pkg/core/server.go index b0c4156..043d21f 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -7,6 +7,8 @@ import ( "fmt" "net" + "github.com/HyNetwork/hysteria/pkg/congestion" + "github.com/HyNetwork/hysteria/pkg/acl" "github.com/HyNetwork/hysteria/pkg/obfs" "github.com/HyNetwork/hysteria/pkg/pmtud_fix" @@ -26,11 +28,10 @@ type ( ) type Server struct { - transport *transport.ServerTransport - sendBPS, recvBPS uint64 - congestionFactory CongestionFactory - disableUDP bool - aclEngine *acl.Engine + transport *transport.ServerTransport + sendBPS, recvBPS uint64 + disableUDP bool + aclEngine *acl.Engine connectFunc ConnectFunc disconnectFunc DisconnectFunc @@ -46,7 +47,7 @@ type Server struct { } func NewServer(addr string, protocol string, tlsConfig *tls.Config, quicConfig *quic.Config, transport *transport.ServerTransport, - sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine, + sendBPS uint64, recvBPS uint64, disableUDP bool, aclEngine *acl.Engine, obfuscator obfs.Obfuscator, connectFunc ConnectFunc, disconnectFunc DisconnectFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc, udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc, promRegistry *prometheus.Registry, @@ -57,19 +58,18 @@ func NewServer(addr string, protocol string, tlsConfig *tls.Config, quicConfig * return nil, err } s := &Server{ - listener: listener, - transport: transport, - sendBPS: sendBPS, - recvBPS: recvBPS, - congestionFactory: congestionFactory, - disableUDP: disableUDP, - aclEngine: aclEngine, - connectFunc: connectFunc, - disconnectFunc: disconnectFunc, - tcpRequestFunc: tcpRequestFunc, - tcpErrorFunc: tcpErrorFunc, - udpRequestFunc: udpRequestFunc, - udpErrorFunc: udpErrorFunc, + listener: listener, + transport: transport, + sendBPS: sendBPS, + recvBPS: recvBPS, + disableUDP: disableUDP, + aclEngine: aclEngine, + connectFunc: connectFunc, + disconnectFunc: disconnectFunc, + tcpRequestFunc: tcpRequestFunc, + tcpErrorFunc: tcpErrorFunc, + udpRequestFunc: udpRequestFunc, + udpErrorFunc: udpErrorFunc, } if promRegistry != nil { s.upCounterVec = prometheus.NewCounterVec(prometheus.CounterOpts{ @@ -172,8 +172,8 @@ func (s *Server) handleControlStream(cs quic.Connection, stream quic.Stream) ([] return nil, false, false, err } // Set the congestion accordingly - if ok && s.congestionFactory != nil { - cs.SetCongestionControl(s.congestionFactory(serverSendBPS)) + if ok { + cs.SetCongestionControl(congestion.NewBrutalSender(serverSendBPS)) } return ch.Auth, ok, vb[0] == protocolVersionV2, nil }