package cmd

import (
	"context"
	"crypto/tls"
	"errors"
	"net"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strings"

	"github.com/apernet/hysteria/core/server"
	"github.com/apernet/hysteria/extras/auth"

	"github.com/caddyserver/certmagic"
	"github.com/spf13/cobra"
	"github.com/spf13/viper"
	"go.uber.org/zap"
)

var serverCmd = &cobra.Command{
	Use:   "server",
	Short: "Server mode",
	Run:   runServer,
}

func init() {
	rootCmd.AddCommand(serverCmd)
	initServerConfigDefaults()
}

func initServerConfigDefaults() {
	viper.SetDefault("listen", ":443")
}

func runServer(cmd *cobra.Command, args []string) {
	logger.Info("server mode")

	if err := viper.ReadInConfig(); err != nil {
		logger.Fatal("failed to read server config", zap.Error(err))
	}
	config, err := viperToServerConfig()
	if err != nil {
		logger.Fatal("failed to parse server config", zap.Error(err))
	}

	s, err := server.NewServer(config)
	if err != nil {
		logger.Fatal("failed to initialize server", zap.Error(err))
	}
	logger.Info("server up and running")

	if err := s.Serve(); err != nil {
		logger.Fatal("failed to serve", zap.Error(err))
	}
}

func viperToServerConfig() (*server.Config, error) {
	// Conn
	conn, err := viperToServerConn()
	if err != nil {
		return nil, err
	}
	// TLS
	tlsConfig, err := viperToServerTLSConfig()
	if err != nil {
		return nil, err
	}
	// QUIC
	quicConfig := viperToServerQUICConfig()
	// Bandwidth
	bwConfig, err := viperToServerBandwidthConfig()
	if err != nil {
		return nil, err
	}
	// Disable UDP
	disableUDP := viper.GetBool("disableUDP")
	// Authenticator
	authenticator, err := viperToAuthenticator()
	if err != nil {
		return nil, err
	}
	// Masquerade
	masqHandler, err := viperToMasqHandler()
	if err != nil {
		return nil, err
	}
	// Config
	config := &server.Config{
		TLSConfig:       tlsConfig,
		QUICConfig:      quicConfig,
		Conn:            conn,
		Outbound:        nil, // TODO
		BandwidthConfig: bwConfig,
		DisableUDP:      disableUDP,
		Authenticator:   authenticator,
		EventLogger:     &serverLogger{},
		MasqHandler:     masqHandler,
	}
	return config, nil
}

func viperToServerConn() (net.PacketConn, error) {
	listen := viper.GetString("listen")
	if listen == "" {
		return nil, configError{Field: "listen", Err: errors.New("empty listen address")}
	}
	uAddr, err := net.ResolveUDPAddr("udp", listen)
	if err != nil {
		return nil, configError{Field: "listen", Err: err}
	}
	conn, err := net.ListenUDP("udp", uAddr)
	if err != nil {
		return nil, configError{Field: "listen", Err: err}
	}
	return conn, nil
}

func viperToServerTLSConfig() (server.TLSConfig, error) {
	vTLS, vACME := viper.Sub("tls"), viper.Sub("acme")
	if vTLS == nil && vACME == nil {
		return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("must set either tls or acme")}
	}
	if vTLS != nil && vACME != nil {
		return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
	}
	if vTLS != nil {
		return viperToServerTLSConfigLocal(vTLS)
	} else {
		return viperToServerTLSConfigACME(vACME)
	}
}

func viperToServerTLSConfigLocal(v *viper.Viper) (server.TLSConfig, error) {
	certPath, keyPath := v.GetString("cert"), v.GetString("key")
	if certPath == "" || keyPath == "" {
		return server.TLSConfig{}, configError{Field: "tls", Err: errors.New("empty cert or key path")}
	}
	cert, err := tls.LoadX509KeyPair(certPath, keyPath)
	if err != nil {
		return server.TLSConfig{}, configError{Field: "tls", Err: err}
	}
	return server.TLSConfig{
		Certificates: []tls.Certificate{cert},
	}, nil
}

func viperToServerTLSConfigACME(v *viper.Viper) (server.TLSConfig, error) {
	dataDir := v.GetString("dir")
	if dataDir == "" {
		dataDir = "acme"
	}

	cfg := &certmagic.Config{
		RenewalWindowRatio: certmagic.DefaultRenewalWindowRatio,
		KeySource:          certmagic.DefaultKeyGenerator,
		Storage:            &certmagic.FileStorage{Path: dataDir},
		Logger:             logger,
	}
	issuer := certmagic.NewACMEIssuer(cfg, certmagic.ACMEIssuer{
		Email:                   v.GetString("email"),
		Agreed:                  true,
		DisableHTTPChallenge:    v.GetBool("disableHTTP"),
		DisableTLSALPNChallenge: v.GetBool("disableTLSALPN"),
		AltHTTPPort:             v.GetInt("altHTTPPort"),
		AltTLSALPNPort:          v.GetInt("altTLSALPNPort"),
		Logger:                  logger,
	})
	switch strings.ToLower(v.GetString("ca")) {
	case "letsencrypt", "le", "":
		// Default to Let's Encrypt
		issuer.CA = certmagic.LetsEncryptProductionCA
	case "zerossl", "zero":
		issuer.CA = certmagic.ZeroSSLProductionCA
	default:
		return server.TLSConfig{}, configError{Field: "acme.ca", Err: errors.New("unknown CA")}
	}
	cfg.Issuers = []certmagic.Issuer{issuer}

	cache := certmagic.NewCache(certmagic.CacheOptions{
		GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
			return cfg, nil
		},
		Logger: logger,
	})
	cfg = certmagic.New(cache, *cfg)

	domains := v.GetStringSlice("domains")
	if len(domains) == 0 {
		return server.TLSConfig{}, configError{Field: "acme.domains", Err: errors.New("empty domains")}
	}
	err := cfg.ManageSync(context.Background(), domains)
	if err != nil {
		return server.TLSConfig{}, configError{Field: "acme", Err: err}
	}
	return server.TLSConfig{
		GetCertificate: cfg.GetCertificate,
	}, nil
}

