From eb7e91e5ce5decf13a0581860debc815cada6319 Mon Sep 17 00:00:00 2001 From: tobyxdd Date: Fri, 30 Jun 2023 13:16:01 -0700 Subject: [PATCH] feat: rework config parsing to use viper unmarshal --- app/cmd/client.go | 276 +++++++++++++++++--------------- app/cmd/ping.go | 12 +- app/cmd/server.go | 398 +++++++++++++++++++++------------------------- 3 files changed, 341 insertions(+), 345 deletions(-) diff --git a/app/cmd/client.go b/app/cmd/client.go index f27c0f4..bf20d5f 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -6,6 +6,7 @@ import ( "net" "os" "sync" + "time" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -22,154 +23,174 @@ var clientCmd = &cobra.Command{ Run: runClient, } -type modeFunc func(*viper.Viper, client.Client) error - -var modeMap = map[string]modeFunc{ - "socks5": clientSOCKS5, - "http": clientHTTP, -} - func init() { rootCmd.AddCommand(clientCmd) } +type clientConfig struct { + Server string `mapstructure:"server"` + Auth string `mapstructure:"auth"` + TLS struct { + SNI string `mapstructure:"sni"` + Insecure bool `mapstructure:"insecure"` + CA string `mapstructure:"ca"` + } `mapstructure:"tls"` + QUIC struct { + InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` + MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` + InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` + MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` + MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` + KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"` + DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` + } `mapstructure:"quic"` + Bandwidth struct { + Up string `mapstructure:"up"` + Down string `mapstructure:"down"` + } `mapstructure:"bandwidth"` + FastOpen bool `mapstructure:"fastOpen"` + SOCKS5 *socks5Config `mapstructure:"socks5"` + HTTP *httpConfig `mapstructure:"http"` +} + +type socks5Config struct { + Listen string `mapstructure:"listen"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + DisableUDP bool `mapstructure:"disableUDP"` +} + +type httpConfig struct { + Listen string `mapstructure:"listen"` + Username string `mapstructure:"username"` + Password string `mapstructure:"password"` + Realm string `mapstructure:"realm"` +} + +// Config validates the fields and returns a ready-to-use Hysteria client config +func (c *clientConfig) Config() (*client.Config, error) { + hyConfig := &client.Config{} + // ServerAddr + if c.Server == "" { + return nil, configError{Field: "server", Err: errors.New("server address is empty")} + } + host, hostPort := parseServerAddrString(c.Server) + addr, err := net.ResolveUDPAddr("udp", hostPort) + if err != nil { + return nil, configError{Field: "server", Err: err} + } + hyConfig.ServerAddr = addr + // Auth + hyConfig.Auth = c.Auth + // TLSConfig + if c.TLS.SNI == "" { + // Use server hostname as SNI + hyConfig.TLSConfig.ServerName = host + } else { + hyConfig.TLSConfig.ServerName = c.TLS.SNI + } + hyConfig.TLSConfig.InsecureSkipVerify = c.TLS.Insecure + if c.TLS.CA != "" { + ca, err := os.ReadFile(c.TLS.CA) + if err != nil { + return nil, configError{Field: "tls.ca", Err: err} + } + cPool := x509.NewCertPool() + if !cPool.AppendCertsFromPEM(ca) { + return nil, configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")} + } + hyConfig.TLSConfig.RootCAs = cPool + } + // QUICConfig + hyConfig.QUICConfig = client.QUICConfig{ + InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow, + MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow, + InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow, + MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow, + MaxIdleTimeout: c.QUIC.MaxIdleTimeout, + KeepAlivePeriod: c.QUIC.KeepAlivePeriod, + DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery, + } + // BandwidthConfig + if c.Bandwidth.Up == "" || c.Bandwidth.Down == "" { + return nil, configError{Field: "bandwidth", Err: errors.New("both up and down bandwidth must be set")} + } + hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up) + if err != nil { + return nil, configError{Field: "bandwidth.up", Err: err} + } + hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down) + if err != nil { + return nil, configError{Field: "bandwidth.down", Err: err} + } + // FastOpen + hyConfig.FastOpen = c.FastOpen + + return hyConfig, nil +} + func runClient(cmd *cobra.Command, args []string) { logger.Info("client mode") if err := viper.ReadInConfig(); err != nil { logger.Fatal("failed to read client config", zap.Error(err)) } - config, err := viperToClientConfig() - if err != nil { + var config clientConfig + if err := viper.Unmarshal(&config); err != nil { logger.Fatal("failed to parse client config", zap.Error(err)) } + hyConfig, err := config.Config() + if err != nil { + logger.Fatal("failed to validate client config", zap.Error(err)) + } - c, err := client.NewClient(config) + c, err := client.NewClient(hyConfig) if err != nil { logger.Fatal("failed to initialize client", zap.Error(err)) } defer c.Close() + // Modes var wg sync.WaitGroup hasMode := false - for mode, fn := range modeMap { - v := viper.Sub(mode) - if v != nil { - hasMode = true - wg.Add(1) - go func(mode string, fn modeFunc) { - defer wg.Done() - if err := fn(v, c); err != nil { - logger.Fatal("failed to run mode", zap.String("mode", mode), zap.Error(err)) - } - }(mode, fn) - } + + if config.SOCKS5 != nil { + hasMode = true + wg.Add(1) + go func() { + defer wg.Done() + if err := clientSOCKS5(*config.SOCKS5, c); err != nil { + logger.Fatal("failed to run SOCKS5 server", zap.Error(err)) + } + }() } + if config.HTTP != nil { + hasMode = true + wg.Add(1) + go func() { + defer wg.Done() + if err := clientHTTP(*config.HTTP, c); err != nil { + logger.Fatal("failed to run HTTP proxy server", zap.Error(err)) + } + }() + } + if !hasMode { logger.Fatal("no mode specified") } wg.Wait() } -func viperToClientConfig() (*client.Config, error) { - // Conn and address - addrStr := viper.GetString("server") - if addrStr == "" { - return nil, configError{Field: "server", Err: errors.New("server address is empty")} - } - host, hostPort := parseServerAddrString(addrStr) - addr, err := net.ResolveUDPAddr("udp", hostPort) - if err != nil { - return nil, configError{Field: "server", Err: err} - } - // TLS - tlsConfig, err := viperToClientTLSConfig(host) - if err != nil { - return nil, err - } - // QUIC - quicConfig := viperToClientQUICConfig() - // Bandwidth - bwConfig, err := viperToClientBandwidthConfig() - if err != nil { - return nil, err - } - return &client.Config{ - ConnFactory: nil, // TODO - ServerAddr: addr, - Auth: viper.GetString("auth"), - TLSConfig: tlsConfig, - QUICConfig: quicConfig, - BandwidthConfig: bwConfig, - FastOpen: viper.GetBool("fastOpen"), - }, nil -} - -func viperToClientTLSConfig(host string) (client.TLSConfig, error) { - config := client.TLSConfig{ - ServerName: viper.GetString("tls.sni"), - InsecureSkipVerify: viper.GetBool("tls.insecure"), - } - if config.ServerName == "" { - // The user didn't specify a server name, fallback to the host part of the server address - config.ServerName = host - } - caPath := viper.GetString("tls.ca") - if caPath != "" { - ca, err := os.ReadFile(caPath) - if err != nil { - return client.TLSConfig{}, configError{Field: "tls.ca", Err: err} - } - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(ca) { - return client.TLSConfig{}, configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")} - } - config.RootCAs = pool - } - return config, nil -} - -func viperToClientQUICConfig() client.QUICConfig { - return client.QUICConfig{ - InitialStreamReceiveWindow: viper.GetUint64("quic.initStreamReceiveWindow"), - MaxStreamReceiveWindow: viper.GetUint64("quic.maxStreamReceiveWindow"), - InitialConnectionReceiveWindow: viper.GetUint64("quic.initConnReceiveWindow"), - MaxConnectionReceiveWindow: viper.GetUint64("quic.maxConnReceiveWindow"), - MaxIdleTimeout: viper.GetDuration("quic.maxIdleTimeout"), - KeepAlivePeriod: viper.GetDuration("quic.keepAlivePeriod"), - DisablePathMTUDiscovery: viper.GetBool("quic.disablePathMTUDiscovery"), - } -} - -func viperToClientBandwidthConfig() (client.BandwidthConfig, error) { - bw := client.BandwidthConfig{} - upStr, downStr := viper.GetString("bandwidth.up"), viper.GetString("bandwidth.down") - if upStr == "" || downStr == "" { - return client.BandwidthConfig{}, configError{Field: "bandwidth", Err: errors.New("bandwidth.up and bandwidth.down must be set")} - } - up, err := convBandwidth(upStr) - if err != nil { - return client.BandwidthConfig{}, configError{Field: "bandwidth.up", Err: err} - } - down, err := convBandwidth(downStr) - if err != nil { - return client.BandwidthConfig{}, configError{Field: "bandwidth.down", Err: err} - } - bw.MaxTx, bw.MaxRx = up, down - return bw, nil -} - -func clientSOCKS5(v *viper.Viper, c client.Client) error { - listenAddr := v.GetString("listen") - if listenAddr == "" { +func clientSOCKS5(config socks5Config, c client.Client) error { + if config.Listen == "" { return configError{Field: "listen", Err: errors.New("listen address is empty")} } - l, err := net.Listen("tcp", listenAddr) + l, err := net.Listen("tcp", config.Listen) if err != nil { return configError{Field: "listen", Err: err} } var authFunc func(username, password string) bool - username, password := v.GetString("username"), v.GetString("password") + username, password := config.Username, config.Password if username != "" && password != "" { authFunc = func(u, p string) bool { return u == username && p == password @@ -178,47 +199,46 @@ func clientSOCKS5(v *viper.Viper, c client.Client) error { s := socks5.Server{ HyClient: c, AuthFunc: authFunc, - DisableUDP: viper.GetBool("disableUDP"), + DisableUDP: config.DisableUDP, EventLogger: &socks5Logger{}, } - logger.Info("SOCKS5 server listening", zap.String("addr", listenAddr)) + logger.Info("SOCKS5 server listening", zap.String("addr", config.Listen)) return s.Serve(l) } -func clientHTTP(v *viper.Viper, c client.Client) error { - listenAddr := v.GetString("listen") - if listenAddr == "" { +func clientHTTP(config httpConfig, c client.Client) error { + if config.Listen == "" { return configError{Field: "listen", Err: errors.New("listen address is empty")} } - l, err := net.Listen("tcp", listenAddr) + l, err := net.Listen("tcp", config.Listen) if err != nil { return configError{Field: "listen", Err: err} } var authFunc func(username, password string) bool - username, password := v.GetString("username"), v.GetString("password") + username, password := config.Username, config.Password if username != "" && password != "" { authFunc = func(u, p string) bool { return u == username && p == password } } - realm := v.GetString("realm") - if realm == "" { - realm = "Hysteria" + if config.Realm == "" { + config.Realm = "Hysteria" } h := http.Server{ HyClient: c, AuthFunc: authFunc, - AuthRealm: realm, + AuthRealm: config.Realm, EventLogger: &httpLogger{}, } - logger.Info("HTTP proxy server listening", zap.String("addr", listenAddr)) + logger.Info("HTTP proxy server listening", zap.String("addr", config.Listen)) return h.Serve(l) } +// parseServerAddrString parses server address string. +// Server address can be in either "host:port" or "host" format (in which case we assume port 443). func parseServerAddrString(addrStr string) (host, hostPort string) { h, _, err := net.SplitHostPort(addrStr) if err != nil { - // No port provided, use default HTTPS port return addrStr, net.JoinHostPort(addrStr, "443") } return h, addrStr diff --git a/app/cmd/ping.go b/app/cmd/ping.go index 0719ee0..66df837 100644 --- a/app/cmd/ping.go +++ b/app/cmd/ping.go @@ -26,19 +26,23 @@ func runPing(cmd *cobra.Command, args []string) { logger.Info("ping mode") if len(args) != 1 { - logger.Fatal("no address specified") + logger.Fatal("must specify one and only one address") } addr := args[0] if err := viper.ReadInConfig(); err != nil { logger.Fatal("failed to read client config", zap.Error(err)) } - config, err := viperToClientConfig() - if err != nil { + var config clientConfig + if err := viper.Unmarshal(&config); err != nil { logger.Fatal("failed to parse client config", zap.Error(err)) } + hyConfig, err := config.Config() + if err != nil { + logger.Fatal("failed to validate client config", zap.Error(err)) + } - c, err := client.NewClient(config) + c, err := client.NewClient(hyConfig) if err != nil { logger.Fatal("failed to initialize client", zap.Error(err)) } diff --git a/app/cmd/server.go b/app/cmd/server.go index 3658557..81745b5 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -9,6 +9,7 @@ import ( "net/http/httputil" "net/url" "strings" + "time" "github.com/apernet/hysteria/core/server" "github.com/apernet/hysteria/extras/auth" @@ -27,257 +28,202 @@ var serverCmd = &cobra.Command{ func init() { rootCmd.AddCommand(serverCmd) - initServerConfigDefaults() } -func initServerConfigDefaults() { - viper.SetDefault("listen", ":443") +type serverConfig struct { + Listen string `mapstructure:"listen"` + TLS *serverConfigTLS `mapstructure:"tls"` + ACME *serverConfigACME `mapstructure:"acme"` + QUIC struct { + InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` + MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` + InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` + MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` + MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` + MaxIncomingStreams int64 `mapstructure:"maxIncomingStreams"` + DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` + } `mapstructure:"quic"` + Bandwidth struct { + Up string `mapstructure:"up"` + Down string `mapstructure:"down"` + } `mapstructure:"bandwidth"` + DisableUDP bool `mapstructure:"disableUDP"` + Auth struct { + Type string `mapstructure:"type"` + Password string `mapstructure:"password"` + } `mapstructure:"auth"` + Masquerade struct { + Type string `mapstructure:"type"` + File struct { + Dir string `mapstructure:"dir"` + } `mapstructure:"file"` + Proxy struct { + URL string `mapstructure:"url"` + RewriteHost bool `mapstructure:"rewriteHost"` + } `mapstructure:"proxy"` + } `mapstructure:"masquerade"` } -func runServer(cmd *cobra.Command, args []string) { - logger.Info("server mode") - - if err := viper.ReadInConfig(); err != nil { - logger.Fatal("failed to read server config", zap.Error(err)) - } - config, err := viperToServerConfig() - if err != nil { - logger.Fatal("failed to parse server config", zap.Error(err)) - } - - s, err := server.NewServer(config) - if err != nil { - logger.Fatal("failed to initialize server", zap.Error(err)) - } - logger.Info("server up and running") - - if err := s.Serve(); err != nil { - logger.Fatal("failed to serve", zap.Error(err)) - } +type serverConfigTLS struct { + Cert string `mapstructure:"cert"` + Key string `mapstructure:"key"` } -func viperToServerConfig() (*server.Config, error) { +type serverConfigACME struct { + Domains []string `mapstructure:"domains"` + Email string `mapstructure:"email"` + CA string `mapstructure:"ca"` + DisableHTTP bool `mapstructure:"disableHTTP"` + DisableTLSALPN bool `mapstructure:"disableTLSALPN"` + AltHTTPPort int `mapstructure:"altHTTPPort"` + AltTLSALPNPort int `mapstructure:"altTLSALPNPort"` + Dir string `mapstructure:"dir"` +} + +// Config validates the fields and returns a ready-to-use Hysteria server config +func (c *serverConfig) Config() (*server.Config, error) { + hyConfig := &server.Config{} // Conn - conn, err := viperToServerConn() - if err != nil { - return nil, err + listenAddr := c.Listen + if listenAddr == "" { + listenAddr = ":443" } - // TLS - tlsConfig, err := viperToServerTLSConfig() - if err != nil { - return nil, err - } - // QUIC - quicConfig := viperToServerQUICConfig() - // Bandwidth - bwConfig, err := viperToServerBandwidthConfig() - if err != nil { - return nil, err - } - // Disable UDP - disableUDP := viper.GetBool("disableUDP") - // Authenticator - authenticator, err := viperToAuthenticator() - if err != nil { - return nil, err - } - // Masquerade - masqHandler, err := viperToMasqHandler() - if err != nil { - return nil, err - } - // Config - config := &server.Config{ - TLSConfig: tlsConfig, - QUICConfig: quicConfig, - Conn: conn, - Outbound: nil, // TODO - BandwidthConfig: bwConfig, - DisableUDP: disableUDP, - Authenticator: authenticator, - EventLogger: &serverLogger{}, - MasqHandler: masqHandler, - } - return config, nil -} - -func viperToServerConn() (net.PacketConn, error) { - listen := viper.GetString("listen") - if listen == "" { - return nil, configError{Field: "listen", Err: errors.New("empty listen address")} - } - uAddr, err := net.ResolveUDPAddr("udp", listen) + uAddr, err := net.ResolveUDPAddr("udp", listenAddr) if err != nil { return nil, configError{Field: "listen", Err: err} } - conn, err := net.ListenUDP("udp", uAddr) + hyConfig.Conn, err = net.ListenUDP("udp", uAddr) if err != nil { return nil, configError{Field: "listen", Err: err} } - return conn, nil -} - -func viperToServerTLSConfig() (server.TLSConfig, error) { - vTLS, vACME := viper.Sub("tls"), viper.Sub("acme") - if vTLS == nil && vACME == nil { - return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("must set either tls or acme")} + // TLSConfig + if c.TLS == nil && c.ACME == nil { + return nil, configError{Field: "tls", Err: errors.New("must set either tls or acme")} } - if vTLS != nil && vACME != nil { - return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("cannot set both tls and acme")} + if c.TLS != nil && c.ACME != nil { + return nil, configError{Field: "tls", Err: errors.New("cannot set both tls and acme")} } - if vTLS != nil { - return viperToServerTLSConfigLocal(vTLS) + if c.TLS != nil { + // Local TLS cert + if c.TLS.Cert == "" || c.TLS.Key == "" { + return nil, configError{Field: "tls", Err: errors.New("empty cert or key path")} + } + cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key) + if err != nil { + return nil, configError{Field: "tls", Err: err} + } + hyConfig.TLSConfig.Certificates = []tls.Certificate{cert} } else { - return viperToServerTLSConfigACME(vACME) - } -} - -func viperToServerTLSConfigLocal(v *viper.Viper) (server.TLSConfig, error) { - certPath, keyPath := v.GetString("cert"), v.GetString("key") - if certPath == "" || keyPath == "" { - return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("empty cert or key path")} - } - cert, err := tls.LoadX509KeyPair(certPath, keyPath) - if err != nil { - return server.TLSConfig{}, configError{Field: "tls", Err: err} - } - return server.TLSConfig{ - Certificates: []tls.Certificate{cert}, - }, nil -} - -func viperToServerTLSConfigACME(v *viper.Viper) (server.TLSConfig, error) { - dataDir := v.GetString("dir") - if dataDir == "" { - dataDir = "acme" - } - - cfg := &certmagic.Config{ - RenewalWindowRatio: certmagic.DefaultRenewalWindowRatio, - KeySource: certmagic.DefaultKeyGenerator, - Storage: &certmagic.FileStorage{Path: dataDir}, - Logger: logger, - } - issuer := certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{ - Email: v.GetString("email"), - Agreed: true, - DisableHTTPChallenge: v.GetBool("disableHTTP"), - DisableTLSALPNChallenge: v.GetBool("disableTLSALPN"), - AltHTTPPort: v.GetInt("altHTTPPort"), - AltTLSALPNPort: v.GetInt("altTLSALPNPort"), - Logger: logger, - }) - switch strings.ToLower(v.GetString("ca")) { - case "letsencrypt", "le", "": - // Default to Let's Encrypt - issuer.CA = certmagic.LetsEncryptProductionCA - case "zerossl", "zero": - issuer.CA = certmagic.ZeroSSLProductionCA - default: - return server.TLSConfig{}, configError{Field: "acme.ca", Err: errors.New("unknown CA")} - } - cfg.Issuers = []certmagic.Issuer{issuer} - - cache := certmagic.NewCache(certmagic.CacheOptions{ - GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) { - return cfg, nil - }, - Logger: logger, - }) - cfg = certmagic.New(cache, *cfg) - - domains := v.GetStringSlice("domains") - if len(domains) == 0 { - return server.TLSConfig{}, configError{Field: "acme.domains", Err: errors.New("empty domains")} - } - err := cfg.ManageSync(context.Background(), domains) - if err != nil { - return server.TLSConfig{}, configError{Field: "acme", Err: err} - } - return server.TLSConfig{ - GetCertificate: cfg.GetCertificate, - }, nil -} - -func viperToServerQUICConfig() server.QUICConfig { - return server.QUICConfig{ - InitialStreamReceiveWindow: viper.GetUint64("quic.initStreamReceiveWindow"), - MaxStreamReceiveWindow: viper.GetUint64("quic.maxStreamReceiveWindow"), - InitialConnectionReceiveWindow: viper.GetUint64("quic.initConnReceiveWindow"), - MaxConnectionReceiveWindow: viper.GetUint64("quic.maxConnReceiveWindow"), - MaxIdleTimeout: viper.GetDuration("quic.maxIdleTimeout"), - MaxIncomingStreams: viper.GetInt64("quic.maxIncomingStreams"), - DisablePathMTUDiscovery: viper.GetBool("quic.disablePathMTUDiscovery"), - } -} - -func viperToServerBandwidthConfig() (server.BandwidthConfig, error) { - bw := server.BandwidthConfig{} - upStr, downStr := viper.GetString("bandwidth.up"), viper.GetString("bandwidth.down") - if upStr != "" { - up, err := convBandwidth(upStr) - if err != nil { - return server.BandwidthConfig{}, configError{Field: "bandwidth.up", Err: err} + // ACME + dataDir := c.ACME.Dir + if dataDir == "" { + dataDir = "acme" } - bw.MaxTx = up - } - if downStr != "" { - down, err := convBandwidth(downStr) - if err != nil { - return server.BandwidthConfig{}, configError{Field: "bandwidth.down", Err: err} + cmCfg := &certmagic.Config{ + RenewalWindowRatio: certmagic.DefaultRenewalWindowRatio, + KeySource: certmagic.DefaultKeyGenerator, + Storage: &certmagic.FileStorage{Path: dataDir}, + Logger: logger, } - bw.MaxRx = down - } - return bw, nil -} + cmIssuer := certmagic.NewACMEIssuer(cmCfg, certmagic.ACMEIssuer{ + Email: c.ACME.Email, + Agreed: true, + DisableHTTPChallenge: c.ACME.DisableHTTP, + DisableTLSALPNChallenge: c.ACME.DisableTLSALPN, + AltHTTPPort: c.ACME.AltHTTPPort, + AltTLSALPNPort: c.ACME.AltTLSALPNPort, + Logger: logger, + }) + switch strings.ToLower(c.ACME.CA) { + case "letsencrypt", "le", "": + // Default to Let's Encrypt + cmIssuer.CA = certmagic.LetsEncryptProductionCA + case "zerossl", "zero": + cmIssuer.CA = certmagic.ZeroSSLProductionCA + default: + return nil, configError{Field: "acme.ca", Err: errors.New("unknown CA")} + } + cmCfg.Issuers = []certmagic.Issuer{cmIssuer} + cmCache := certmagic.NewCache(certmagic.CacheOptions{ + GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) { + return cmCfg, nil + }, + Logger: logger, + }) + cmCfg = certmagic.New(cmCache, *cmCfg) -func viperToAuthenticator() (server.Authenticator, error) { - authType := viper.GetString("auth.type") - if authType == "" { + if len(c.ACME.Domains) == 0 { + return nil, configError{Field: "acme.domains", Err: errors.New("empty domains")} + } + err = cmCfg.ManageSync(context.Background(), c.ACME.Domains) + if err != nil { + return nil, configError{Field: "acme.domains", Err: err} + } + hyConfig.TLSConfig.GetCertificate = cmCfg.GetCertificate + } + // QUICConfig + hyConfig.QUICConfig = server.QUICConfig{ + InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow, + MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow, + InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow, + MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow, + MaxIdleTimeout: c.QUIC.MaxIdleTimeout, + MaxIncomingStreams: c.QUIC.MaxIncomingStreams, + DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery, + } + // BandwidthConfig + if c.Bandwidth.Up != "" { + hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up) + if err != nil { + return nil, configError{Field: "bandwidth.up", Err: err} + } + } + if c.Bandwidth.Down != "" { + hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down) + if err != nil { + return nil, configError{Field: "bandwidth.down", Err: err} + } + } + // DisableUDP + hyConfig.DisableUDP = c.DisableUDP + // Authenticator + if c.Auth.Type == "" { return nil, configError{Field: "auth.type", Err: errors.New("empty auth type")} } - switch authType { + switch strings.ToLower(c.Auth.Type) { case "password": - pw := viper.GetString("auth.password") - if pw == "" { + if c.Auth.Password == "" { return nil, configError{Field: "auth.password", Err: errors.New("empty auth password")} } - return &auth.PasswordAuthenticator{Password: pw}, nil + hyConfig.Authenticator = &auth.PasswordAuthenticator{Password: c.Auth.Password} default: return nil, configError{Field: "auth.type", Err: errors.New("unsupported auth type")} } -} - -func viperToMasqHandler() (http.Handler, error) { - masqType := viper.GetString("masquerade.type") - if masqType == "" { - // Default to use the 404 handler - return http.NotFoundHandler(), nil - } - switch masqType { - case "404": - return http.NotFoundHandler(), nil + // EventLogger + hyConfig.EventLogger = &serverLogger{} + // MasqHandler + switch strings.ToLower(c.Masquerade.Type) { + case "", "404": + hyConfig.MasqHandler = http.NotFoundHandler() case "file": - dir := viper.GetString("masquerade.file.dir") - if dir == "" { - return nil, configError{Field: "masquerade.file.dir", Err: errors.New("empty directory")} + if c.Masquerade.File.Dir == "" { + return nil, configError{Field: "masquerade.file.dir", Err: errors.New("empty file directory")} } - return http.FileServer(http.Dir(dir)), nil + hyConfig.MasqHandler = http.FileServer(http.Dir(c.Masquerade.File.Dir)) case "proxy": - urlStr := viper.GetString("masquerade.proxy.url") - if urlStr == "" { - return nil, configError{Field: "masquerade.proxy.url", Err: errors.New("empty url")} + if c.Masquerade.Proxy.URL == "" { + return nil, configError{Field: "masquerade.proxy.url", Err: errors.New("empty proxy url")} } - u, err := url.Parse(urlStr) + u, err := url.Parse(c.Masquerade.Proxy.URL) if err != nil { return nil, configError{Field: "masquerade.proxy.url", Err: err} } - proxy := &httputil.ReverseProxy{ + hyConfig.MasqHandler = &httputil.ReverseProxy{ Rewrite: func(r *httputil.ProxyRequest) { r.SetURL(u) // SetURL rewrites the Host header, // but we don't want that if rewriteHost is false - if !viper.GetBool("masquerade.proxy.rewriteHost") { + if !c.Masquerade.Proxy.RewriteHost { r.Out.Host = r.In.Host } }, @@ -286,10 +232,36 @@ func viperToMasqHandler() (http.Handler, error) { w.WriteHeader(http.StatusBadGateway) }, } - return proxy, nil default: return nil, configError{Field: "masquerade.type", Err: errors.New("unsupported masquerade type")} } + return hyConfig, nil +} + +func runServer(cmd *cobra.Command, args []string) { + logger.Info("server mode") + + if err := viper.ReadInConfig(); err != nil { + logger.Fatal("failed to read server config", zap.Error(err)) + } + var config serverConfig + if err := viper.Unmarshal(&config); err != nil { + logger.Fatal("failed to parse server config", zap.Error(err)) + } + hyConfig, err := config.Config() + if err != nil { + logger.Fatal("failed to validate server config", zap.Error(err)) + } + + s, err := server.NewServer(hyConfig) + if err != nil { + logger.Fatal("failed to initialize server", zap.Error(err)) + } + logger.Info("server up and running") + + if err := s.Serve(); err != nil { + logger.Fatal("failed to serve", zap.Error(err)) + } } type serverLogger struct{}