package cmd

import (
	"context"
	"crypto/tls"
	"errors"
	"net"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strings"
	"time"

	"github.com/caddyserver/certmagic"
	"github.com/spf13/cobra"
	"github.com/spf13/viper"
	"go.uber.org/zap"

	"github.com/apernet/hysteria/core/server"
	"github.com/apernet/hysteria/extras/auth"
	"github.com/apernet/hysteria/extras/obfs"
)

var serverCmd = &cobra.Command{
	Use:   "server",
	Short: "Server mode",
	Run:   runServer,
}

func init() {
	rootCmd.AddCommand(serverCmd)
}

type serverConfig struct {
	Listen string `mapstructure:"listen"`
	Obfs   struct {
		Type       string `mapstructure:"type"`
		Salamander struct {
			Password string `mapstructure:"password"`
		} `mapstructure:"salamander"`
	} `mapstructure:"obfs"`
	TLS  *serverConfigTLS  `mapstructure:"tls"`
	ACME *serverConfigACME `mapstructure:"acme"`
	QUIC struct {
		InitStreamReceiveWindow     uint64        `mapstructure:"initStreamReceiveWindow"`
		MaxStreamReceiveWindow      uint64        `mapstructure:"maxStreamReceiveWindow"`
		InitConnectionReceiveWindow uint64        `mapstructure:"initConnReceiveWindow"`
		MaxConnectionReceiveWindow  uint64        `mapstructure:"maxConnReceiveWindow"`
		MaxIdleTimeout              time.Duration `mapstructure:"maxIdleTimeout"`
		MaxIncomingStreams          int64         `mapstructure:"maxIncomingStreams"`
		DisablePathMTUDiscovery     bool          `mapstructure:"disablePathMTUDiscovery"`
	} `mapstructure:"quic"`
	Bandwidth struct {
		Up   string `mapstructure:"up"`
		Down string `mapstructure:"down"`
	} `mapstructure:"bandwidth"`
	DisableUDP bool `mapstructure:"disableUDP"`
	Auth       struct {
		Type     string `mapstructure:"type"`
		Password string `mapstructure:"password"`
	} `mapstructure:"auth"`
	Masquerade struct {
		Type string `mapstructure:"type"`
		File struct {
			Dir string `mapstructure:"dir"`
		} `mapstructure:"file"`
		Proxy struct {
			URL         string `mapstructure:"url"`
			RewriteHost bool   `mapstructure:"rewriteHost"`
		} `mapstructure:"proxy"`
	} `mapstructure:"masquerade"`
}

type serverConfigTLS struct {
	Cert string `mapstructure:"cert"`
	Key  string `mapstructure:"key"`
}

type serverConfigACME struct {
	Domains        []string `mapstructure:"domains"`
	Email          string   `mapstructure:"email"`
	CA             string   `mapstructure:"ca"`
	DisableHTTP    bool     `mapstructure:"disableHTTP"`
	DisableTLSALPN bool     `mapstructure:"disableTLSALPN"`
	AltHTTPPort    int      `mapstructure:"altHTTPPort"`
	AltTLSALPNPort int      `mapstructure:"altTLSALPNPort"`
	Dir            string   `mapstructure:"dir"`
}

func (c *serverConfig) fillConn(hyConfig *server.Config) error {
	listenAddr := c.Listen
	if listenAddr == "" {
		listenAddr = ":443"
	}
	uAddr, err := net.ResolveUDPAddr("udp", listenAddr)
	if err != nil {
		return configError{Field: "listen", Err: err}
	}
	conn, err := net.ListenUDP("udp", uAddr)
	if err != nil {
		return configError{Field: "listen", Err: err}
	}
	switch strings.ToLower(c.Obfs.Type) {
	case "", "plain":
		hyConfig.Conn = conn
		return nil
	case "salamander":
		ob, err := obfs.NewSalamanderObfuscator([]byte(c.Obfs.Salamander.Password))
		if err != nil {
			return configError{Field: "obfs.salamander.password", Err: err}
		}
		hyConfig.Conn = obfs.WrapPacketConn(conn, ob)
		return nil
	default:
		return configError{Field: "obfs.type", Err: errors.New("unsupported obfuscation type")}
	}
}

