refactor: fill default values directly to configs

This commit is contained in:
Toby 2022-11-02 18:23:54 +00:00
parent 21b2830289
commit 2e84ca6ebe
3 changed files with 59 additions and 62 deletions

View File

@ -39,6 +39,7 @@ var clientPacketConnFuncFactoryMap = map[string]pktconns.ClientPacketConnFuncFac
func client(config *clientConfig) { func client(config *clientConfig) {
logrus.WithField("config", config.String()).Info("Client configuration loaded") logrus.WithField("config", config.String()).Info("Client configuration loaded")
config.Fill() // Fill default values
// Resolver // Resolver
if len(config.Resolver) > 0 { if len(config.Resolver) > 0 {
err := setResolver(config.Resolver) err := setResolver(config.Resolver)
@ -50,15 +51,11 @@ func client(config *clientConfig) {
} }
// TLS // TLS
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
NextProtos: []string{config.ALPN},
ServerName: config.ServerName, ServerName: config.ServerName,
InsecureSkipVerify: config.Insecure, InsecureSkipVerify: config.Insecure,
MinVersion: tls.VersionTLS13, MinVersion: tls.VersionTLS13,
} }
if config.ALPN != "" {
tlsConfig.NextProtos = []string{config.ALPN}
} else {
tlsConfig.NextProtos = []string{DefaultALPN}
}
// Load CA // Load CA
if len(config.CustomCA) > 0 { if len(config.CustomCA) > 0 {
bs, err := ioutil.ReadFile(config.CustomCA) bs, err := ioutil.ReadFile(config.CustomCA)
@ -83,24 +80,11 @@ func client(config *clientConfig) {
InitialConnectionReceiveWindow: config.ReceiveWindow, InitialConnectionReceiveWindow: config.ReceiveWindow,
MaxConnectionReceiveWindow: config.ReceiveWindow, MaxConnectionReceiveWindow: config.ReceiveWindow,
HandshakeIdleTimeout: time.Duration(config.HandshakeTimeout) * time.Second, HandshakeIdleTimeout: time.Duration(config.HandshakeTimeout) * time.Second,
MaxIdleTimeout: time.Duration(config.IdleTimeout) * time.Second,
KeepAlivePeriod: time.Duration(config.IdleTimeout) * time.Second * 2 / 5,
DisablePathMTUDiscovery: config.DisableMTUDiscovery, DisablePathMTUDiscovery: config.DisableMTUDiscovery,
EnableDatagrams: true, EnableDatagrams: true,
} }
if config.IdleTimeout == 0 {
quicConfig.MaxIdleTimeout = DefaultClientMaxIdleTimeout
quicConfig.KeepAlivePeriod = DefaultClientKeepAlivePeriod
} else {
quicConfig.MaxIdleTimeout = time.Duration(config.IdleTimeout) * time.Second
quicConfig.KeepAlivePeriod = quicConfig.MaxIdleTimeout * 2 / 5
}
if config.ReceiveWindowConn == 0 {
quicConfig.InitialStreamReceiveWindow = DefaultStreamReceiveWindow
quicConfig.MaxStreamReceiveWindow = DefaultStreamReceiveWindow
}
if config.ReceiveWindow == 0 {
quicConfig.InitialConnectionReceiveWindow = DefaultConnectionReceiveWindow
quicConfig.MaxConnectionReceiveWindow = DefaultConnectionReceiveWindow
}
if !quicConfig.DisablePathMTUDiscovery && pmtud.DisablePathMTUDiscovery { if !quicConfig.DisablePathMTUDiscovery && pmtud.DisablePathMTUDiscovery {
logrus.Info("Path MTU Discovery is not yet supported on this platform") logrus.Info("Path MTU Discovery is not yet supported on this platform")
} }
@ -135,11 +119,7 @@ func client(config *clientConfig) {
var err error var err error
aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultClientTransport.ResolveIPAddr, aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultClientTransport.ResolveIPAddr,
func() (*geoip2.Reader, error) { func() (*geoip2.Reader, error) {
if len(config.MMDB) > 0 {
return loadMMDBReader(config.MMDB) return loadMMDBReader(config.MMDB)
} else {
return loadMMDBReader(DefaultMMDBFilename)
}
}) })
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"regexp" "regexp"
"strconv" "strconv"
"time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/yosuke-furukawa/json5/encoding/json5" "github.com/yosuke-furukawa/json5/encoding/json5"
@ -15,17 +14,16 @@ const (
mbpsToBps = 125000 mbpsToBps = 125000
minSpeedBPS = 16384 minSpeedBPS = 16384
DefaultALPN = "hysteria"
DefaultStreamReceiveWindow = 15728640 // 15 MB/s DefaultStreamReceiveWindow = 15728640 // 15 MB/s
DefaultConnectionReceiveWindow = 67108864 // 64 MB/s DefaultConnectionReceiveWindow = 67108864 // 64 MB/s
DefaultMaxIncomingStreams = 1024 DefaultMaxIncomingStreams = 1024
DefaultALPN = "hysteria"
DefaultMMDBFilename = "GeoLite2-Country.mmdb" DefaultMMDBFilename = "GeoLite2-Country.mmdb"
ServerMaxIdleTimeout = 60 * time.Second ServerMaxIdleTimeoutSec = 60
DefaultClientMaxIdleTimeout = 20 * time.Second DefaultClientIdleTimeoutSec = 20
DefaultClientKeepAlivePeriod = 8 * time.Second
) )
var rateStringRegexp = regexp.MustCompile(`^(\d+)\s*([KMGT]?)([Bb])ps$`) var rateStringRegexp = regexp.MustCompile(`^(\d+)\s*([KMGT]?)([Bb])ps$`)
@ -98,10 +96,10 @@ func (c *serverConfig) Speed() (uint64, uint64, error) {
func (c *serverConfig) Check() error { func (c *serverConfig) Check() error {
if len(c.Listen) == 0 { if len(c.Listen) == 0 {
return errors.New("no listen address") return errors.New("missing listen address")
} }
if len(c.ACME.Domains) == 0 && (len(c.CertFile) == 0 || len(c.KeyFile) == 0) { if len(c.ACME.Domains) == 0 && (len(c.CertFile) == 0 || len(c.KeyFile) == 0) {
return errors.New("ACME domain or TLS cert not provided") return errors.New("need either ACME info or cert/key files")
} }
if up, down, err := c.Speed(); err != nil || (up != 0 && up < minSpeedBPS) || (down != 0 && down < minSpeedBPS) { if up, down, err := c.Speed(); err != nil || (up != 0 && up < minSpeedBPS) || (down != 0 && down < minSpeedBPS) {
return errors.New("invalid speed") return errors.New("invalid speed")
@ -116,6 +114,24 @@ func (c *serverConfig) Check() error {
return nil return nil
} }
func (c *serverConfig) Fill() {
if len(c.ALPN) == 0 {
c.ALPN = DefaultALPN
}
if c.ReceiveWindowConn == 0 {
c.ReceiveWindowConn = DefaultStreamReceiveWindow
}
if c.ReceiveWindowClient == 0 {
c.ReceiveWindowClient = DefaultConnectionReceiveWindow
}
if c.MaxConnClient == 0 {
c.MaxConnClient = DefaultMaxIncomingStreams
}
if len(c.MMDB) == 0 {
c.MMDB = DefaultMMDBFilename
}
}
func (c *serverConfig) String() string { func (c *serverConfig) String() string {
return fmt.Sprintf("%+v", *c) return fmt.Sprintf("%+v", *c)
} }
@ -128,10 +144,10 @@ type Relay struct {
func (r *Relay) Check() error { func (r *Relay) Check() error {
if len(r.Listen) == 0 { if len(r.Listen) == 0 {
return errors.New("no relay listen address") return errors.New("missing relay listen address")
} }
if len(r.Remote) == 0 { if len(r.Remote) == 0 {
return errors.New("no relay remote address") return errors.New("missing relay remote address")
} }
if r.Timeout != 0 && r.Timeout < 4 { if r.Timeout != 0 && r.Timeout < 4 {
return errors.New("invalid relay timeout") return errors.New("invalid relay timeout")
@ -252,10 +268,10 @@ func (c *clientConfig) Check() error {
return errors.New("invalid TUN timeout") return errors.New("invalid TUN timeout")
} }
if len(c.TCPRelay.Listen) > 0 && len(c.TCPRelay.Remote) == 0 { if len(c.TCPRelay.Listen) > 0 && len(c.TCPRelay.Remote) == 0 {
return errors.New("no TCP relay remote address") return errors.New("missing TCP relay remote address")
} }
if len(c.UDPRelay.Listen) > 0 && len(c.UDPRelay.Remote) == 0 { if len(c.UDPRelay.Listen) > 0 && len(c.UDPRelay.Remote) == 0 {
return errors.New("no UDP relay remote address") return errors.New("missing UDP relay remote address")
} }
if c.TCPRelay.Timeout != 0 && c.TCPRelay.Timeout < 4 { if c.TCPRelay.Timeout != 0 && c.TCPRelay.Timeout < 4 {
return errors.New("invalid TCP relay timeout") return errors.New("invalid TCP relay timeout")
@ -283,7 +299,7 @@ func (c *clientConfig) Check() error {
return errors.New("invalid TCP Redirect timeout") return errors.New("invalid TCP Redirect timeout")
} }
if len(c.Server) == 0 { if len(c.Server) == 0 {
return errors.New("no server address") return errors.New("missing server address")
} }
if up, down, err := c.Speed(); err != nil || up < minSpeedBPS || down < minSpeedBPS { if up, down, err := c.Speed(); err != nil || up < minSpeedBPS || down < minSpeedBPS {
return errors.New("invalid speed") return errors.New("invalid speed")
@ -301,6 +317,24 @@ func (c *clientConfig) Check() error {
return nil return nil
} }
func (c *clientConfig) Fill() {
if len(c.ALPN) == 0 {
c.ALPN = DefaultALPN
}
if c.ReceiveWindowConn == 0 {
c.ReceiveWindowConn = DefaultStreamReceiveWindow
}
if c.ReceiveWindow == 0 {
c.ReceiveWindow = DefaultConnectionReceiveWindow
}
if len(c.MMDB) == 0 {
c.MMDB = DefaultMMDBFilename
}
if c.IdleTimeout == 0 {
c.IdleTimeout = DefaultClientIdleTimeoutSec
}
}
func (c *clientConfig) String() string { func (c *clientConfig) String() string {
return fmt.Sprintf("%+v", *c) return fmt.Sprintf("%+v", *c)
} }

View File

@ -33,6 +33,7 @@ var serverPacketConnFuncFactoryMap = map[string]pktconns.ServerPacketConnFuncFac
func server(config *serverConfig) { func server(config *serverConfig) {
logrus.WithField("config", config.String()).Info("Server configuration loaded") logrus.WithField("config", config.String()).Info("Server configuration loaded")
config.Fill() // Fill default values
// Resolver // Resolver
if len(config.Resolver) > 0 { if len(config.Resolver) > 0 {
err := setResolver(config.Resolver) err := setResolver(config.Resolver)
@ -54,6 +55,7 @@ func server(config *serverConfig) {
"error": err, "error": err,
}).Fatal("Failed to get a certificate with ACME") }).Fatal("Failed to get a certificate with ACME")
} }
tc.NextProtos = []string{config.ALPN}
tc.MinVersion = tls.VersionTLS13 tc.MinVersion = tls.VersionTLS13
tlsConfig = tc tlsConfig = tc
} else { } else {
@ -68,14 +70,10 @@ func server(config *serverConfig) {
} }
tlsConfig = &tls.Config{ tlsConfig = &tls.Config{
GetCertificate: kpl.GetCertificateFunc(), GetCertificate: kpl.GetCertificateFunc(),
NextProtos: []string{config.ALPN},
MinVersion: tls.VersionTLS13, MinVersion: tls.VersionTLS13,
} }
} }
if config.ALPN != "" {
tlsConfig.NextProtos = []string{config.ALPN}
} else {
tlsConfig.NextProtos = []string{DefaultALPN}
}
// QUIC config // QUIC config
quicConfig := &quic.Config{ quicConfig := &quic.Config{
InitialStreamReceiveWindow: config.ReceiveWindowConn, InitialStreamReceiveWindow: config.ReceiveWindowConn,
@ -83,22 +81,11 @@ func server(config *serverConfig) {
InitialConnectionReceiveWindow: config.ReceiveWindowClient, InitialConnectionReceiveWindow: config.ReceiveWindowClient,
MaxConnectionReceiveWindow: config.ReceiveWindowClient, MaxConnectionReceiveWindow: config.ReceiveWindowClient,
MaxIncomingStreams: int64(config.MaxConnClient), MaxIncomingStreams: int64(config.MaxConnClient),
MaxIdleTimeout: ServerMaxIdleTimeout, MaxIdleTimeout: ServerMaxIdleTimeoutSec * time.Second,
KeepAlivePeriod: 0, // Keep alive should solely be client's responsibility KeepAlivePeriod: 0, // Keep alive should solely be client's responsibility
DisablePathMTUDiscovery: config.DisableMTUDiscovery, DisablePathMTUDiscovery: config.DisableMTUDiscovery,
EnableDatagrams: true, EnableDatagrams: true,
} }
if config.ReceiveWindowConn == 0 {
quicConfig.InitialStreamReceiveWindow = DefaultStreamReceiveWindow
quicConfig.MaxStreamReceiveWindow = DefaultStreamReceiveWindow
}
if config.ReceiveWindowClient == 0 {
quicConfig.InitialConnectionReceiveWindow = DefaultConnectionReceiveWindow
quicConfig.MaxConnectionReceiveWindow = DefaultConnectionReceiveWindow
}
if quicConfig.MaxIncomingStreams == 0 {
quicConfig.MaxIncomingStreams = DefaultMaxIncomingStreams
}
if !quicConfig.DisablePathMTUDiscovery && pmtud.DisablePathMTUDiscovery { if !quicConfig.DisablePathMTUDiscovery && pmtud.DisablePathMTUDiscovery {
logrus.Info("Path MTU Discovery is not yet supported on this platform") logrus.Info("Path MTU Discovery is not yet supported on this platform")
} }
@ -108,8 +95,8 @@ func server(config *serverConfig) {
switch authMode := config.Auth.Mode; authMode { switch authMode := config.Auth.Mode; authMode {
case "", "none": case "", "none":
if len(config.Obfs) == 0 { if len(config.Obfs) == 0 {
logrus.Warn("No authentication or obfuscation enabled. " + logrus.Warn("Neither authentication nor obfuscation is turned on. " +
"Your server could be accessed by anyone! Are you sure this is what you intended?") "Your server could be used by anyone! Are you sure this is what you want?")
} }
authFunc = func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) { authFunc = func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) {
return true, "Welcome" return true, "Welcome"
@ -199,11 +186,7 @@ func server(config *serverConfig) {
return ipAddr, err return ipAddr, err
}, },
func() (*geoip2.Reader, error) { func() (*geoip2.Reader, error) {
if len(config.MMDB) > 0 {
return loadMMDBReader(config.MMDB) return loadMMDBReader(config.MMDB)
} else {
return loadMMDBReader(DefaultMMDBFilename)
}
}) })
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{