feat: rework config parsing to use viper unmarshal

This commit is contained in:
tobyxdd 2023-06-30 13:16:01 -07:00
parent 8342827339
commit eb7e91e5ce
3 changed files with 341 additions and 345 deletions

View File

@ -6,6 +6,7 @@ import (
"net" "net"
"os" "os"
"sync" "sync"
"time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
@ -22,154 +23,174 @@ var clientCmd = &cobra.Command{
Run: runClient, Run: runClient,
} }
type modeFunc func(*viper.Viper, client.Client) error
var modeMap = map[string]modeFunc{
"socks5": clientSOCKS5,
"http": clientHTTP,
}
func init() { func init() {
rootCmd.AddCommand(clientCmd) rootCmd.AddCommand(clientCmd)
} }
type clientConfig struct {
Server string `mapstructure:"server"`
Auth string `mapstructure:"auth"`
TLS struct {
SNI string `mapstructure:"sni"`
Insecure bool `mapstructure:"insecure"`
CA string `mapstructure:"ca"`
} `mapstructure:"tls"`
QUIC struct {
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"`
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
} `mapstructure:"quic"`
Bandwidth struct {
Up string `mapstructure:"up"`
Down string `mapstructure:"down"`
} `mapstructure:"bandwidth"`
FastOpen bool `mapstructure:"fastOpen"`
SOCKS5 *socks5Config `mapstructure:"socks5"`
HTTP *httpConfig `mapstructure:"http"`
}
type socks5Config struct {
Listen string `mapstructure:"listen"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
DisableUDP bool `mapstructure:"disableUDP"`
}
type httpConfig struct {
Listen string `mapstructure:"listen"`
Username string `mapstructure:"username"`
Password string `mapstructure:"password"`
Realm string `mapstructure:"realm"`
}
// Config validates the fields and returns a ready-to-use Hysteria client config
func (c *clientConfig) Config() (*client.Config, error) {
hyConfig := &client.Config{}
// ServerAddr
if c.Server == "" {
return nil, configError{Field: "server", Err: errors.New("server address is empty")}
}
host, hostPort := parseServerAddrString(c.Server)
addr, err := net.ResolveUDPAddr("udp", hostPort)
if err != nil {
return nil, configError{Field: "server", Err: err}
}
hyConfig.ServerAddr = addr
// Auth
hyConfig.Auth = c.Auth
// TLSConfig
if c.TLS.SNI == "" {
// Use server hostname as SNI
hyConfig.TLSConfig.ServerName = host
} else {
hyConfig.TLSConfig.ServerName = c.TLS.SNI
}
hyConfig.TLSConfig.InsecureSkipVerify = c.TLS.Insecure
if c.TLS.CA != "" {
ca, err := os.ReadFile(c.TLS.CA)
if err != nil {
return nil, configError{Field: "tls.ca", Err: err}
}
cPool := x509.NewCertPool()
if !cPool.AppendCertsFromPEM(ca) {
return nil, configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")}
}
hyConfig.TLSConfig.RootCAs = cPool
}
// QUICConfig
hyConfig.QUICConfig = client.QUICConfig{
InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow,
MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow,
InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow,
MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow,
MaxIdleTimeout: c.QUIC.MaxIdleTimeout,
KeepAlivePeriod: c.QUIC.KeepAlivePeriod,
DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery,
}
// BandwidthConfig
if c.Bandwidth.Up == "" || c.Bandwidth.Down == "" {
return nil, configError{Field: "bandwidth", Err: errors.New("both up and down bandwidth must be set")}
}
hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up)
if err != nil {
return nil, configError{Field: "bandwidth.up", Err: err}
}
hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down)
if err != nil {
return nil, configError{Field: "bandwidth.down", Err: err}
}
// FastOpen
hyConfig.FastOpen = c.FastOpen
return hyConfig, nil
}
func runClient(cmd *cobra.Command, args []string) { func runClient(cmd *cobra.Command, args []string) {
logger.Info("client mode") logger.Info("client mode")
if err := viper.ReadInConfig(); err != nil { if err := viper.ReadInConfig(); err != nil {
logger.Fatal("failed to read client config", zap.Error(err)) logger.Fatal("failed to read client config", zap.Error(err))
} }
config, err := viperToClientConfig() var config clientConfig
if err != nil { if err := viper.Unmarshal(&config); err != nil {
logger.Fatal("failed to parse client config", zap.Error(err)) logger.Fatal("failed to parse client config", zap.Error(err))
} }
hyConfig, err := config.Config()
if err != nil {
logger.Fatal("failed to validate client config", zap.Error(err))
}
c, err := client.NewClient(config) c, err := client.NewClient(hyConfig)
if err != nil { if err != nil {
logger.Fatal("failed to initialize client", zap.Error(err)) logger.Fatal("failed to initialize client", zap.Error(err))
} }
defer c.Close() defer c.Close()
// Modes
var wg sync.WaitGroup var wg sync.WaitGroup
hasMode := false hasMode := false
for mode, fn := range modeMap {
v := viper.Sub(mode) if config.SOCKS5 != nil {
if v != nil {
hasMode = true hasMode = true
wg.Add(1) wg.Add(1)
go func(mode string, fn modeFunc) { go func() {
defer wg.Done() defer wg.Done()
if err := fn(v, c); err != nil { if err := clientSOCKS5(*config.SOCKS5, c); err != nil {
logger.Fatal("failed to run mode", zap.String("mode", mode), zap.Error(err)) logger.Fatal("failed to run SOCKS5 server", zap.Error(err))
} }
}(mode, fn) }()
} }
if config.HTTP != nil {
hasMode = true
wg.Add(1)
go func() {
defer wg.Done()
if err := clientHTTP(*config.HTTP, c); err != nil {
logger.Fatal("failed to run HTTP proxy server", zap.Error(err))
} }
}()
}
if !hasMode { if !hasMode {
logger.Fatal("no mode specified") logger.Fatal("no mode specified")
} }
wg.Wait() wg.Wait()
} }
func viperToClientConfig() (*client.Config, error) { func clientSOCKS5(config socks5Config, c client.Client) error {
// Conn and address if config.Listen == "" {
addrStr := viper.GetString("server")
if addrStr == "" {
return nil, configError{Field: "server", Err: errors.New("server address is empty")}
}
host, hostPort := parseServerAddrString(addrStr)
addr, err := net.ResolveUDPAddr("udp", hostPort)
if err != nil {
return nil, configError{Field: "server", Err: err}
}
// TLS
tlsConfig, err := viperToClientTLSConfig(host)
if err != nil {
return nil, err
}
// QUIC
quicConfig := viperToClientQUICConfig()
// Bandwidth
bwConfig, err := viperToClientBandwidthConfig()
if err != nil {
return nil, err
}
return &client.Config{
ConnFactory: nil, // TODO
ServerAddr: addr,
Auth: viper.GetString("auth"),
TLSConfig: tlsConfig,
QUICConfig: quicConfig,
BandwidthConfig: bwConfig,
FastOpen: viper.GetBool("fastOpen"),
}, nil
}
func viperToClientTLSConfig(host string) (client.TLSConfig, error) {
config := client.TLSConfig{
ServerName: viper.GetString("tls.sni"),
InsecureSkipVerify: viper.GetBool("tls.insecure"),
}
if config.ServerName == "" {
// The user didn't specify a server name, fallback to the host part of the server address
config.ServerName = host
}
caPath := viper.GetString("tls.ca")
if caPath != "" {
ca, err := os.ReadFile(caPath)
if err != nil {
return client.TLSConfig{}, configError{Field: "tls.ca", Err: err}
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(ca) {
return client.TLSConfig{}, configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")}
}
config.RootCAs = pool
}
return config, nil
}
func viperToClientQUICConfig() client.QUICConfig {
return client.QUICConfig{
InitialStreamReceiveWindow: viper.GetUint64("quic.initStreamReceiveWindow"),
MaxStreamReceiveWindow: viper.GetUint64("quic.maxStreamReceiveWindow"),
InitialConnectionReceiveWindow: viper.GetUint64("quic.initConnReceiveWindow"),
MaxConnectionReceiveWindow: viper.GetUint64("quic.maxConnReceiveWindow"),
MaxIdleTimeout: viper.GetDuration("quic.maxIdleTimeout"),
KeepAlivePeriod: viper.GetDuration("quic.keepAlivePeriod"),
DisablePathMTUDiscovery: viper.GetBool("quic.disablePathMTUDiscovery"),
}
}
func viperToClientBandwidthConfig() (client.BandwidthConfig, error) {
bw := client.BandwidthConfig{}
upStr, downStr := viper.GetString("bandwidth.up"), viper.GetString("bandwidth.down")
if upStr == "" || downStr == "" {
return client.BandwidthConfig{}, configError{Field: "bandwidth", Err: errors.New("bandwidth.up and bandwidth.down must be set")}
}
up, err := convBandwidth(upStr)
if err != nil {
return client.BandwidthConfig{}, configError{Field: "bandwidth.up", Err: err}
}
down, err := convBandwidth(downStr)
if err != nil {
return client.BandwidthConfig{}, configError{Field: "bandwidth.down", Err: err}
}
bw.MaxTx, bw.MaxRx = up, down
return bw, nil
}
func clientSOCKS5(v *viper.Viper, c client.Client) error {
listenAddr := v.GetString("listen")
if listenAddr == "" {
return configError{Field: "listen", Err: errors.New("listen address is empty")} return configError{Field: "listen", Err: errors.New("listen address is empty")}
} }
l, err := net.Listen("tcp", listenAddr) l, err := net.Listen("tcp", config.Listen)
if err != nil { if err != nil {
return configError{Field: "listen", Err: err} return configError{Field: "listen", Err: err}
} }
var authFunc func(username, password string) bool var authFunc func(username, password string) bool
username, password := v.GetString("username"), v.GetString("password") username, password := config.Username, config.Password
if username != "" && password != "" { if username != "" && password != "" {
authFunc = func(u, p string) bool { authFunc = func(u, p string) bool {
return u == username && p == password return u == username && p == password
@ -178,47 +199,46 @@ func clientSOCKS5(v *viper.Viper, c client.Client) error {
s := socks5.Server{ s := socks5.Server{
HyClient: c, HyClient: c,
AuthFunc: authFunc, AuthFunc: authFunc,
DisableUDP: viper.GetBool("disableUDP"), DisableUDP: config.DisableUDP,
EventLogger: &socks5Logger{}, EventLogger: &socks5Logger{},
} }
logger.Info("SOCKS5 server listening", zap.String("addr", listenAddr)) logger.Info("SOCKS5 server listening", zap.String("addr", config.Listen))
return s.Serve(l) return s.Serve(l)
} }
func clientHTTP(v *viper.Viper, c client.Client) error { func clientHTTP(config httpConfig, c client.Client) error {
listenAddr := v.GetString("listen") if config.Listen == "" {
if listenAddr == "" {
return configError{Field: "listen", Err: errors.New("listen address is empty")} return configError{Field: "listen", Err: errors.New("listen address is empty")}
} }
l, err := net.Listen("tcp", listenAddr) l, err := net.Listen("tcp", config.Listen)
if err != nil { if err != nil {
return configError{Field: "listen", Err: err} return configError{Field: "listen", Err: err}
} }
var authFunc func(username, password string) bool var authFunc func(username, password string) bool
username, password := v.GetString("username"), v.GetString("password") username, password := config.Username, config.Password
if username != "" && password != "" { if username != "" && password != "" {
authFunc = func(u, p string) bool { authFunc = func(u, p string) bool {
return u == username && p == password return u == username && p == password
} }
} }
realm := v.GetString("realm") if config.Realm == "" {
if realm == "" { config.Realm = "Hysteria"
realm = "Hysteria"
} }
h := http.Server{ h := http.Server{
HyClient: c, HyClient: c,
AuthFunc: authFunc, AuthFunc: authFunc,
AuthRealm: realm, AuthRealm: config.Realm,
EventLogger: &httpLogger{}, EventLogger: &httpLogger{},
} }
logger.Info("HTTP proxy server listening", zap.String("addr", listenAddr)) logger.Info("HTTP proxy server listening", zap.String("addr", config.Listen))
return h.Serve(l) return h.Serve(l)
} }
// parseServerAddrString parses server address string.
// Server address can be in either "host:port" or "host" format (in which case we assume port 443).
func parseServerAddrString(addrStr string) (host, hostPort string) { func parseServerAddrString(addrStr string) (host, hostPort string) {
h, _, err := net.SplitHostPort(addrStr) h, _, err := net.SplitHostPort(addrStr)
if err != nil { if err != nil {
// No port provided, use default HTTPS port
return addrStr, net.JoinHostPort(addrStr, "443") return addrStr, net.JoinHostPort(addrStr, "443")
} }
return h, addrStr return h, addrStr

View File

@ -26,19 +26,23 @@ func runPing(cmd *cobra.Command, args []string) {
logger.Info("ping mode") logger.Info("ping mode")
if len(args) != 1 { if len(args) != 1 {
logger.Fatal("no address specified") logger.Fatal("must specify one and only one address")
} }
addr := args[0] addr := args[0]
if err := viper.ReadInConfig(); err != nil { if err := viper.ReadInConfig(); err != nil {
logger.Fatal("failed to read client config", zap.Error(err)) logger.Fatal("failed to read client config", zap.Error(err))
} }
config, err := viperToClientConfig() var config clientConfig
if err != nil { if err := viper.Unmarshal(&config); err != nil {
logger.Fatal("failed to parse client config", zap.Error(err)) logger.Fatal("failed to parse client config", zap.Error(err))
} }
hyConfig, err := config.Config()
if err != nil {
logger.Fatal("failed to validate client config", zap.Error(err))
}
c, err := client.NewClient(config) c, err := client.NewClient(hyConfig)
if err != nil { if err != nil {
logger.Fatal("failed to initialize client", zap.Error(err)) logger.Fatal("failed to initialize client", zap.Error(err))
} }

View File

@ -9,6 +9,7 @@ import (
"net/http/httputil" "net/http/httputil"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/apernet/hysteria/core/server" "github.com/apernet/hysteria/core/server"
"github.com/apernet/hysteria/extras/auth" "github.com/apernet/hysteria/extras/auth"
@ -27,257 +28,202 @@ var serverCmd = &cobra.Command{
func init() { func init() {
rootCmd.AddCommand(serverCmd) rootCmd.AddCommand(serverCmd)
initServerConfigDefaults()
} }
func initServerConfigDefaults() { type serverConfig struct {
viper.SetDefault("listen", ":443") Listen string `mapstructure:"listen"`
TLS *serverConfigTLS `mapstructure:"tls"`
ACME *serverConfigACME `mapstructure:"acme"`
QUIC struct {
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
MaxIncomingStreams int64 `mapstructure:"maxIncomingStreams"`
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
} `mapstructure:"quic"`
Bandwidth struct {
Up string `mapstructure:"up"`
Down string `mapstructure:"down"`
} `mapstructure:"bandwidth"`
DisableUDP bool `mapstructure:"disableUDP"`
Auth struct {
Type string `mapstructure:"type"`
Password string `mapstructure:"password"`
} `mapstructure:"auth"`
Masquerade struct {
Type string `mapstructure:"type"`
File struct {
Dir string `mapstructure:"dir"`
} `mapstructure:"file"`
Proxy struct {
URL string `mapstructure:"url"`
RewriteHost bool `mapstructure:"rewriteHost"`
} `mapstructure:"proxy"`
} `mapstructure:"masquerade"`
} }
func runServer(cmd *cobra.Command, args []string) { type serverConfigTLS struct {
logger.Info("server mode") Cert string `mapstructure:"cert"`
Key string `mapstructure:"key"`
if err := viper.ReadInConfig(); err != nil {
logger.Fatal("failed to read server config", zap.Error(err))
}
config, err := viperToServerConfig()
if err != nil {
logger.Fatal("failed to parse server config", zap.Error(err))
} }
s, err := server.NewServer(config) type serverConfigACME struct {
if err != nil { Domains []string `mapstructure:"domains"`
logger.Fatal("failed to initialize server", zap.Error(err)) Email string `mapstructure:"email"`
} CA string `mapstructure:"ca"`
logger.Info("server up and running") DisableHTTP bool `mapstructure:"disableHTTP"`
DisableTLSALPN bool `mapstructure:"disableTLSALPN"`
if err := s.Serve(); err != nil { AltHTTPPort int `mapstructure:"altHTTPPort"`
logger.Fatal("failed to serve", zap.Error(err)) AltTLSALPNPort int `mapstructure:"altTLSALPNPort"`
} Dir string `mapstructure:"dir"`
} }
func viperToServerConfig() (*server.Config, error) { // Config validates the fields and returns a ready-to-use Hysteria server config
func (c *serverConfig) Config() (*server.Config, error) {
hyConfig := &server.Config{}
// Conn // Conn
conn, err := viperToServerConn() listenAddr := c.Listen
if err != nil { if listenAddr == "" {
return nil, err listenAddr = ":443"
} }
// TLS uAddr, err := net.ResolveUDPAddr("udp", listenAddr)
tlsConfig, err := viperToServerTLSConfig()
if err != nil {
return nil, err
}
// QUIC
quicConfig := viperToServerQUICConfig()
// Bandwidth
bwConfig, err := viperToServerBandwidthConfig()
if err != nil {
return nil, err
}
// Disable UDP
disableUDP := viper.GetBool("disableUDP")
// Authenticator
authenticator, err := viperToAuthenticator()
if err != nil {
return nil, err
}
// Masquerade
masqHandler, err := viperToMasqHandler()
if err != nil {
return nil, err
}
// Config
config := &server.Config{
TLSConfig: tlsConfig,
QUICConfig: quicConfig,
Conn: conn,
Outbound: nil, // TODO
BandwidthConfig: bwConfig,
DisableUDP: disableUDP,
Authenticator: authenticator,
EventLogger: &serverLogger{},
MasqHandler: masqHandler,
}
return config, nil
}
func viperToServerConn() (net.PacketConn, error) {
listen := viper.GetString("listen")
if listen == "" {
return nil, configError{Field: "listen", Err: errors.New("empty listen address")}
}
uAddr, err := net.ResolveUDPAddr("udp", listen)
if err != nil { if err != nil {
return nil, configError{Field: "listen", Err: err} return nil, configError{Field: "listen", Err: err}
} }
conn, err := net.ListenUDP("udp", uAddr) hyConfig.Conn, err = net.ListenUDP("udp", uAddr)
if err != nil { if err != nil {
return nil, configError{Field: "listen", Err: err} return nil, configError{Field: "listen", Err: err}
} }
return conn, nil // TLSConfig
if c.TLS == nil && c.ACME == nil {
return nil, configError{Field: "tls", Err: errors.New("must set either tls or acme")}
} }
if c.TLS != nil && c.ACME != nil {
func viperToServerTLSConfig() (server.TLSConfig, error) { return nil, configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
vTLS, vACME := viper.Sub("tls"), viper.Sub("acme")
if vTLS == nil && vACME == nil {
return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("must set either tls or acme")}
} }
if vTLS != nil && vACME != nil { if c.TLS != nil {
return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("cannot set both tls and acme")} // Local TLS cert
if c.TLS.Cert == "" || c.TLS.Key == "" {
return nil, configError{Field: "tls", Err: errors.New("empty cert or key path")}
} }
if vTLS != nil { cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key)
return viperToServerTLSConfigLocal(vTLS) if err != nil {
return nil, configError{Field: "tls", Err: err}
}
hyConfig.TLSConfig.Certificates = []tls.Certificate{cert}
} else { } else {
return viperToServerTLSConfigACME(vACME) // ACME
} dataDir := c.ACME.Dir
}
func viperToServerTLSConfigLocal(v *viper.Viper) (server.TLSConfig, error) {
certPath, keyPath := v.GetString("cert"), v.GetString("key")
if certPath == "" || keyPath == "" {
return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("empty cert or key path")}
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return server.TLSConfig{}, configError{Field: "tls", Err: err}
}
return server.TLSConfig{
Certificates: []tls.Certificate{cert},
}, nil
}
func viperToServerTLSConfigACME(v *viper.Viper) (server.TLSConfig, error) {
dataDir := v.GetString("dir")
if dataDir == "" { if dataDir == "" {
dataDir = "acme" dataDir = "acme"
} }
cmCfg := &certmagic.Config{
cfg := &certmagic.Config{
RenewalWindowRatio: certmagic.DefaultRenewalWindowRatio, RenewalWindowRatio: certmagic.DefaultRenewalWindowRatio,
KeySource: certmagic.DefaultKeyGenerator, KeySource: certmagic.DefaultKeyGenerator,
Storage: &certmagic.FileStorage{Path: dataDir}, Storage: &certmagic.FileStorage{Path: dataDir},
Logger: logger, Logger: logger,
} }
issuer := certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{ cmIssuer := certmagic.NewACMEIssuer(cmCfg, certmagic.ACMEIssuer{
Email: v.GetString("email"), Email: c.ACME.Email,
Agreed: true, Agreed: true,
DisableHTTPChallenge: v.GetBool("disableHTTP"), DisableHTTPChallenge: c.ACME.DisableHTTP,
DisableTLSALPNChallenge: v.GetBool("disableTLSALPN"), DisableTLSALPNChallenge: c.ACME.DisableTLSALPN,
AltHTTPPort: v.GetInt("altHTTPPort"), AltHTTPPort: c.ACME.AltHTTPPort,
AltTLSALPNPort: v.GetInt("altTLSALPNPort"), AltTLSALPNPort: c.ACME.AltTLSALPNPort,
Logger: logger, Logger: logger,
}) })
switch strings.ToLower(v.GetString("ca")) { switch strings.ToLower(c.ACME.CA) {
case "letsencrypt", "le", "": case "letsencrypt", "le", "":
// Default to Let's Encrypt // Default to Let's Encrypt
issuer.CA = certmagic.LetsEncryptProductionCA cmIssuer.CA = certmagic.LetsEncryptProductionCA
case "zerossl", "zero": case "zerossl", "zero":
issuer.CA = certmagic.ZeroSSLProductionCA cmIssuer.CA = certmagic.ZeroSSLProductionCA
default: default:
return server.TLSConfig{}, configError{Field: "acme.ca", Err: errors.New("unknown CA")} return nil, configError{Field: "acme.ca", Err: errors.New("unknown CA")}
} }
cfg.Issuers = []certmagic.Issuer{issuer} cmCfg.Issuers = []certmagic.Issuer{cmIssuer}
cmCache := certmagic.NewCache(certmagic.CacheOptions{
cache := certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) { GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
return cfg, nil return cmCfg, nil
}, },
Logger: logger, Logger: logger,
}) })
cfg = certmagic.New(cache, *cfg) cmCfg = certmagic.New(cmCache, *cmCfg)
domains := v.GetStringSlice("domains") if len(c.ACME.Domains) == 0 {
if len(domains) == 0 { return nil, configError{Field: "acme.domains", Err: errors.New("empty domains")}
return server.TLSConfig{}, configError{Field: "acme.domains", Err: errors.New("empty domains")}
} }
err := cfg.ManageSync(context.Background(), domains) err = cmCfg.ManageSync(context.Background(), c.ACME.Domains)
if err != nil { if err != nil {
return server.TLSConfig{}, configError{Field: "acme", Err: err} return nil, configError{Field: "acme.domains", Err: err}
} }
return server.TLSConfig{ hyConfig.TLSConfig.GetCertificate = cmCfg.GetCertificate
GetCertificate: cfg.GetCertificate,
}, nil
} }
// QUICConfig
func viperToServerQUICConfig() server.QUICConfig { hyConfig.QUICConfig = server.QUICConfig{
return server.QUICConfig{ InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow,
InitialStreamReceiveWindow: viper.GetUint64("quic.initStreamReceiveWindow"), MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow,
MaxStreamReceiveWindow: viper.GetUint64("quic.maxStreamReceiveWindow"), InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow,
InitialConnectionReceiveWindow: viper.GetUint64("quic.initConnReceiveWindow"), MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow,
MaxConnectionReceiveWindow: viper.GetUint64("quic.maxConnReceiveWindow"), MaxIdleTimeout: c.QUIC.MaxIdleTimeout,
MaxIdleTimeout: viper.GetDuration("quic.maxIdleTimeout"), MaxIncomingStreams: c.QUIC.MaxIncomingStreams,
MaxIncomingStreams: viper.GetInt64("quic.maxIncomingStreams"), DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery,
DisablePathMTUDiscovery: viper.GetBool("quic.disablePathMTUDiscovery"),
} }
} // BandwidthConfig
if c.Bandwidth.Up != "" {
func viperToServerBandwidthConfig() (server.BandwidthConfig, error) { hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up)
bw := server.BandwidthConfig{}
upStr, downStr := viper.GetString("bandwidth.up"), viper.GetString("bandwidth.down")
if upStr != "" {
up, err := convBandwidth(upStr)
if err != nil { if err != nil {
return server.BandwidthConfig{}, configError{Field: "bandwidth.up", Err: err} return nil, configError{Field: "bandwidth.up", Err: err}
} }
bw.MaxTx = up
} }
if downStr != "" { if c.Bandwidth.Down != "" {
down, err := convBandwidth(downStr) hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down)
if err != nil { if err != nil {
return server.BandwidthConfig{}, configError{Field: "bandwidth.down", Err: err} return nil, configError{Field: "bandwidth.down", Err: err}
} }
bw.MaxRx = down
} }
return bw, nil // DisableUDP
} hyConfig.DisableUDP = c.DisableUDP
// Authenticator
func viperToAuthenticator() (server.Authenticator, error) { if c.Auth.Type == "" {
authType := viper.GetString("auth.type")
if authType == "" {
return nil, configError{Field: "auth.type", Err: errors.New("empty auth type")} return nil, configError{Field: "auth.type", Err: errors.New("empty auth type")}
} }
switch authType { switch strings.ToLower(c.Auth.Type) {
case "password": case "password":
pw := viper.GetString("auth.password") if c.Auth.Password == "" {
if pw == "" {
return nil, configError{Field: "auth.password", Err: errors.New("empty auth password")} return nil, configError{Field: "auth.password", Err: errors.New("empty auth password")}
} }
return &auth.PasswordAuthenticator{Password: pw}, nil hyConfig.Authenticator = &auth.PasswordAuthenticator{Password: c.Auth.Password}
default: default:
return nil, configError{Field: "auth.type", Err: errors.New("unsupported auth type")} return nil, configError{Field: "auth.type", Err: errors.New("unsupported auth type")}
} }
} // EventLogger
hyConfig.EventLogger = &serverLogger{}
func viperToMasqHandler() (http.Handler, error) { // MasqHandler
masqType := viper.GetString("masquerade.type") switch strings.ToLower(c.Masquerade.Type) {
if masqType == "" { case "", "404":
// Default to use the 404 handler hyConfig.MasqHandler = http.NotFoundHandler()
return http.NotFoundHandler(), nil
}
switch masqType {
case "404":
return http.NotFoundHandler(), nil
case "file": case "file":
dir := viper.GetString("masquerade.file.dir") if c.Masquerade.File.Dir == "" {
if dir == "" { return nil, configError{Field: "masquerade.file.dir", Err: errors.New("empty file directory")}
return nil, configError{Field: "masquerade.file.dir", Err: errors.New("empty directory")}
} }
return http.FileServer(http.Dir(dir)), nil hyConfig.MasqHandler = http.FileServer(http.Dir(c.Masquerade.File.Dir))
case "proxy": case "proxy":
urlStr := viper.GetString("masquerade.proxy.url") if c.Masquerade.Proxy.URL == "" {
if urlStr == "" { return nil, configError{Field: "masquerade.proxy.url", Err: errors.New("empty proxy url")}
return nil, configError{Field: "masquerade.proxy.url", Err: errors.New("empty url")}
} }
u, err := url.Parse(urlStr) u, err := url.Parse(c.Masquerade.Proxy.URL)
if err != nil { if err != nil {
return nil, configError{Field: "masquerade.proxy.url", Err: err} return nil, configError{Field: "masquerade.proxy.url", Err: err}
} }
proxy := &httputil.ReverseProxy{ hyConfig.MasqHandler = &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) { Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(u) r.SetURL(u)
// SetURL rewrites the Host header, // SetURL rewrites the Host header,
// but we don't want that if rewriteHost is false // but we don't want that if rewriteHost is false
if !viper.GetBool("masquerade.proxy.rewriteHost") { if !c.Masquerade.Proxy.RewriteHost {
r.Out.Host = r.In.Host r.Out.Host = r.In.Host
} }
}, },
@ -286,10 +232,36 @@ func viperToMasqHandler() (http.Handler, error) {
w.WriteHeader(http.StatusBadGateway) w.WriteHeader(http.StatusBadGateway)
}, },
} }
return proxy, nil
default: default:
return nil, configError{Field: "masquerade.type", Err: errors.New("unsupported masquerade type")} return nil, configError{Field: "masquerade.type", Err: errors.New("unsupported masquerade type")}
} }
return hyConfig, nil
}
func runServer(cmd *cobra.Command, args []string) {
logger.Info("server mode")
if err := viper.ReadInConfig(); err != nil {
logger.Fatal("failed to read server config", zap.Error(err))
}
var config serverConfig
if err := viper.Unmarshal(&config); err != nil {
logger.Fatal("failed to parse server config", zap.Error(err))
}
hyConfig, err := config.Config()
if err != nil {
logger.Fatal("failed to validate server config", zap.Error(err))
}
s, err := server.NewServer(hyConfig)
if err != nil {
logger.Fatal("failed to initialize server", zap.Error(err))
}
logger.Info("server up and running")
if err := s.Serve(); err != nil {
logger.Fatal("failed to serve", zap.Error(err))
}
} }
type serverLogger struct{} type serverLogger struct{}