func (c *serverConfig) fillTLSConfig(hyConfig *server.Config) error {
	if c.TLS == nil && c.ACME == nil {
		return configError{Field: "tls", Err: errors.New("must set either tls or acme")}
	}
	if c.TLS != nil && c.ACME != nil {
		return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
	}
	if c.TLS != nil {
		// Local TLS cert
		if c.TLS.Cert == "" || c.TLS.Key == "" {
			return configError{Field: "tls", Err: errors.New("empty cert or key path")}
		}
		cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key)
		if err != nil {
			return configError{Field: "tls", Err: err}
		}
		hyConfig.TLSConfig.Certificates = []tls.Certificate{cert}
	} else {
		// ACME
		dataDir := c.ACME.Dir
		if dataDir == "" {
			dataDir = "acme"
		}
		cmCfg := &certmagic.Config{
			RenewalWindowRatio: certmagic.DefaultRenewalWindowRatio,
			KeySource:          certmagic.DefaultKeyGenerator,
			Storage:            &certmagic.FileStorage{Path: dataDir},
			Logger:             logger,
		}
		cmIssuer := certmagic.NewACMEIssuer(cmCfg, certmagic.ACMEIssuer{
			Email:                   c.ACME.Email,
			Agreed:                  true,
			DisableHTTPChallenge:    c.ACME.DisableHTTP,
			DisableTLSALPNChallenge: c.ACME.DisableTLSALPN,
			AltHTTPPort:             c.ACME.AltHTTPPort,
			AltTLSALPNPort:          c.ACME.AltTLSALPNPort,
			Logger:                  logger,
		})
		switch strings.ToLower(c.ACME.CA) {
		case "letsencrypt", "le", "":
			// Default to Let's Encrypt
			cmIssuer.CA = certmagic.LetsEncryptProductionCA
		case "zerossl", "zero":
			cmIssuer.CA = certmagic.ZeroSSLProductionCA
		default:
			return configError{Field: "acme.ca", Err: errors.New("unknown CA")}
		}
		cmCfg.Issuers = []certmagic.Issuer{cmIssuer}
		cmCache := certmagic.NewCache(certmagic.CacheOptions{
			GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
				return cmCfg, nil
			},
			Logger: logger,
		})
		cmCfg = certmagic.New(cmCache, *cmCfg)

		if len(c.ACME.Domains) == 0 {
			return configError{Field: "acme.domains", Err: errors.New("empty domains")}
		}
		err := cmCfg.ManageSync(context.Background(), c.ACME.Domains)
		if err != nil {
			return configError{Field: "acme.domains", Err: err}
		}
		hyConfig.TLSConfig.GetCertificate = cmCfg.GetCertificate
	}
	return nil
}

func (c *serverConfig) fillQUICConfig(hyConfig *server.Config) error {
	hyConfig.QUICConfig = server.QUICConfig{
		InitialStreamReceiveWindow:     c.QUIC.InitStreamReceiveWindow,
		MaxStreamReceiveWindow:         c.QUIC.MaxStreamReceiveWindow,
		InitialConnectionReceiveWindow: c.QUIC.InitConnectionReceiveWindow,
		MaxConnectionReceiveWindow:     c.QUIC.MaxConnectionReceiveWindow,
		MaxIdleTimeout:                 c.QUIC.MaxIdleTimeout,
		MaxIncomingStreams:             c.QUIC.MaxIncomingStreams,
		DisablePathMTUDiscovery:        c.QUIC.DisablePathMTUDiscovery,
	}
	return nil
}

func (c *serverConfig) fillBandwidthConfig(hyConfig *server.Config) error {
	var err error
	if c.Bandwidth.Up != "" {
		hyConfig.BandwidthConfig.MaxTx, err = convBandwidth(c.Bandwidth.Up)
		if err != nil {
			return configError{Field: "bandwidth.up", Err: err}
		}
	}
	if c.Bandwidth.Down != "" {
		hyConfig.BandwidthConfig.MaxRx, err = convBandwidth(c.Bandwidth.Down)
		if err != nil {
			return configError{Field: "bandwidth.down", Err: err}
		}
	}
	return nil
}

func (c *serverConfig) fillDisableUDP(hyConfig *server.Config) error {
	hyConfig.DisableUDP = c.DisableUDP
	return nil
}

func (c *serverConfig) fillAuthenticator(hyConfig *server.Config) error {
	if c.Auth.Type == "" {
		return configError{Field: "auth.type", Err: errors.New("empty auth type")}
	}
	switch strings.ToLower(c.Auth.Type) {
	case "password":
		if c.Auth.Password == "" {
			return configError{Field: "auth.password", Err: errors.New("empty auth password")}
		}
		hyConfig.Authenticator = &auth.PasswordAuthenticator{Password: c.Auth.Password}
		return nil
	default:
		return configError{Field: "auth.type", Err: errors.New("unsupported auth type")}
	}
}