func viperToServerQUICConfig() server.QUICConfig {
	return server.QUICConfig{
		InitialStreamReceiveWindow:     viper.GetUint64("quic.initStreamReceiveWindow"),
		MaxStreamReceiveWindow:         viper.GetUint64("quic.maxStreamReceiveWindow"),
		InitialConnectionReceiveWindow: viper.GetUint64("quic.initConnReceiveWindow"),
		MaxConnectionReceiveWindow:     viper.GetUint64("quic.maxConnReceiveWindow"),
		MaxIdleTimeout:                 viper.GetDuration("quic.maxIdleTimeout"),
		MaxIncomingStreams:             viper.GetInt64("quic.maxIncomingStreams"),
		DisablePathMTUDiscovery:        viper.GetBool("quic.disablePathMTUDiscovery"),
	}
}

func viperToServerBandwidthConfig() (server.BandwidthConfig, error) {
	bw := server.BandwidthConfig{}
	upStr, downStr := viper.GetString("bandwidth.up"), viper.GetString("bandwidth.down")
	if upStr != "" {
		up, err := convBandwidth(upStr)
		if err != nil {
			return server.BandwidthConfig{}, configError{Field: "bandwidth.up", Err: err}
		}
		bw.MaxTx = up
	}
	if downStr != "" {
		down, err := convBandwidth(downStr)
		if err != nil {
			return server.BandwidthConfig{}, configError{Field: "bandwidth.down", Err: err}
		}
		bw.MaxRx = down
	}
	return bw, nil
}

func viperToAuthenticator() (server.Authenticator, error) {
	authType := viper.GetString("auth.type")
	if authType == "" {
		return nil, configError{Field: "auth.type", Err: errors.New("empty auth type")}
	}
	switch authType {
	case "password":
		pw := viper.GetString("auth.password")
		if pw == "" {
			return nil, configError{Field: "auth.password", Err: errors.New("empty auth password")}
		}
		return &auth.PasswordAuthenticator{Password: pw}, nil
	default:
		return nil, configError{Field: "auth.type", Err: errors.New("unsupported auth type")}
	}
}

func viperToMasqHandler() (http.Handler, error) {
	masqType := viper.GetString("masquerade.type")
	if masqType == "" {
		// Default to use the 404 handler
		return http.NotFoundHandler(), nil
	}
	switch masqType {
	case "404":
		return http.NotFoundHandler(), nil
	case "file":
		dir := viper.GetString("masquerade.file.dir")
		if dir == "" {
			return nil, configError{Field: "masquerade.file.dir", Err: errors.New("empty directory")}
		}
		return http.FileServer(http.Dir(dir)), nil
	case "proxy":
		urlStr := viper.GetString("masquerade.proxy.url")
		if urlStr == "" {
			return nil, configError{Field: "masquerade.proxy.url", Err: errors.New("empty url")}
		}
		u, err := url.Parse(urlStr)
		if err != nil {
			return nil, configError{Field: "masquerade.proxy.url", Err: err}
		}
		proxy := &httputil.ReverseProxy{
			Rewrite: func(r *httputil.ProxyRequest) {
				r.SetURL(u)
				// SetURL rewrites the Host header,
				// but we don't want that if rewriteHost is false
				if !viper.GetBool("masquerade.proxy.rewriteHost") {
					r.Out.Host = r.In.Host
				}
			},
			ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
				logger.Error("HTTP reverse proxy error", zap.Error(err))
				w.WriteHeader(http.StatusBadGateway)
			},
		}
		return proxy, nil
	default:
		return nil, configError{Field: "masquerade.type", Err: errors.New("unsupported masquerade type")}
	}
}

type serverLogger struct{}

func (l *serverLogger) Connect(addr net.Addr, id string, tx uint64) {
	logger.Info("client connected", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint64("tx", tx))
}

func (l *serverLogger) Disconnect(addr net.Addr, id string, err error) {
	logger.Info("client disconnected", zap.String("addr", addr.String()), zap.String("id", id), zap.Error(err))
}

func (l *serverLogger) TCPRequest(addr net.Addr, id, reqAddr string) {
	logger.Debug("TCP request", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr))
}

func (l *serverLogger) TCPError(addr net.Addr, id, reqAddr string, err error) {
	if err == nil {
		logger.Debug("TCP closed", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr))
	} else {
		logger.Error("TCP error", zap.String("addr", addr.String()), zap.String("id", id), zap.String("reqAddr", reqAddr), zap.Error(err))
	}
}

func (l *serverLogger) UDPRequest(addr net.Addr, id string, sessionID uint32) {
	logger.Debug("UDP request", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID))
}

func (l *serverLogger) UDPError(addr net.Addr, id string, sessionID uint32, err error) {
	if err == nil {
		logger.Debug("UDP closed", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID))
	} else {
		logger.Error("UDP error", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID), zap.Error(err))
	}
}