diff --git a/cmd/client.go b/cmd/client.go index d0fba20..ba25fac 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -131,10 +131,11 @@ func client(config *clientConfig) { // Client var client *core.Client try := 0 + up, down, _ := config.Speed() for { try += 1 c, err := core.NewClient(config.Server, config.Protocol, auth, tlsConfig, quicConfig, - transport.DefaultClientTransport, uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, + transport.DefaultClientTransport, up, down, func(refBPS uint64) congestion.CongestionControl { return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) }, obfuscator) diff --git a/cmd/config.go b/cmd/config.go index 7e14c3c..edb448c 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -3,13 +3,15 @@ package main import ( "errors" "fmt" - "github.com/sirupsen/logrus" "github.com/yosuke-furukawa/json5/encoding/json5" + "regexp" + "strconv" ) const ( - mbpsToBps = 125000 + mbpsToBps = 125000 + minSpeedBPS = 16384 DefaultStreamReceiveWindow = 15728640 // 15 MB/s DefaultConnectionReceiveWindow = 67108864 // 64 MB/s @@ -20,6 +22,8 @@ const ( DefaultMMDBFilename = "GeoLite2-Country.mmdb" ) +var rateStringRegexp = regexp.MustCompile(`^(\d+)\s*([KMGT]?)([Bb])ps$`) + type serverConfig struct { Listen string `json:"listen"` Protocol string `json:"protocol"` @@ -34,7 +38,9 @@ type serverConfig struct { CertFile string `json:"cert"` KeyFile string `json:"key"` // Optional below + Up string `json:"up"` UpMbps int `json:"up_mbps"` + Down string `json:"down"` DownMbps int `json:"down_mbps"` DisableUDP bool `json:"disable_udp"` ACL string `json:"acl"` @@ -59,6 +65,27 @@ type serverConfig struct { } `json:"socks5_outbound"` } +func (c *serverConfig) Speed() (uint64, uint64, error) { + var up, down uint64 + if len(c.Up) > 0 { + up = stringToBps(c.Up) + if up == 0 { + return 0, 0, errors.New("invalid speed format") + } + } else { + up = uint64(c.UpMbps) * mbpsToBps + } + if len(c.Down) > 0 { + down = stringToBps(c.Down) + if down == 0 { + return 0, 0, errors.New("invalid speed format") + } + } else { + down = uint64(c.DownMbps) * mbpsToBps + } + return up, down, nil +} + func (c *serverConfig) Check() error { if len(c.Listen) == 0 { return errors.New("no listen address") @@ -66,7 +93,7 @@ func (c *serverConfig) Check() error { if len(c.ACME.Domains) == 0 && (len(c.CertFile) == 0 || len(c.KeyFile) == 0) { return errors.New("ACME domain or TLS cert not provided") } - if c.UpMbps < 0 || c.DownMbps < 0 { + if up, down, err := c.Speed(); err != nil || (up != 0 && up < minSpeedBPS) || (down != 0 && down < minSpeedBPS) { return errors.New("invalid speed") } if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || @@ -105,7 +132,9 @@ func (r *Relay) Check() error { type clientConfig struct { Server string `json:"server"` Protocol string `json:"protocol"` + Up string `json:"up"` UpMbps int `json:"up_mbps"` + Down string `json:"down"` DownMbps int `json:"down_mbps"` Retry int `json:"retry"` RetryInterval int `json:"retry_interval"` @@ -162,6 +191,27 @@ type clientConfig struct { ResolvePreference string `json:"resolve_preference"` } +func (c *clientConfig) Speed() (uint64, uint64, error) { + var up, down uint64 + if len(c.Up) > 0 { + up = stringToBps(c.Up) + if up == 0 { + return 0, 0, errors.New("invalid speed format") + } + } else { + up = uint64(c.UpMbps) * mbpsToBps + } + if len(c.Down) > 0 { + down = stringToBps(c.Down) + if down == 0 { + return 0, 0, errors.New("invalid speed format") + } + } else { + down = uint64(c.DownMbps) * mbpsToBps + } + return up, down, nil +} + func (c *clientConfig) Check() error { if len(c.SOCKS5.Listen) == 0 && len(c.HTTP.Listen) == 0 && len(c.TUN.Name) == 0 && len(c.TCPRelay.Listen) == 0 && len(c.UDPRelay.Listen) == 0 && @@ -209,7 +259,7 @@ func (c *clientConfig) Check() error { if len(c.Server) == 0 { return errors.New("no server address") } - if c.UpMbps <= 0 || c.DownMbps <= 0 { + if up, down, err := c.Speed(); err != nil || up < minSpeedBPS || down < minSpeedBPS { return errors.New("invalid speed") } if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || @@ -228,3 +278,33 @@ func (c *clientConfig) Check() error { func (c *clientConfig) String() string { return fmt.Sprintf("%+v", *c) } + +func stringToBps(s string) uint64 { + if s == "" { + return 0 + } + m := rateStringRegexp.FindStringSubmatch(s) + if m == nil { + return 0 + } + var n uint64 + switch m[2] { + case "K", "k": + n = 1 << 10 + case "M": + n = 1 << 20 + case "G": + n = 1 << 30 + case "T": + n = 1 << 40 + default: + n = 1 + } + v, _ := strconv.ParseUint(m[1], 10, 64) + n = v * n + if m[3] == "b" { + // Bits, need to convert to bytes + n = n >> 3 + } + return n +} diff --git a/cmd/config_test.go b/cmd/config_test.go new file mode 100644 index 0000000..11006ba --- /dev/null +++ b/cmd/config_test.go @@ -0,0 +1,34 @@ +package main + +import "testing" + +func Test_stringToBps(t *testing.T) { + tests := []struct { + name string + s string + want uint64 + }{ + {name: "bps 1", s: "8 bps", want: 1}, + {name: "bps 2", s: "3 bps", want: 0}, + {name: "Bps", s: "9991Bps", want: 9991}, + {name: "KBps", s: "10 KBps", want: 10240}, + {name: "Kbps", s: "10 Kbps", want: 1280}, + {name: "MBps", s: "10 MBps", want: 10485760}, + {name: "Mbps", s: "10 Mbps", want: 1310720}, + {name: "GBps", s: "10 GBps", want: 10737418240}, + {name: "Gbps", s: "10 Gbps", want: 1342177280}, + {name: "TBps", s: "10 TBps", want: 10995116277760}, + {name: "Tbps", s: "10 Tbps", want: 1374389534720}, + {name: "invalid 1", s: "6699E Kbps", want: 0}, + {name: "invalid 2", s: "400 Bsp", want: 0}, + {name: "invalid 3", s: "9 GBbps", want: 0}, + {name: "invalid 4", s: "Mbps", want: 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := stringToBps(tt.s); got != tt.want { + t.Errorf("stringToBps() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/server.go b/cmd/server.go index 6803ece..2d36772 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -192,8 +192,9 @@ func server(config *serverConfig) { logrus.WithField("error", err).Fatal("Prometheus HTTP server error") }() } + up, down, _ := config.Speed() server, err := core.NewServer(config.Listen, config.Protocol, tlsConfig, quicConfig, transport.DefaultServerTransport, - uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, + up, down, func(refBPS uint64) congestion.CongestionControl { return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) }, config.DisableUDP, aclEngine, obfuscator, connectFunc, disconnectFunc,