func (c *serverConfig) fillEventLogger(hyConfig *server.Config) error {
	hyConfig.EventLogger = &serverLogger{}
	return nil
}

func (c *serverConfig) fillMasqHandler(hyConfig *server.Config) error {
	switch strings.ToLower(c.Masquerade.Type) {
	case "", "404":
		hyConfig.MasqHandler = http.NotFoundHandler()
		return nil
	case "file":
		if c.Masquerade.File.Dir == "" {
			return configError{Field: "masquerade.file.dir", Err: errors.New("empty file directory")}
		}
		hyConfig.MasqHandler = http.FileServer(http.Dir(c.Masquerade.File.Dir))
		return nil
	case "proxy":
		if c.Masquerade.Proxy.URL == "" {
			return configError{Field: "masquerade.proxy.url", Err: errors.New("empty proxy url")}
		}
		u, err := url.Parse(c.Masquerade.Proxy.URL)
		if err != nil {
			return configError{Field: "masquerade.proxy.url", Err: err}
		}
		hyConfig.MasqHandler = &httputil.ReverseProxy{
			Rewrite: func(r *httputil.ProxyRequest) {
				r.SetURL(u)
				// SetURL rewrites the Host header,
				// but we don't want that if rewriteHost is false
				if !c.Masquerade.Proxy.RewriteHost {
					r.Out.Host = r.In.Host
				}
			},
			ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
				logger.Error("HTTP reverse proxy error", zap.Error(err))
				w.WriteHeader(http.StatusBadGateway)
			},
		}
		return nil
	default:
		return configError{Field: "masquerade.type", Err: errors.New("unsupported masquerade type")}
	}
}

// Config validates the fields and returns a ready-to-use Hysteria server config
func (c *serverConfig) Config() (*server.Config, error) {
	hyConfig := &server.Config{}
	fillers := []func(*server.Config) error{
		c.fillConn,
		c.fillTLSConfig,
		c.fillQUICConfig,
		c.fillBandwidthConfig,
		c.fillDisableUDP,
		c.fillAuthenticator,
		c.fillEventLogger,
		c.fillMasqHandler,
	}
	for _, f := range fillers {
		if err := f(hyConfig); err != nil {
			return nil, err
		}
	}
	return hyConfig, nil
}

func runServer(cmd *cobra.Command, args []string) {
	logger.Info("server mode")

	if err := viper.ReadInConfig(); err != nil {
		logger.Fatal("failed to read server config", zap.Error(err))
	}
	var config serverConfig
	if err := viper.Unmarshal(&config); err != nil {
		logger.Fatal("failed to parse server config", zap.Error(err))
	}
	hyConfig, err := config.Config()
	if err != nil {
		logger.Fatal("failed to load server config", zap.Error(err))
	}

	s, err := server.NewServer(hyConfig)
	if err != nil {
		logger.Fatal("failed to initialize server", zap.Error(err))
	}
	logger.Info("server up and running")

	if err := s.Serve(); err != nil {
		logger.Fatal("failed to serve", zap.Error(err))
	}
}

type serverLogger struct{}

func (l *serverLogger) Connect(addr net.Addr, id string, tx uint64) {
	logger.Info("client connected", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint64("tx", tx))
}

func (l *serverLogger) Disconnect(addr net.Addr, id string, err error) {
	logger.Info("client disconnected", zap.String("addr", addr.String()), zap.String("id", id), zap.Error(err))
}

func (l *serverLogger) TCPRequest(addr net.Addr, id, reqAddr string) {
	logger.Debug("TCP request", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr))
}

func (l *serverLogger) TCPError(addr net.Addr, id, reqAddr string, err error) {
	if err == nil {
		logger.Debug("TCP closed", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr))
	} else {
		logger.Error("TCP error", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr), zap.Error(err))
	}
}

func (l *serverLogger) UDPRequest(addr net.Addr, id string, sessionID uint32) {
	logger.Debug("UDP request", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID))
}

func (l *serverLogger) UDPError(addr net.Addr, id string, sessionID uint32, err error) {
	if err == nil {
		logger.Debug("UDP closed", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID))
	} else {
		logger.Error("UDP error", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID), zap.Error(err))
	}
}