mirror of
https://github.com/cedar2025/hysteria.git
synced 2025-06-08 21:39:56 +00:00

Prevent the congestion window from falling below the size of single packet in scenarios with extremely low RTT, which previously led to transmission stalls.
186 lines
5.3 KiB
Go
186 lines
5.3 KiB
Go
package brutal
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/apernet/hysteria/core/internal/congestion/common"
|
|
|
|
"github.com/apernet/quic-go/congestion"
|
|
)
|
|
|
|
const (
|
|
pktInfoSlotCount = 5 // slot index is based on seconds, so this is basically how many seconds we sample
|
|
minSampleCount = 50
|
|
minAckRate = 0.8
|
|
congestionWindowMultiplier = 2
|
|
|
|
debugEnv = "HYSTERIA_BRUTAL_DEBUG"
|
|
debugPrintInterval = 2
|
|
)
|
|
|
|
var _ congestion.CongestionControl = &BrutalSender{}
|
|
|
|
type BrutalSender struct {
|
|
rttStats congestion.RTTStatsProvider
|
|
bps congestion.ByteCount
|
|
maxDatagramSize congestion.ByteCount
|
|
pacer *common.Pacer
|
|
|
|
pktInfoSlots [pktInfoSlotCount]pktInfo
|
|
ackRate float64
|
|
|
|
debug bool
|
|
lastAckPrintTimestamp int64
|
|
}
|
|
|
|
type pktInfo struct {
|
|
Timestamp int64
|
|
AckCount uint64
|
|
LossCount uint64
|
|
}
|
|
|
|
func NewBrutalSender(bps uint64) *BrutalSender {
|
|
debug, _ := strconv.ParseBool(os.Getenv(debugEnv))
|
|
bs := &BrutalSender{
|
|
bps: congestion.ByteCount(bps),
|
|
maxDatagramSize: congestion.InitialPacketSizeIPv4,
|
|
ackRate: 1,
|
|
debug: debug,
|
|
}
|
|
bs.pacer = common.NewPacer(func() congestion.ByteCount {
|
|
return congestion.ByteCount(float64(bs.bps) / bs.ackRate)
|
|
})
|
|
return bs
|
|
}
|
|
|
|
func (b *BrutalSender) SetRTTStatsProvider(rttStats congestion.RTTStatsProvider) {
|
|
b.rttStats = rttStats
|
|
}
|
|
|
|
func (b *BrutalSender) TimeUntilSend(bytesInFlight congestion.ByteCount) time.Time {
|
|
return b.pacer.TimeUntilSend()
|
|
}
|
|
|
|
func (b *BrutalSender) HasPacingBudget(now time.Time) bool {
|
|
return b.pacer.Budget(now) >= b.maxDatagramSize
|
|
}
|
|
|
|
func (b *BrutalSender) CanSend(bytesInFlight congestion.ByteCount) bool {
|
|
return bytesInFlight <= b.GetCongestionWindow()
|
|
}
|
|
|
|
func (b *BrutalSender) GetCongestionWindow() congestion.ByteCount {
|
|
rtt := b.rttStats.SmoothedRTT()
|
|
if rtt <= 0 {
|
|
return 10240
|
|
}
|
|
cwnd := congestion.ByteCount(float64(b.bps) * rtt.Seconds() * congestionWindowMultiplier / b.ackRate)
|
|
if cwnd < b.maxDatagramSize {
|
|
cwnd = b.maxDatagramSize
|
|
}
|
|
return cwnd
|
|
}
|
|
|
|
func (b *BrutalSender) OnPacketSent(sentTime time.Time, bytesInFlight congestion.ByteCount,
|
|
packetNumber congestion.PacketNumber, bytes congestion.ByteCount, isRetransmittable bool,
|
|
) {
|
|
b.pacer.SentPacket(sentTime, bytes)
|
|
}
|
|
|
|
func (b *BrutalSender) OnPacketAcked(number congestion.PacketNumber, ackedBytes congestion.ByteCount,
|
|
priorInFlight congestion.ByteCount, eventTime time.Time,
|
|
) {
|
|
// Stub
|
|
}
|
|
|
|
func (b *BrutalSender) OnCongestionEvent(number congestion.PacketNumber, lostBytes congestion.ByteCount,
|
|
priorInFlight congestion.ByteCount,
|
|
) {
|
|
// Stub
|
|
}
|
|
|
|
func (b *BrutalSender) OnCongestionEventEx(priorInFlight congestion.ByteCount, eventTime time.Time, ackedPackets []congestion.AckedPacketInfo, lostPackets []congestion.LostPacketInfo) {
|
|
currentTimestamp := eventTime.Unix()
|
|
slot := currentTimestamp % pktInfoSlotCount
|
|
if b.pktInfoSlots[slot].Timestamp == currentTimestamp {
|
|
b.pktInfoSlots[slot].LossCount += uint64(len(lostPackets))
|
|
b.pktInfoSlots[slot].AckCount += uint64(len(ackedPackets))
|
|
} else {
|
|
// uninitialized slot or too old, reset
|
|
b.pktInfoSlots[slot].Timestamp = currentTimestamp
|
|
b.pktInfoSlots[slot].AckCount = uint64(len(ackedPackets))
|
|
b.pktInfoSlots[slot].LossCount = uint64(len(lostPackets))
|
|
}
|
|
b.updateAckRate(currentTimestamp)
|
|
}
|
|
|
|
func (b *BrutalSender) SetMaxDatagramSize(size congestion.ByteCount) {
|
|
b.maxDatagramSize = size
|
|
b.pacer.SetMaxDatagramSize(size)
|
|
if b.debug {
|
|
b.debugPrint("SetMaxDatagramSize: %d", size)
|
|
}
|
|
}
|
|
|
|
func (b *BrutalSender) updateAckRate(currentTimestamp int64) {
|
|
minTimestamp := currentTimestamp - pktInfoSlotCount
|
|
var ackCount, lossCount uint64
|
|
for _, info := range b.pktInfoSlots {
|
|
if info.Timestamp < minTimestamp {
|
|
continue
|
|
}
|
|
ackCount += info.AckCount
|
|
lossCount += info.LossCount
|
|
}
|
|
if ackCount+lossCount < minSampleCount {
|
|
b.ackRate = 1
|
|
if b.canPrintAckRate(currentTimestamp) {
|
|
b.lastAckPrintTimestamp = currentTimestamp
|
|
b.debugPrint("Not enough samples (total=%d, ack=%d, loss=%d, rtt=%d)",
|
|
ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
|
}
|
|
return
|
|
}
|
|
rate := float64(ackCount) / float64(ackCount+lossCount)
|
|
if rate < minAckRate {
|
|
b.ackRate = minAckRate
|
|
if b.canPrintAckRate(currentTimestamp) {
|
|
b.lastAckPrintTimestamp = currentTimestamp
|
|
b.debugPrint("ACK rate too low: %.2f, clamped to %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
|
|
rate, minAckRate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
|
}
|
|
return
|
|
}
|
|
b.ackRate = rate
|
|
if b.canPrintAckRate(currentTimestamp) {
|
|
b.lastAckPrintTimestamp = currentTimestamp
|
|
b.debugPrint("ACK rate: %.2f (total=%d, ack=%d, loss=%d, rtt=%d)",
|
|
rate, ackCount+lossCount, ackCount, lossCount, b.rttStats.SmoothedRTT().Milliseconds())
|
|
}
|
|
}
|
|
|
|
func (b *BrutalSender) InSlowStart() bool {
|
|
return false
|
|
}
|
|
|
|
func (b *BrutalSender) InRecovery() bool {
|
|
return false
|
|
}
|
|
|
|
func (b *BrutalSender) MaybeExitSlowStart() {}
|
|
|
|
func (b *BrutalSender) OnRetransmissionTimeout(packetsRetransmitted bool) {}
|
|
|
|
func (b *BrutalSender) canPrintAckRate(currentTimestamp int64) bool {
|
|
return b.debug && currentTimestamp-b.lastAckPrintTimestamp >= debugPrintInterval
|
|
}
|
|
|
|
func (b *BrutalSender) debugPrint(format string, a ...any) {
|
|
fmt.Printf("[BrutalSender] [%s] %s\n",
|
|
time.Now().Format("15:04:05"),
|
|
fmt.Sprintf(format, a...))
|
|
}
|