feat: rework config parsing to use viper unmarshal

This commit is contained in:
tobyxdd
2023-06-30 13:16:01 -07:00
parent 8342827339
commit eb7e91e5ce
3 changed files with 341 additions and 345 deletions

View File

@@ -6,6 +6,7 @@ import (
"net"
"os"
"sync"
"time"
"github.com/spf13/cobra"
"github.com/spf13/viper"
@@ -22,154 +23,174 @@ var clientCmd = &cobra.Command{
Run: runClient,
}
type modeFunc func(*viper.Viper, client.Client) error
var modeMap = map[string]modeFunc{
"socks5": clientSOCKS5,
"http": clientHTTP,
}
func init() {
rootCmd.AddCommand(clientCmd)
}
type clientConfig struct {
Server string `mapstructure:"server"`
Auth string `mapstructure:"auth"`
TLS struct {
SNI string `mapstructure:"sni"`
Insecure bool `mapstructure:"insecure"`
CA string `mapstructure:"ca"`
} `mapstructure:"tls"`
QUIC struct {
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"`
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
} `mapstructure:"quic"`
Bandwidth struct {
Up string `mapstructure:"up"`
Down string `mapstructure:"down"`
} `mapstructure:"bandwidth"`
FastOpen bool `mapstructure:"fastOpen"`
SOCKS5 *socks5Config `mapstructure:"socks5"`
HTTP *httpConfig `mapstructure:"http"`
}
type socks5Config struct {
Listen string `mapstructure:"listen"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
DisableUDP bool `mapstructure:"disableUDP"`
}
type httpConfig struct {
Listen string `mapstructure:"listen"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
Realm string `mapstructure:"realm"`
}
// Config validates the fields and returns a ready-to-use Hysteria client config
func (c *clientConfig) Config() (*client.Config, error) {
hyConfig := &client.Config{}
// ServerAddr
if c.Server == "" {
return nil, configError{Field: "server", Err: errors.New("server address is empty")}
}
host, hostPort := parseServerAddrString(c.Server)
addr, err := net.ResolveUDPAddr("udp", hostPort)
if err != nil {
return nil, configError{Field: "server", Err: err}
}
hyConfig.ServerAddr = addr
// Auth
hyConfig.Auth = c.Auth
// TLSConfig
if c.TLS.SNI == "" {
// Use server hostname as SNI
hyConfig.TLSConfig.ServerName = host
} else {
hyConfig.TLSConfig.ServerName = c.TLS.SNI
}
hyConfig.TLSConfig.InsecureSkipVerify = c.TLS.Insecure
if c.TLS.CA != "" {
ca, err := os.ReadFile(c.TLS.CA)
if err != nil {
return nil, configError{Field: "tls.ca", Err: err}
}
cPool := x509.NewCertPool()
if !cPool.AppendCertsFromPEM(ca) {
return nil, configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")}
}
hyConfig.TLSConfig.RootCAs = cPool
}
// QUICConfig
hyConfig.QUICConfig = client.QUICConfig{
InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow,
MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow,
InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow,
MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow,
MaxIdleTimeout: c.QUIC.MaxIdleTimeout,
KeepAlivePeriod: c.QUIC.KeepAlivePeriod,
DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery,
}
// BandwidthConfig
if c.Bandwidth.Up == "" || c.Bandwidth.Down == "" {
return nil, configError{Field: "bandwidth", Err: errors.New("both up and down bandwidth must be set")}
}
hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up)
if err != nil {
return nil, configError{Field: "bandwidth.up", Err: err}
}
hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down)
if err != nil {
return nil, configError{Field: "bandwidth.down", Err: err}
}
// FastOpen
hyConfig.FastOpen = c.FastOpen
return hyConfig, nil
}
func runClient(cmd *cobra.Command, args []string) {
logger.Info("client mode")
if err := viper.ReadInConfig(); err != nil {
logger.Fatal("failed to read client config", zap.Error(err))
}
config, err := viperToClientConfig()
if err != nil {
var config clientConfig
if err := viper.Unmarshal(&config); err != nil {
logger.Fatal("failed to parse client config", zap.Error(err))
}
hyConfig, err := config.Config()
if err != nil {
logger.Fatal("failed to validate client config", zap.Error(err))
}
c, err := client.NewClient(config)
c, err := client.NewClient(hyConfig)
if err != nil {
logger.Fatal("failed to initialize client", zap.Error(err))
}
defer c.Close()
// Modes
var wg sync.WaitGroup
hasMode := false
for mode, fn := range modeMap {
v := viper.Sub(mode)
if v != nil {
hasMode = true
wg.Add(1)
go func(mode string, fn modeFunc) {
defer wg.Done()
if err := fn(v, c); err != nil {
logger.Fatal("failed to run mode", zap.String("mode", mode), zap.Error(err))
}
}(mode, fn)
}
if config.SOCKS5 != nil {
hasMode = true
wg.Add(1)
go func() {
defer wg.Done()
if err := clientSOCKS5(*config.SOCKS5, c); err != nil {
logger.Fatal("failed to run SOCKS5 server", zap.Error(err))
}
}()
}
if config.HTTP != nil {
hasMode = true
wg.Add(1)
go func() {
defer wg.Done()
if err := clientHTTP(*config.HTTP, c); err != nil {
logger.Fatal("failed to run HTTP proxy server", zap.Error(err))
}
}()
}
if !hasMode {
logger.Fatal("no mode specified")
}
wg.Wait()
}
func viperToClientConfig() (*client.Config, error) {
// Conn and address
addrStr := viper.GetString("server")
if addrStr == "" {
return nil, configError{Field: "server", Err: errors.New("server address is empty")}
}
host, hostPort := parseServerAddrString(addrStr)
addr, err := net.ResolveUDPAddr("udp", hostPort)
if err != nil {
return nil, configError{Field: "server", Err: err}
}
// TLS
tlsConfig, err := viperToClientTLSConfig(host)
if err != nil {
return nil, err
}
// QUIC
quicConfig := viperToClientQUICConfig()
// Bandwidth
bwConfig, err := viperToClientBandwidthConfig()
if err != nil {
return nil, err
}
return &client.Config{
ConnFactory: nil, // TODO
ServerAddr: addr,
Auth: viper.GetString("auth"),
TLSConfig: tlsConfig,
QUICConfig: quicConfig,
BandwidthConfig: bwConfig,
FastOpen: viper.GetBool("fastOpen"),
}, nil
}
func viperToClientTLSConfig(host string) (client.TLSConfig, error) {
config := client.TLSConfig{
ServerName: viper.GetString("tls.sni"),
InsecureSkipVerify: viper.GetBool("tls.insecure"),
}
if config.ServerName == "" {
// The user didn't specify a server name, fallback to the host part of the server address
config.ServerName = host
}
caPath := viper.GetString("tls.ca")
if caPath != "" {
ca, err := os.ReadFile(caPath)
if err != nil {
return client.TLSConfig{}, configError{Field: "tls.ca", Err: err}
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(ca) {
return client.TLSConfig{}, configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")}
}
config.RootCAs = pool
}
return config, nil
}
func viperToClientQUICConfig() client.QUICConfig {
return client.QUICConfig{
InitialStreamReceiveWindow: viper.GetUint64("quic.initStreamReceiveWindow"),
MaxStreamReceiveWindow: viper.GetUint64("quic.maxStreamReceiveWindow"),
InitialConnectionReceiveWindow: viper.GetUint64("quic.initConnReceiveWindow"),
MaxConnectionReceiveWindow: viper.GetUint64("quic.maxConnReceiveWindow"),
MaxIdleTimeout: viper.GetDuration("quic.maxIdleTimeout"),
KeepAlivePeriod: viper.GetDuration("quic.keepAlivePeriod"),
DisablePathMTUDiscovery: viper.GetBool("quic.disablePathMTUDiscovery"),
}
}
func viperToClientBandwidthConfig() (client.BandwidthConfig, error) {
bw := client.BandwidthConfig{}
upStr, downStr := viper.GetString("bandwidth.up"), viper.GetString("bandwidth.down")
if upStr == "" || downStr == "" {
return client.BandwidthConfig{}, configError{Field: "bandwidth", Err: errors.New("bandwidth.up and bandwidth.down must be set")}
}
up, err := convBandwidth(upStr)
if err != nil {
return client.BandwidthConfig{}, configError{Field: "bandwidth.up", Err: err}
}
down, err := convBandwidth(downStr)
if err != nil {
return client.BandwidthConfig{}, configError{Field: "bandwidth.down", Err: err}
}
bw.MaxTx, bw.MaxRx = up, down
return bw, nil
}
func clientSOCKS5(v *viper.Viper, c client.Client) error {
listenAddr := v.GetString("listen")
if listenAddr == "" {
func clientSOCKS5(config socks5Config, c client.Client) error {
if config.Listen == "" {
return configError{Field: "listen", Err: errors.New("listen address is empty")}
}
l, err := net.Listen("tcp", listenAddr)
l, err := net.Listen("tcp", config.Listen)
if err != nil {
return configError{Field: "listen", Err: err}
}
var authFunc func(username, password string) bool
username, password := v.GetString("username"), v.GetString("password")
username, password := config.Username, config.Password
if username != "" && password != "" {
authFunc = func(u, p string) bool {
return u == username && p == password
@@ -178,47 +199,46 @@ func clientSOCKS5(v *viper.Viper, c client.Client) error {
s := socks5.Server{
HyClient: c,
AuthFunc: authFunc,
DisableUDP: viper.GetBool("disableUDP"),
DisableUDP: config.DisableUDP,
EventLogger: &socks5Logger{},
}
logger.Info("SOCKS5 server listening", zap.String("addr", listenAddr))
logger.Info("SOCKS5 server listening", zap.String("addr", config.Listen))
return s.Serve(l)
}
func clientHTTP(v *viper.Viper, c client.Client) error {
listenAddr := v.GetString("listen")
if listenAddr == "" {
func clientHTTP(config httpConfig, c client.Client) error {
if config.Listen == "" {
return configError{Field: "listen", Err: errors.New("listen address is empty")}
}
l, err := net.Listen("tcp", listenAddr)
l, err := net.Listen("tcp", config.Listen)
if err != nil {
return configError{Field: "listen", Err: err}
}
var authFunc func(username, password string) bool
username, password := v.GetString("username"), v.GetString("password")
username, password := config.Username, config.Password
if username != "" && password != "" {
authFunc = func(u, p string) bool {
return u == username && p == password
}
}
realm := v.GetString("realm")
if realm == "" {
realm = "Hysteria"
if config.Realm == "" {
config.Realm = "Hysteria"
}
h := http.Server{
HyClient: c,
AuthFunc: authFunc,
AuthRealm: realm,
AuthRealm: config.Realm,
EventLogger: &httpLogger{},
}
logger.Info("HTTP proxy server listening", zap.String("addr", listenAddr))
logger.Info("HTTP proxy server listening", zap.String("addr", config.Listen))
return h.Serve(l)
}
// parseServerAddrString parses server address string.
// Server address can be in either "host:port" or "host" format (in which case we assume port 443).
func parseServerAddrString(addrStr string) (host, hostPort string) {
h, _, err := net.SplitHostPort(addrStr)
if err != nil {
// No port provided, use default HTTPS port
return addrStr, net.JoinHostPort(addrStr, "443")
}
return h, addrStr