package server

import (
	"crypto/tls"
	"net"
	"net/http"
	"time"

	"github.com/apernet/hysteria/core/errors"
	"github.com/apernet/hysteria/core/internal/pmtud"
)

const (
	defaultStreamReceiveWindow = 8388608                            // 8MB
	defaultConnReceiveWindow   = defaultStreamReceiveWindow * 5 / 2 // 20MB
	defaultMaxIdleTimeout      = 30 * time.Second
	defaultMaxIncomingStreams  = 1024
)

type Config struct {
	TLSConfig       TLSConfig
	QUICConfig      QUICConfig
	Conn            net.PacketConn
	Outbound        Outbound
	BandwidthConfig BandwidthConfig
	DisableUDP      bool
	Authenticator   Authenticator
	EventLogger     EventLogger
	TrafficLogger   TrafficLogger
	MasqHandler     http.Handler
}

// fill fills the fields that are not set by the user with default values when possible,
// and returns an error if the user has not set a required field, or if a field is invalid.
func (c *Config) fill() error {
	if len(c.TLSConfig.Certificates) == 0 && c.TLSConfig.GetCertificate == nil {
		return errors.ConfigError{Field: "TLSConfig", Reason: "must set at least one of Certificates or GetCertificate"}
	}
	if c.QUICConfig.InitialStreamReceiveWindow == 0 {
		c.QUICConfig.InitialStreamReceiveWindow = defaultStreamReceiveWindow
	} else if c.QUICConfig.InitialStreamReceiveWindow < 16384 {
		return errors.ConfigError{Field: "QUICConfig.InitialStreamReceiveWindow", Reason: "must be at least 16384"}
	}
	if c.QUICConfig.MaxStreamReceiveWindow == 0 {
		c.QUICConfig.MaxStreamReceiveWindow = defaultStreamReceiveWindow
	} else if c.QUICConfig.MaxStreamReceiveWindow < 16384 {
		return errors.ConfigError{Field: "QUICConfig.MaxStreamReceiveWindow", Reason: "must be at least 16384"}
	}
	if c.QUICConfig.InitialConnectionReceiveWindow == 0 {
		c.QUICConfig.InitialConnectionReceiveWindow = defaultConnReceiveWindow
	} else if c.QUICConfig.InitialConnectionReceiveWindow < 16384 {
		return errors.ConfigError{Field: "QUICConfig.InitialConnectionReceiveWindow", Reason: "must be at least 16384"}
	}
	if c.QUICConfig.MaxConnectionReceiveWindow == 0 {
		c.QUICConfig.MaxConnectionReceiveWindow = defaultConnReceiveWindow
	} else if c.QUICConfig.MaxConnectionReceiveWindow < 16384 {
		return errors.ConfigError{Field: "QUICConfig.MaxConnectionReceiveWindow", Reason: "must be at least 16384"}
	}
	if c.QUICConfig.MaxIdleTimeout == 0 {
		c.QUICConfig.MaxIdleTimeout = defaultMaxIdleTimeout
	} else if c.QUICConfig.MaxIdleTimeout < 4*time.Second || c.QUICConfig.MaxIdleTimeout > 120*time.Second {
		return errors.ConfigError{Field: "QUICConfig.MaxIdleTimeout", Reason: "must be between 4s and 120s"}
	}
	if c.QUICConfig.MaxIncomingStreams == 0 {
		c.QUICConfig.MaxIncomingStreams = defaultMaxIncomingStreams
	} else if c.QUICConfig.MaxIncomingStreams < 8 {
		return errors.ConfigError{Field: "QUICConfig.MaxIncomingStreams", Reason: "must be at least 8"}
	}
	c.QUICConfig.DisablePathMTUDiscovery = c.QUICConfig.DisablePathMTUDiscovery || pmtud.DisablePathMTUDiscovery
	if c.Conn == nil {
		return errors.ConfigError{Field: "Conn", Reason: "must be set"}
	}
	if c.Outbound == nil {
		c.Outbound = &defaultOutbound{}
	}
	if c.BandwidthConfig.MaxTx != 0 && c.BandwidthConfig.MaxTx < 65536 {
		return errors.ConfigError{Field: "BandwidthConfig.MaxTx", Reason: "must be at least 65536"}
	}
	if c.BandwidthConfig.MaxRx != 0 && c.BandwidthConfig.MaxRx < 65536 {
		return errors.ConfigError{Field: "BandwidthConfig.MaxRx", Reason: "must be at least 65536"}
	}
	if c.Authenticator == nil {
		return errors.ConfigError{Field: "Authenticator", Reason: "must be set"}
	}
	return nil
}

