hysteria-dev/app/cmd/server.go
2023-07-09 16:37:18 -07:00

319 lines
10 KiB
Go

package cmd
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/caddyserver/certmagic"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"go.uber.org/zap"
"github.com/apernet/hysteria/core/server"
"github.com/apernet/hysteria/extras/auth"
"github.com/apernet/hysteria/extras/obfs"
)
var serverCmd = &cobra.Command{
Use: "server",
Short: "Server mode",
Run: runServer,
}
func init() {
rootCmd.AddCommand(serverCmd)
}
type serverConfig struct {
Listen string `mapstructure:"listen"`
Obfs struct {
Type string `mapstructure:"type"`
Salamander struct {
Password string `mapstructure:"password"`
} `mapstructure:"salamander"`
} `mapstructure:"obfs"`
TLS *serverConfigTLS `mapstructure:"tls"`
ACME *serverConfigACME `mapstructure:"acme"`
QUIC struct {
InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"`
MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"`
InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"`
MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"`
MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"`
MaxIncomingStreams int64 `mapstructure:"maxIncomingStreams"`
DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"`
} `mapstructure:"quic"`
Bandwidth struct {
Up string `mapstructure:"up"`
Down string `mapstructure:"down"`
} `mapstructure:"bandwidth"`
DisableUDP bool `mapstructure:"disableUDP"`
Auth struct {
Type string `mapstructure:"type"`
Password string `mapstructure:"password"`
} `mapstructure:"auth"`
Masquerade struct {
Type string `mapstructure:"type"`
File struct {
Dir string `mapstructure:"dir"`
} `mapstructure:"file"`
Proxy struct {
URL string `mapstructure:"url"`
RewriteHost bool `mapstructure:"rewriteHost"`
} `mapstructure:"proxy"`
} `mapstructure:"masquerade"`
}
type serverConfigTLS struct {
Cert string `mapstructure:"cert"`
Key string `mapstructure:"key"`
}
type serverConfigACME struct {
Domains []string `mapstructure:"domains"`
Email string `mapstructure:"email"`
CA string `mapstructure:"ca"`
DisableHTTP bool `mapstructure:"disableHTTP"`
DisableTLSALPN bool `mapstructure:"disableTLSALPN"`
AltHTTPPort int `mapstructure:"altHTTPPort"`
AltTLSALPNPort int `mapstructure:"altTLSALPNPort"`
Dir string `mapstructure:"dir"`
}
// Config validates the fields and returns a ready-to-use Hysteria server config
func (c *serverConfig) Config() (*server.Config, error) {
hyConfig := &server.Config{}
// Conn
listenAddr := c.Listen
if listenAddr == "" {
listenAddr = ":443"
}
uAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, configError{Field: "listen", Err: err}
}
conn, err := net.ListenUDP("udp", uAddr)
if err != nil {
return nil, configError{Field: "listen", Err: err}
}
switch strings.ToLower(c.Obfs.Type) {
case "", "plain":
hyConfig.Conn = conn
case "salamander":
ob, err := obfs.NewSalamanderObfuscator([]byte(c.Obfs.Salamander.Password))
if err != nil {
return nil, configError{Field: "obfs.salamander.password", Err: err}
}
hyConfig.Conn = obfs.WrapPacketConn(conn, ob)
default:
return nil, configError{Field: "obfs.type", Err: errors.New("unsupported obfuscation type")}
}
// TLSConfig
if c.TLS == nil && c.ACME == nil {
return nil, configError{Field: "tls", Err: errors.New("must set either tls or acme")}
}
if c.TLS != nil && c.ACME != nil {
return nil, configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
}
if c.TLS != nil {
// Local TLS cert
if c.TLS.Cert == "" || c.TLS.Key == "" {
return nil, configError{Field: "tls", Err: errors.New("empty cert or key path")}
}
cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key)
if err != nil {
return nil, configError{Field: "tls", Err: err}
}
hyConfig.TLSConfig.Certificates = []tls.Certificate{cert}
} else {
// ACME
dataDir := c.ACME.Dir
if dataDir == "" {
dataDir = "acme"
}
cmCfg := &certmagic.Config{
RenewalWindowRatio: certmagic.DefaultRenewalWindowRatio,
KeySource: certmagic.DefaultKeyGenerator,
Storage: &certmagic.FileStorage{Path: dataDir},
Logger: logger,
}
cmIssuer := certmagic.NewACMEIssuer(cmCfg, certmagic.ACMEIssuer{
Email: c.ACME.Email,
Agreed: true,
DisableHTTPChallenge: c.ACME.DisableHTTP,
DisableTLSALPNChallenge: c.ACME.DisableTLSALPN,
AltHTTPPort: c.ACME.AltHTTPPort,
AltTLSALPNPort: c.ACME.AltTLSALPNPort,
Logger: logger,
})
switch strings.ToLower(c.ACME.CA) {
case "letsencrypt", "le", "":
// Default to Let's Encrypt
cmIssuer.CA = certmagic.LetsEncryptProductionCA
case "zerossl", "zero":
cmIssuer.CA = certmagic.ZeroSSLProductionCA
default:
return nil, configError{Field: "acme.ca", Err: errors.New("unknown CA")}
}
cmCfg.Issuers = []certmagic.Issuer{cmIssuer}
cmCache := certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
return cmCfg, nil
},
Logger: logger,
})
cmCfg = certmagic.New(cmCache, *cmCfg)
if len(c.ACME.Domains) == 0 {
return nil, configError{Field: "acme.domains", Err: errors.New("empty domains")}
}
err = cmCfg.ManageSync(context.Background(), c.ACME.Domains)
if err != nil {
return nil, configError{Field: "acme.domains", Err: err}
}
hyConfig.TLSConfig.GetCertificate = cmCfg.GetCertificate
}
// QUICConfig
hyConfig.QUICConfig = server.QUICConfig{
InitialStreamReceiveWindow: c.QUIC.InitStreamReceiveWindow,
MaxStreamReceiveWindow: c.QUIC.MaxStreamReceiveWindow,
InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow,
MaxConnectionReceiveWindow: c.QUIC.MaxConnectionReceiveWindow,
MaxIdleTimeout: c.QUIC.MaxIdleTimeout,
MaxIncomingStreams: c.QUIC.MaxIncomingStreams,
DisablePathMTUDiscovery: c.QUIC.DisablePathMTUDiscovery,
}
// BandwidthConfig
if c.Bandwidth.Up != "" {
hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up)
if err != nil {
return nil, configError{Field: "bandwidth.up", Err: err}
}
}
if c.Bandwidth.Down != "" {
hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down)
if err != nil {
return nil, configError{Field: "bandwidth.down", Err: err}
}
}
// DisableUDP
hyConfig.DisableUDP = c.DisableUDP
// Authenticator
if c.Auth.Type == "" {
return nil, configError{Field: "auth.type", Err: errors.New("empty auth type")}
}
switch strings.ToLower(c.Auth.Type) {
case "password":
if c.Auth.Password == "" {
return nil, configError{Field: "auth.password", Err: errors.New("empty auth password")}
}
hyConfig.Authenticator = &auth.PasswordAuthenticator{Password: c.Auth.Password}
default:
return nil, configError{Field: "auth.type", Err: errors.New("unsupported auth type")}
}
// EventLogger
hyConfig.EventLogger = &serverLogger{}
// MasqHandler
switch strings.ToLower(c.Masquerade.Type) {
case "", "404":
hyConfig.MasqHandler = http.NotFoundHandler()
case "file":
if c.Masquerade.File.Dir == "" {
return nil, configError{Field: "masquerade.file.dir", Err: errors.New("empty file directory")}
}
hyConfig.MasqHandler = http.FileServer(http.Dir(c.Masquerade.File.Dir))
case "proxy":
if c.Masquerade.Proxy.URL == "" {
return nil, configError{Field: "masquerade.proxy.url", Err: errors.New("empty proxy url")}
}
u, err := url.Parse(c.Masquerade.Proxy.URL)
if err != nil {
return nil, configError{Field: "masquerade.proxy.url", Err: err}
}
hyConfig.MasqHandler = &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(u)
// SetURL rewrites the Host header,
// but we don't want that if rewriteHost is false
if !c.Masquerade.Proxy.RewriteHost {
r.Out.Host = r.In.Host
}
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
logger.Error("HTTP reverse proxy error", zap.Error(err))
w.WriteHeader(http.StatusBadGateway)
},
}
default:
return nil, configError{Field: "masquerade.type", Err: errors.New("unsupported masquerade type")}
}
return hyConfig, nil
}
func runServer(cmd *cobra.Command, args []string) {
logger.Info("server mode")
if err := viper.ReadInConfig(); err != nil {
logger.Fatal("failed to read server config", zap.Error(err))
}
var config serverConfig
if err := viper.Unmarshal(&config); err != nil {
logger.Fatal("failed to parse server config", zap.Error(err))
}
hyConfig, err := config.Config()
if err != nil {
logger.Fatal("failed to load server config", zap.Error(err))
}
s, err := server.NewServer(hyConfig)
if err != nil {
logger.Fatal("failed to initialize server", zap.Error(err))
}
logger.Info("server up and running")
if err := s.Serve(); err != nil {
logger.Fatal("failed to serve", zap.Error(err))
}
}
type serverLogger struct{}
func (l *serverLogger) Connect(addr net.Addr, id string, tx uint64) {
logger.Info("client connected", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint64("tx", tx))
}
func (l *serverLogger) Disconnect(addr net.Addr, id string, err error) {
logger.Info("client disconnected", zap.String("addr", addr.String()), zap.String("id", id), zap.Error(err))
}
func (l *serverLogger) TCPRequest(addr net.Addr, id, reqAddr string) {
logger.Debug("TCP request", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr))
}
func (l *serverLogger) TCPError(addr net.Addr, id, reqAddr string, err error) {
if err == nil {
logger.Debug("TCP closed", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr))
} else {
logger.Error("TCP error", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr), zap.Error(err))
}
}
func (l *serverLogger) UDPRequest(addr net.Addr, id string, sessionID uint32) {
logger.Debug("UDP request", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID))
}
func (l *serverLogger) UDPError(addr net.Addr, id string, sessionID uint32, err error) {
if err == nil {
logger.Debug("UDP closed", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID))
} else {
logger.Error("UDP error", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID), zap.Error(err))
}
}