mirror of
https://github.com/cmz0228/hysteria-dev.git
synced 2025-06-09 13:59:54 +00:00
feat: rework config parsing to use viper unmarshal
This commit is contained in:
parent
8342827339
commit
eb7e91e5ce
@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
@ -22,154 +23,174 @@ var clientCmd = &cobra.Command{
|
||||
Run: runClient,
|
||||
}
|
||||
|
||||
type modeFunc func(*viper.Viper, client.Client) error
|
||||
|
||||
var modeMap = map[string]modeFunc{
|
||||
"socks5": clientSOCKS5,
|
||||
"http": clientHTTP,
|
||||
}
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(clientCmd)
|
||||
}
|
||||
|
||||
type clientConfig struct {
|
||||
Server string `mapstructure:"server"`
|
||||
Auth string `mapstructure:"auth"`
|
||||
TLS struct {
|
||||
SNI string `mapstructure:"sni"`
|
||||
Insecure bool `mapstructure:"insecure"`
|
||||
CA string `mapstructure:"ca"`
|
||||
} `mapstructure:"tls"`
|
||||
QUIC struct {
|
||||
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
|
||||
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
|
||||
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
|
||||
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
|
||||
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
|
||||
KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"`
|
||||
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
|
||||
} `mapstructure:"quic"`
|
||||
Bandwidth struct {
|
||||
Up string `mapstructure:"up"`
|
||||
Down string `mapstructure:"down"`
|
||||
} `mapstructure:"bandwidth"`
|
||||
FastOpen bool `mapstructure:"fastOpen"`
|
||||
SOCKS5 *socks5Config `mapstructure:"socks5"`
|
||||
HTTP *httpConfig `mapstructure:"http"`
|
||||
}
|
||||
|
||||
type socks5Config struct {
|
||||
Listen string `mapstructure:"listen"`
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
DisableUDP bool `mapstructure:"disableUDP"`
|
||||
}
|
||||
|
||||
type httpConfig struct {
|
||||
Listen string `mapstructure:"listen"`
|
||||
Username string `mapstructure:"username"`
|
||||
Password string `mapstructure:"password"`
|
||||
Realm string `mapstructure:"realm"`
|
||||
}
|
||||
|
||||
// Config validates the fields and returns a ready-to-use Hysteria client config
|
||||
func (c *clientConfig) Config() (*client.Config, error) {
|
||||
hyConfig := &client.Config{}
|
||||
// ServerAddr
|
||||
if c.Server == "" {
|
||||
return nil, configError{Field: "server", Err: errors.New("server address is empty")}
|
||||
}
|
||||
host, hostPort := parseServerAddrString(c.Server)
|
||||
addr, err := net.ResolveUDPAddr("udp", hostPort)
|
||||
if err != nil {
|
||||
return nil, configError{Field: "server", Err: err}
|
||||
}
|
||||
hyConfig.ServerAddr = addr
|
||||
// Auth
|
||||
hyConfig.Auth = c.Auth
|
||||
// TLSConfig
|
||||
if c.TLS.SNI == "" {
|
||||
// Use server hostname as SNI
|
||||
hyConfig.TLSConfig.ServerName = host
|
||||
} else {
|
||||
hyConfig.TLSConfig.ServerName = c.TLS.SNI
|
||||
}
|
||||
hyConfig.TLSConfig.InsecureSkipVerify = c.TLS.Insecure
|
||||
if c.TLS.CA != "" {
|
||||
ca, err := os.ReadFile(c.TLS.CA)
|
||||
if err != nil {
|
||||
return nil, configError{Field: "tls.ca", Err: err}
|
||||
}
|
||||
cPool := x509.NewCertPool()
|
||||
if !cPool.AppendCertsFromPEM(ca) {
|
||||
return nil, configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")}
|
||||
}
|
||||
hyConfig.TLSConfig.RootCAs = cPool
|
||||
}
|
||||
// QUICConfig
|
||||
hyConfig.QUICConfig = client.QUICConfig{
|
||||
InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow,
|
||||
MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow,
|
||||
InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow,
|
||||
MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow,
|
||||
MaxIdleTimeout: c.QUIC.MaxIdleTimeout,
|
||||
KeepAlivePeriod: c.QUIC.KeepAlivePeriod,
|
||||
DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery,
|
||||
}
|
||||
// BandwidthConfig
|
||||
if c.Bandwidth.Up == "" || c.Bandwidth.Down == "" {
|
||||
return nil, configError{Field: "bandwidth", Err: errors.New("both up and down bandwidth must be set")}
|
||||
}
|
||||
hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up)
|
||||
if err != nil {
|
||||
return nil, configError{Field: "bandwidth.up", Err: err}
|
||||
}
|
||||
hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down)
|
||||
if err != nil {
|
||||
return nil, configError{Field: "bandwidth.down", Err: err}
|
||||
}
|
||||
// FastOpen
|
||||
hyConfig.FastOpen = c.FastOpen
|
||||
|
||||
return hyConfig, nil
|
||||
}
|
||||
|
||||
func runClient(cmd *cobra.Command, args []string) {
|
||||
logger.Info("client mode")
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
logger.Fatal("failed to read client config", zap.Error(err))
|
||||
}
|
||||
config, err := viperToClientConfig()
|
||||
if err != nil {
|
||||
var config clientConfig
|
||||
if err := viper.Unmarshal(&config); err != nil {
|
||||
logger.Fatal("failed to parse client config", zap.Error(err))
|
||||
}
|
||||
hyConfig, err := config.Config()
|
||||
if err != nil {
|
||||
logger.Fatal("failed to validate client config", zap.Error(err))
|
||||
}
|
||||
|
||||
c, err := client.NewClient(config)
|
||||
c, err := client.NewClient(hyConfig)
|
||||
if err != nil {
|
||||
logger.Fatal("failed to initialize client", zap.Error(err))
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// Modes
|
||||
var wg sync.WaitGroup
|
||||
hasMode := false
|
||||
for mode, fn := range modeMap {
|
||||
v := viper.Sub(mode)
|
||||
if v != nil {
|
||||
|
||||
if config.SOCKS5 != nil {
|
||||
hasMode = true
|
||||
wg.Add(1)
|
||||
go func(mode string, fn modeFunc) {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := fn(v, c); err != nil {
|
||||
logger.Fatal("failed to run mode", zap.String("mode", mode), zap.Error(err))
|
||||
if err := clientSOCKS5(*config.SOCKS5, c); err != nil {
|
||||
logger.Fatal("failed to run SOCKS5 server", zap.Error(err))
|
||||
}
|
||||
}(mode, fn)
|
||||
}()
|
||||
}
|
||||
if config.HTTP != nil {
|
||||
hasMode = true
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := clientHTTP(*config.HTTP, c); err != nil {
|
||||
logger.Fatal("failed to run HTTP proxy server", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if !hasMode {
|
||||
logger.Fatal("no mode specified")
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func viperToClientConfig() (*client.Config, error) {
|
||||
// Conn and address
|
||||
addrStr := viper.GetString("server")
|
||||
if addrStr == "" {
|
||||
return nil, configError{Field: "server", Err: errors.New("server address is empty")}
|
||||
}
|
||||
host, hostPort := parseServerAddrString(addrStr)
|
||||
addr, err := net.ResolveUDPAddr("udp", hostPort)
|
||||
if err != nil {
|
||||
return nil, configError{Field: "server", Err: err}
|
||||
}
|
||||
// TLS
|
||||
tlsConfig, err := viperToClientTLSConfig(host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// QUIC
|
||||
quicConfig := viperToClientQUICConfig()
|
||||
// Bandwidth
|
||||
bwConfig, err := viperToClientBandwidthConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &client.Config{
|
||||
ConnFactory: nil, // TODO
|
||||
ServerAddr: addr,
|
||||
Auth: viper.GetString("auth"),
|
||||
TLSConfig: tlsConfig,
|
||||
QUICConfig: quicConfig,
|
||||
BandwidthConfig: bwConfig,
|
||||
FastOpen: viper.GetBool("fastOpen"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func viperToClientTLSConfig(host string) (client.TLSConfig, error) {
|
||||
config := client.TLSConfig{
|
||||
ServerName: viper.GetString("tls.sni"),
|
||||
InsecureSkipVerify: viper.GetBool("tls.insecure"),
|
||||
}
|
||||
if config.ServerName == "" {
|
||||
// The user didn't specify a server name, fallback to the host part of the server address
|
||||
config.ServerName = host
|
||||
}
|
||||
caPath := viper.GetString("tls.ca")
|
||||
if caPath != "" {
|
||||
ca, err := os.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return client.TLSConfig{}, configError{Field: "tls.ca", Err: err}
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(ca) {
|
||||
return client.TLSConfig{}, configError{Field: "tls.ca", Err: errors.New("failed to parse CA certificate")}
|
||||
}
|
||||
config.RootCAs = pool
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func viperToClientQUICConfig() client.QUICConfig {
|
||||
return client.QUICConfig{
|
||||
InitialStreamReceiveWindow: viper.GetUint64("quic.initStreamReceiveWindow"),
|
||||
MaxStreamReceiveWindow: viper.GetUint64("quic.maxStreamReceiveWindow"),
|
||||
InitialConnectionReceiveWindow: viper.GetUint64("quic.initConnReceiveWindow"),
|
||||
MaxConnectionReceiveWindow: viper.GetUint64("quic.maxConnReceiveWindow"),
|
||||
MaxIdleTimeout: viper.GetDuration("quic.maxIdleTimeout"),
|
||||
KeepAlivePeriod: viper.GetDuration("quic.keepAlivePeriod"),
|
||||
DisablePathMTUDiscovery: viper.GetBool("quic.disablePathMTUDiscovery"),
|
||||
}
|
||||
}
|
||||
|
||||
func viperToClientBandwidthConfig() (client.BandwidthConfig, error) {
|
||||
bw := client.BandwidthConfig{}
|
||||
upStr, downStr := viper.GetString("bandwidth.up"), viper.GetString("bandwidth.down")
|
||||
if upStr == "" || downStr == "" {
|
||||
return client.BandwidthConfig{}, configError{Field: "bandwidth", Err: errors.New("bandwidth.up and bandwidth.down must be set")}
|
||||
}
|
||||
up, err := convBandwidth(upStr)
|
||||
if err != nil {
|
||||
return client.BandwidthConfig{}, configError{Field: "bandwidth.up", Err: err}
|
||||
}
|
||||
down, err := convBandwidth(downStr)
|
||||
if err != nil {
|
||||
return client.BandwidthConfig{}, configError{Field: "bandwidth.down", Err: err}
|
||||
}
|
||||
bw.MaxTx, bw.MaxRx = up, down
|
||||
return bw, nil
|
||||
}
|
||||
|
||||
func clientSOCKS5(v *viper.Viper, c client.Client) error {
|
||||
listenAddr := v.GetString("listen")
|
||||
if listenAddr == "" {
|
||||
func clientSOCKS5(config socks5Config, c client.Client) error {
|
||||
if config.Listen == "" {
|
||||
return configError{Field: "listen", Err: errors.New("listen address is empty")}
|
||||
}
|
||||
l, err := net.Listen("tcp", listenAddr)
|
||||
l, err := net.Listen("tcp", config.Listen)
|
||||
if err != nil {
|
||||
return configError{Field: "listen", Err: err}
|
||||
}
|
||||
var authFunc func(username, password string) bool
|
||||
username, password := v.GetString("username"), v.GetString("password")
|
||||
username, password := config.Username, config.Password
|
||||
if username != "" && password != "" {
|
||||
authFunc = func(u, p string) bool {
|
||||
return u == username && p == password
|
||||
@ -178,47 +199,46 @@ func clientSOCKS5(v *viper.Viper, c client.Client) error {
|
||||
s := socks5.Server{
|
||||
HyClient: c,
|
||||
AuthFunc: authFunc,
|
||||
DisableUDP: viper.GetBool("disableUDP"),
|
||||
DisableUDP: config.DisableUDP,
|
||||
EventLogger: &socks5Logger{},
|
||||
}
|
||||
logger.Info("SOCKS5 server listening", zap.String("addr", listenAddr))
|
||||
logger.Info("SOCKS5 server listening", zap.String("addr", config.Listen))
|
||||
return s.Serve(l)
|
||||
}
|
||||
|
||||
func clientHTTP(v *viper.Viper, c client.Client) error {
|
||||
listenAddr := v.GetString("listen")
|
||||
if listenAddr == "" {
|
||||
func clientHTTP(config httpConfig, c client.Client) error {
|
||||
if config.Listen == "" {
|
||||
return configError{Field: "listen", Err: errors.New("listen address is empty")}
|
||||
}
|
||||
l, err := net.Listen("tcp", listenAddr)
|
||||
l, err := net.Listen("tcp", config.Listen)
|
||||
if err != nil {
|
||||
return configError{Field: "listen", Err: err}
|
||||
}
|
||||
var authFunc func(username, password string) bool
|
||||
username, password := v.GetString("username"), v.GetString("password")
|
||||
username, password := config.Username, config.Password
|
||||
if username != "" && password != "" {
|
||||
authFunc = func(u, p string) bool {
|
||||
return u == username && p == password
|
||||
}
|
||||
}
|
||||
realm := v.GetString("realm")
|
||||
if realm == "" {
|
||||
realm = "Hysteria"
|
||||
if config.Realm == "" {
|
||||
config.Realm = "Hysteria"
|
||||
}
|
||||
h := http.Server{
|
||||
HyClient: c,
|
||||
AuthFunc: authFunc,
|
||||
AuthRealm: realm,
|
||||
AuthRealm: config.Realm,
|
||||
EventLogger: &httpLogger{},
|
||||
}
|
||||
logger.Info("HTTP proxy server listening", zap.String("addr", listenAddr))
|
||||
logger.Info("HTTP proxy server listening", zap.String("addr", config.Listen))
|
||||
return h.Serve(l)
|
||||
}
|
||||
|
||||
// parseServerAddrString parses server address string.
|
||||
// Server address can be in either "host:port" or "host" format (in which case we assume port 443).
|
||||
func parseServerAddrString(addrStr string) (host, hostPort string) {
|
||||
h, _, err := net.SplitHostPort(addrStr)
|
||||
if err != nil {
|
||||
// No port provided, use default HTTPS port
|
||||
return addrStr, net.JoinHostPort(addrStr, "443")
|
||||
}
|
||||
return h, addrStr
|
||||
|
@ -26,19 +26,23 @@ func runPing(cmd *cobra.Command, args []string) {
|
||||
logger.Info("ping mode")
|
||||
|
||||
if len(args) != 1 {
|
||||
logger.Fatal("no address specified")
|
||||
logger.Fatal("must specify one and only one address")
|
||||
}
|
||||
addr := args[0]
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
logger.Fatal("failed to read client config", zap.Error(err))
|
||||
}
|
||||
config, err := viperToClientConfig()
|
||||
if err != nil {
|
||||
var config clientConfig
|
||||
if err := viper.Unmarshal(&config); err != nil {
|
||||
logger.Fatal("failed to parse client config", zap.Error(err))
|
||||
}
|
||||
hyConfig, err := config.Config()
|
||||
if err != nil {
|
||||
logger.Fatal("failed to validate client config", zap.Error(err))
|
||||
}
|
||||
|
||||
c, err := client.NewClient(config)
|
||||
c, err := client.NewClient(hyConfig)
|
||||
if err != nil {
|
||||
logger.Fatal("failed to initialize client", zap.Error(err))
|
||||
}
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/apernet/hysteria/core/server"
|
||||
"github.com/apernet/hysteria/extras/auth"
|
||||
@ -27,257 +28,202 @@ var serverCmd = &cobra.Command{
|
||||
|
||||
func init() {
|
||||
rootCmd.AddCommand(serverCmd)
|
||||
initServerConfigDefaults()
|
||||
}
|
||||
|
||||
func initServerConfigDefaults() {
|
||||
viper.SetDefault("listen", ":443")
|
||||
type serverConfig struct {
|
||||
Listen string `mapstructure:"listen"`
|
||||
TLS *serverConfigTLS `mapstructure:"tls"`
|
||||
ACME *serverConfigACME `mapstructure:"acme"`
|
||||
QUIC struct {
|
||||
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
|
||||
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
|
||||
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
|
||||
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
|
||||
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
|
||||
MaxIncomingStreams int64 `mapstructure:"maxIncomingStreams"`
|
||||
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
|
||||
} `mapstructure:"quic"`
|
||||
Bandwidth struct {
|
||||
Up string `mapstructure:"up"`
|
||||
Down string `mapstructure:"down"`
|
||||
} `mapstructure:"bandwidth"`
|
||||
DisableUDP bool `mapstructure:"disableUDP"`
|
||||
Auth struct {
|
||||
Type string `mapstructure:"type"`
|
||||
Password string `mapstructure:"password"`
|
||||
} `mapstructure:"auth"`
|
||||
Masquerade struct {
|
||||
Type string `mapstructure:"type"`
|
||||
File struct {
|
||||
Dir string `mapstructure:"dir"`
|
||||
} `mapstructure:"file"`
|
||||
Proxy struct {
|
||||
URL string `mapstructure:"url"`
|
||||
RewriteHost bool `mapstructure:"rewriteHost"`
|
||||
} `mapstructure:"proxy"`
|
||||
} `mapstructure:"masquerade"`
|
||||
}
|
||||
|
||||
func runServer(cmd *cobra.Command, args []string) {
|
||||
logger.Info("server mode")
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
logger.Fatal("failed to read server config", zap.Error(err))
|
||||
}
|
||||
config, err := viperToServerConfig()
|
||||
if err != nil {
|
||||
logger.Fatal("failed to parse server config", zap.Error(err))
|
||||
type serverConfigTLS struct {
|
||||
Cert string `mapstructure:"cert"`
|
||||
Key string `mapstructure:"key"`
|
||||
}
|
||||
|
||||
s, err := server.NewServer(config)
|
||||
if err != nil {
|
||||
logger.Fatal("failed to initialize server", zap.Error(err))
|
||||
}
|
||||
logger.Info("server up and running")
|
||||
|
||||
if err := s.Serve(); err != nil {
|
||||
logger.Fatal("failed to serve", zap.Error(err))
|
||||
}
|
||||
type serverConfigACME struct {
|
||||
Domains []string `mapstructure:"domains"`
|
||||
Email string `mapstructure:"email"`
|
||||
CA string `mapstructure:"ca"`
|
||||
DisableHTTP bool `mapstructure:"disableHTTP"`
|
||||
DisableTLSALPN bool `mapstructure:"disableTLSALPN"`
|
||||
AltHTTPPort int `mapstructure:"altHTTPPort"`
|
||||
AltTLSALPNPort int `mapstructure:"altTLSALPNPort"`
|
||||
Dir string `mapstructure:"dir"`
|
||||
}
|
||||
|
||||
func viperToServerConfig() (*server.Config, error) {
|
||||
// Config validates the fields and returns a ready-to-use Hysteria server config
|
||||
func (c *serverConfig) Config() (*server.Config, error) {
|
||||
hyConfig := &server.Config{}
|
||||
// Conn
|
||||
conn, err := viperToServerConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
listenAddr := c.Listen
|
||||
if listenAddr == "" {
|
||||
listenAddr = ":443"
|
||||
}
|
||||
// TLS
|
||||
tlsConfig, err := viperToServerTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// QUIC
|
||||
quicConfig := viperToServerQUICConfig()
|
||||
// Bandwidth
|
||||
bwConfig, err := viperToServerBandwidthConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Disable UDP
|
||||
disableUDP := viper.GetBool("disableUDP")
|
||||
// Authenticator
|
||||
authenticator, err := viperToAuthenticator()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Masquerade
|
||||
masqHandler, err := viperToMasqHandler()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Config
|
||||
config := &server.Config{
|
||||
TLSConfig: tlsConfig,
|
||||
QUICConfig: quicConfig,
|
||||
Conn: conn,
|
||||
Outbound: nil, // TODO
|
||||
BandwidthConfig: bwConfig,
|
||||
DisableUDP: disableUDP,
|
||||
Authenticator: authenticator,
|
||||
EventLogger: &serverLogger{},
|
||||
MasqHandler: masqHandler,
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func viperToServerConn() (net.PacketConn, error) {
|
||||
listen := viper.GetString("listen")
|
||||
if listen == "" {
|
||||
return nil, configError{Field: "listen", Err: errors.New("empty listen address")}
|
||||
}
|
||||
uAddr, err := net.ResolveUDPAddr("udp", listen)
|
||||
uAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||
if err != nil {
|
||||
return nil, configError{Field: "listen", Err: err}
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", uAddr)
|
||||
hyConfig.Conn, err = net.ListenUDP("udp", uAddr)
|
||||
if err != nil {
|
||||
return nil, configError{Field: "listen", Err: err}
|
||||
}
|
||||
return conn, nil
|
||||
// TLSConfig
|
||||
if c.TLS == nil && c.ACME == nil {
|
||||
return nil, configError{Field: "tls", Err: errors.New("must set either tls or acme")}
|
||||
}
|
||||
|
||||
func viperToServerTLSConfig() (server.TLSConfig, error) {
|
||||
vTLS, vACME := viper.Sub("tls"), viper.Sub("acme")
|
||||
if vTLS == nil && vACME == nil {
|
||||
return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("must set either tls or acme")}
|
||||
if c.TLS != nil && c.ACME != nil {
|
||||
return nil, configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
|
||||
}
|
||||
if vTLS != nil && vACME != nil {
|
||||
return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
|
||||
if c.TLS != nil {
|
||||
// Local TLS cert
|
||||
if c.TLS.Cert == "" || c.TLS.Key == "" {
|
||||
return nil, configError{Field: "tls", Err: errors.New("empty cert or key path")}
|
||||
}
|
||||
if vTLS != nil {
|
||||
return viperToServerTLSConfigLocal(vTLS)
|
||||
cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key)
|
||||
if err != nil {
|
||||
return nil, configError{Field: "tls", Err: err}
|
||||
}
|
||||
hyConfig.TLSConfig.Certificates = []tls.Certificate{cert}
|
||||
} else {
|
||||
return viperToServerTLSConfigACME(vACME)
|
||||
}
|
||||
}
|
||||
|
||||
func viperToServerTLSConfigLocal(v *viper.Viper) (server.TLSConfig, error) {
|
||||
certPath, keyPath := v.GetString("cert"), v.GetString("key")
|
||||
if certPath == "" || keyPath == "" {
|
||||
return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("empty cert or key path")}
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return server.TLSConfig{}, configError{Field: "tls", Err: err}
|
||||
}
|
||||
return server.TLSConfig{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func viperToServerTLSConfigACME(v *viper.Viper) (server.TLSConfig, error) {
|
||||
dataDir := v.GetString("dir")
|
||||
// ACME
|
||||
dataDir := c.ACME.Dir
|
||||
if dataDir == "" {
|
||||
dataDir = "acme"
|
||||
}
|
||||
|
||||
cfg := &certmagic.Config{
|
||||
cmCfg := &certmagic.Config{
|
||||
RenewalWindowRatio: certmagic.DefaultRenewalWindowRatio,
|
||||
KeySource: certmagic.DefaultKeyGenerator,
|
||||
Storage: &certmagic.FileStorage{Path: dataDir},
|
||||
Logger: logger,
|
||||
}
|
||||
issuer := certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
|
||||
Email: v.GetString("email"),
|
||||
cmIssuer := certmagic.NewACMEIssuer(cmCfg, certmagic.ACMEIssuer{
|
||||
Email: c.ACME.Email,
|
||||
Agreed: true,
|
||||
DisableHTTPChallenge: v.GetBool("disableHTTP"),
|
||||
DisableTLSALPNChallenge: v.GetBool("disableTLSALPN"),
|
||||
AltHTTPPort: v.GetInt("altHTTPPort"),
|
||||
AltTLSALPNPort: v.GetInt("altTLSALPNPort"),
|
||||
DisableHTTPChallenge: c.ACME.DisableHTTP,
|
||||
DisableTLSALPNChallenge: c.ACME.DisableTLSALPN,
|
||||
AltHTTPPort: c.ACME.AltHTTPPort,
|
||||
AltTLSALPNPort: c.ACME.AltTLSALPNPort,
|
||||
Logger: logger,
|
||||
})
|
||||
switch strings.ToLower(v.GetString("ca")) {
|
||||
switch strings.ToLower(c.ACME.CA) {
|
||||
case "letsencrypt", "le", "":
|
||||
// Default to Let's Encrypt
|
||||
issuer.CA = certmagic.LetsEncryptProductionCA
|
||||
cmIssuer.CA = certmagic.LetsEncryptProductionCA
|
||||
case "zerossl", "zero":
|
||||
issuer.CA = certmagic.ZeroSSLProductionCA
|
||||
cmIssuer.CA = certmagic.ZeroSSLProductionCA
|
||||
default:
|
||||
return server.TLSConfig{}, configError{Field: "acme.ca", Err: errors.New("unknown CA")}
|
||||
return nil, configError{Field: "acme.ca", Err: errors.New("unknown CA")}
|
||||
}
|
||||
cfg.Issuers = []certmagic.Issuer{issuer}
|
||||
|
||||
cache := certmagic.NewCache(certmagic.CacheOptions{
|
||||
cmCfg.Issuers = []certmagic.Issuer{cmIssuer}
|
||||
cmCache := certmagic.NewCache(certmagic.CacheOptions{
|
||||
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
|
||||
return cfg, nil
|
||||
return cmCfg, nil
|
||||
},
|
||||
Logger: logger,
|
||||
})
|
||||
cfg = certmagic.New(cache, *cfg)
|
||||
cmCfg = certmagic.New(cmCache, *cmCfg)
|
||||
|
||||
domains := v.GetStringSlice("domains")
|
||||
if len(domains) == 0 {
|
||||
return server.TLSConfig{}, configError{Field: "acme.domains", Err: errors.New("empty domains")}
|
||||
if len(c.ACME.Domains) == 0 {
|
||||
return nil, configError{Field: "acme.domains", Err: errors.New("empty domains")}
|
||||
}
|
||||
err := cfg.ManageSync(context.Background(), domains)
|
||||
err = cmCfg.ManageSync(context.Background(), c.ACME.Domains)
|
||||
if err != nil {
|
||||
return server.TLSConfig{}, configError{Field: "acme", Err: err}
|
||||
return nil, configError{Field: "acme.domains", Err: err}
|
||||
}
|
||||
return server.TLSConfig{
|
||||
GetCertificate: cfg.GetCertificate,
|
||||
}, nil
|
||||
hyConfig.TLSConfig.GetCertificate = cmCfg.GetCertificate
|
||||
}
|
||||
|
||||
func viperToServerQUICConfig() server.QUICConfig {
|
||||
return server.QUICConfig{
|
||||
InitialStreamReceiveWindow: viper.GetUint64("quic.initStreamReceiveWindow"),
|
||||
MaxStreamReceiveWindow: viper.GetUint64("quic.maxStreamReceiveWindow"),
|
||||
InitialConnectionReceiveWindow: viper.GetUint64("quic.initConnReceiveWindow"),
|
||||
MaxConnectionReceiveWindow: viper.GetUint64("quic.maxConnReceiveWindow"),
|
||||
MaxIdleTimeout: viper.GetDuration("quic.maxIdleTimeout"),
|
||||
MaxIncomingStreams: viper.GetInt64("quic.maxIncomingStreams"),
|
||||
DisablePathMTUDiscovery: viper.GetBool("quic.disablePathMTUDiscovery"),
|
||||
// QUICConfig
|
||||
hyConfig.QUICConfig = server.QUICConfig{
|
||||
InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow,
|
||||
MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow,
|
||||
InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow,
|
||||
MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow,
|
||||
MaxIdleTimeout: c.QUIC.MaxIdleTimeout,
|
||||
MaxIncomingStreams: c.QUIC.MaxIncomingStreams,
|
||||
DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery,
|
||||
}
|
||||
}
|
||||
|
||||
func viperToServerBandwidthConfig() (server.BandwidthConfig, error) {
|
||||
bw := server.BandwidthConfig{}
|
||||
upStr, downStr := viper.GetString("bandwidth.up"), viper.GetString("bandwidth.down")
|
||||
if upStr != "" {
|
||||
up, err := convBandwidth(upStr)
|
||||
// BandwidthConfig
|
||||
if c.Bandwidth.Up != "" {
|
||||
hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up)
|
||||
if err != nil {
|
||||
return server.BandwidthConfig{}, configError{Field: "bandwidth.up", Err: err}
|
||||
return nil, configError{Field: "bandwidth.up", Err: err}
|
||||
}
|
||||
bw.MaxTx = up
|
||||
}
|
||||
if downStr != "" {
|
||||
down, err := convBandwidth(downStr)
|
||||
if c.Bandwidth.Down != "" {
|
||||
hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down)
|
||||
if err != nil {
|
||||
return server.BandwidthConfig{}, configError{Field: "bandwidth.down", Err: err}
|
||||
return nil, configError{Field: "bandwidth.down", Err: err}
|
||||
}
|
||||
bw.MaxRx = down
|
||||
}
|
||||
return bw, nil
|
||||
}
|
||||
|
||||
func viperToAuthenticator() (server.Authenticator, error) {
|
||||
authType := viper.GetString("auth.type")
|
||||
if authType == "" {
|
||||
// DisableUDP
|
||||
hyConfig.DisableUDP = c.DisableUDP
|
||||
// Authenticator
|
||||
if c.Auth.Type == "" {
|
||||
return nil, configError{Field: "auth.type", Err: errors.New("empty auth type")}
|
||||
}
|
||||
switch authType {
|
||||
switch strings.ToLower(c.Auth.Type) {
|
||||
case "password":
|
||||
pw := viper.GetString("auth.password")
|
||||
if pw == "" {
|
||||
if c.Auth.Password == "" {
|
||||
return nil, configError{Field: "auth.password", Err: errors.New("empty auth password")}
|
||||
}
|
||||
return &auth.PasswordAuthenticator{Password: pw}, nil
|
||||
hyConfig.Authenticator = &auth.PasswordAuthenticator{Password: c.Auth.Password}
|
||||
default:
|
||||
return nil, configError{Field: "auth.type", Err: errors.New("unsupported auth type")}
|
||||
}
|
||||
}
|
||||
|
||||
func viperToMasqHandler() (http.Handler, error) {
|
||||
masqType := viper.GetString("masquerade.type")
|
||||
if masqType == "" {
|
||||
// Default to use the 404 handler
|
||||
return http.NotFoundHandler(), nil
|
||||
}
|
||||
switch masqType {
|
||||
case "404":
|
||||
return http.NotFoundHandler(), nil
|
||||
// EventLogger
|
||||
hyConfig.EventLogger = &serverLogger{}
|
||||
// MasqHandler
|
||||
switch strings.ToLower(c.Masquerade.Type) {
|
||||
case "", "404":
|
||||
hyConfig.MasqHandler = http.NotFoundHandler()
|
||||
case "file":
|
||||
dir := viper.GetString("masquerade.file.dir")
|
||||
if dir == "" {
|
||||
return nil, configError{Field: "masquerade.file.dir", Err: errors.New("empty directory")}
|
||||
if c.Masquerade.File.Dir == "" {
|
||||
return nil, configError{Field: "masquerade.file.dir", Err: errors.New("empty file directory")}
|
||||
}
|
||||
return http.FileServer(http.Dir(dir)), nil
|
||||
hyConfig.MasqHandler = http.FileServer(http.Dir(c.Masquerade.File.Dir))
|
||||
case "proxy":
|
||||
urlStr := viper.GetString("masquerade.proxy.url")
|
||||
if urlStr == "" {
|
||||
return nil, configError{Field: "masquerade.proxy.url", Err: errors.New("empty url")}
|
||||
if c.Masquerade.Proxy.URL == "" {
|
||||
return nil, configError{Field: "masquerade.proxy.url", Err: errors.New("empty proxy url")}
|
||||
}
|
||||
u, err := url.Parse(urlStr)
|
||||
u, err := url.Parse(c.Masquerade.Proxy.URL)
|
||||
if err != nil {
|
||||
return nil, configError{Field: "masquerade.proxy.url", Err: err}
|
||||
}
|
||||
proxy := &httputil.ReverseProxy{
|
||||
hyConfig.MasqHandler = &httputil.ReverseProxy{
|
||||
Rewrite: func(r *httputil.ProxyRequest) {
|
||||
r.SetURL(u)
|
||||
// SetURL rewrites the Host header,
|
||||
// but we don't want that if rewriteHost is false
|
||||
if !viper.GetBool("masquerade.proxy.rewriteHost") {
|
||||
if !c.Masquerade.Proxy.RewriteHost {
|
||||
r.Out.Host = r.In.Host
|
||||
}
|
||||
},
|
||||
@ -286,10 +232,36 @@ func viperToMasqHandler() (http.Handler, error) {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
},
|
||||
}
|
||||
return proxy, nil
|
||||
default:
|
||||
return nil, configError{Field: "masquerade.type", Err: errors.New("unsupported masquerade type")}
|
||||
}
|
||||
return hyConfig, nil
|
||||
}
|
||||
|
||||
func runServer(cmd *cobra.Command, args []string) {
|
||||
logger.Info("server mode")
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
logger.Fatal("failed to read server config", zap.Error(err))
|
||||
}
|
||||
var config serverConfig
|
||||
if err := viper.Unmarshal(&config); err != nil {
|
||||
logger.Fatal("failed to parse server config", zap.Error(err))
|
||||
}
|
||||
hyConfig, err := config.Config()
|
||||
if err != nil {
|
||||
logger.Fatal("failed to validate server config", zap.Error(err))
|
||||
}
|
||||
|
||||
s, err := server.NewServer(hyConfig)
|
||||
if err != nil {
|
||||
logger.Fatal("failed to initialize server", zap.Error(err))
|
||||
}
|
||||
logger.Info("server up and running")
|
||||
|
||||
if err := s.Serve(); err != nil {
|
||||
logger.Fatal("failed to serve", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
type serverLogger struct{}
|
||||
|
Loading…
x
Reference in New Issue
Block a user