// TLSConfig contains the TLS configuration fields that we want to expose to the user.
type TLSConfig struct {
	Certificates   []tls.Certificate
	GetCertificate func(info *tls.ClientHelloInfo) (*tls.Certificate, error)
}

// QUICConfig contains the QUIC configuration fields that we want to expose to the user.
type QUICConfig struct {
	InitialStreamReceiveWindow     uint64
	MaxStreamReceiveWindow         uint64
	InitialConnectionReceiveWindow uint64
	MaxConnectionReceiveWindow     uint64
	MaxIdleTimeout                 time.Duration
	MaxIncomingStreams             int64
	DisablePathMTUDiscovery        bool // The server may still override this to true on unsupported platforms.
}

// Outbound provides the implementation of how the server should connect to remote servers.
type Outbound interface {
	DialTCP(reqAddr string) (net.Conn, error)
	ListenUDP() (UDPConn, error)
}

// UDPConn is like net.PacketConn, but uses string for addresses.
type UDPConn interface {
	ReadFrom(b []byte) (int, string, error)
	WriteTo(b []byte, addr string) (int, error)
	Close() error
}

type defaultOutbound struct{}

func (o *defaultOutbound) DialTCP(reqAddr string) (net.Conn, error) {
	return net.Dial("tcp", reqAddr)
}

func (o *defaultOutbound) ListenUDP() (UDPConn, error) {
	conn, err := net.ListenUDP("udp", nil)
	if err != nil {
		return nil, err
	}
	return &defaultUDPConn{conn}, nil
}

type defaultUDPConn struct {
	*net.UDPConn
}

func (c *defaultUDPConn) ReadFrom(b []byte) (int, string, error) {
	n, addr, err := c.UDPConn.ReadFrom(b)
	if addr != nil {
		return n, addr.String(), err
	} else {
		return n, "", err
	}
}

func (c *defaultUDPConn) WriteTo(b []byte, addr string) (int, error) {
	uAddr, err := net.ResolveUDPAddr("udp", addr)
	if err != nil {
		return 0, err
	}
	return c.UDPConn.WriteTo(b, uAddr)
}

// BandwidthConfig describes the maximum bandwidth that the server can use, in bytes per second.
type BandwidthConfig struct {
	MaxTx uint64
	MaxRx uint64
}

// Authenticator is an interface that provides authentication logic.
type Authenticator interface {
	Authenticate(addr net.Addr, auth string, tx uint64) (ok bool, id string)
}

// EventLogger is an interface that provides logging logic.
type EventLogger interface {
	Connect(addr net.Addr, id string, tx uint64)
	Disconnect(addr net.Addr, id string, err error)
	TCPRequest(addr net.Addr, id, reqAddr string)
	TCPError(addr net.Addr, id, reqAddr string, err error)
	UDPRequest(addr net.Addr, id string, sessionID uint32)
	UDPError(addr net.Addr, id string, sessionID uint32, err error)
}

// TrafficLogger is an interface that provides traffic logging logic.
// Tx/Rx in this context refers to the server-remote (proxy target) perspective.
// Tx is the bytes sent from the server to the remote.
// Rx is the bytes received by the server from the remote.
// Apart from logging, the Log function can also return false to signal
// that the client should be disconnected. This can be used to implement
// bandwidth limits or post-connection authentication, for example.
// The implementation of this interface must be thread-safe.
type TrafficLogger interface {
	Log(id string, tx, rx uint64) (ok bool)
}