package server

import (
	"context"
	"crypto/tls"
	"errors"
	"io"
	"math/rand"
	"net/http"
	"sync"

	"github.com/apernet/hysteria/core/internal/congestion"
	"github.com/apernet/hysteria/core/internal/frag"
	"github.com/apernet/hysteria/core/internal/protocol"
	"github.com/apernet/hysteria/core/internal/utils"

	"github.com/quic-go/quic-go"
	"github.com/quic-go/quic-go/http3"
)

type Server interface {
	Serve() error
	Close() error
}

func NewServer(config *Config) (Server, error) {
	if err := config.fill(); err != nil {
		return nil, err
	}
	tlsConfig := http3.ConfigureTLSConfig(&tls.Config{
		Certificates:   config.TLSConfig.Certificates,
		GetCertificate: config.TLSConfig.GetCertificate,
	})
	quicConfig := &quic.Config{
		InitialStreamReceiveWindow:     config.QUICConfig.InitialStreamReceiveWindow,
		MaxStreamReceiveWindow:         config.QUICConfig.MaxStreamReceiveWindow,
		InitialConnectionReceiveWindow: config.QUICConfig.InitialConnectionReceiveWindow,
		MaxConnectionReceiveWindow:     config.QUICConfig.MaxConnectionReceiveWindow,
		MaxIdleTimeout:                 config.QUICConfig.MaxIdleTimeout,
		MaxIncomingStreams:             config.QUICConfig.MaxIncomingStreams,
		DisablePathMTUDiscovery:        config.QUICConfig.DisablePathMTUDiscovery,
		EnableDatagrams:                true,
	}
	listener, err := quic.Listen(config.Conn, tlsConfig, quicConfig)
	if err != nil {
		_ = config.Conn.Close()
		return nil, err
	}
	return &serverImpl{
		config:   config,
		listener: listener,
	}, nil
}

type serverImpl struct {
	config   *Config
	listener *quic.Listener
}

func (s *serverImpl) Serve() error {
	for {
		conn, err := s.listener.Accept(context.Background())
		if err != nil {
			return err
		}
		go s.handleClient(conn)
	}
}

func (s *serverImpl) Close() error {
	err := s.listener.Close()
	_ = s.config.Conn.Close()
	return err
}

func (s *serverImpl) handleClient(conn quic.Connection) {
	handler := newH3sHandler(s.config, conn)
	h3s := http3.Server{
		EnableDatagrams: true,
		Handler:         handler,
		StreamHijacker:  handler.ProxyStreamHijacker,
	}
	err := h3s.ServeQUICConn(conn)
	// If the client is authenticated, we need to log the disconnect event
	if handler.authenticated && s.config.EventLogger != nil {
		s.config.EventLogger.Disconnect(conn.RemoteAddr(), handler.authID, err)
	}
	_ = conn.CloseWithError(0, "")
}

type h3sHandler struct {
	config *Config
	conn   quic.Connection

	authenticated bool
	authID        string

	udpOnce sync.Once
	udpSM   udpSessionManager
}

func newH3sHandler(config *Config, conn quic.Connection) *h3sHandler {
	return &h3sHandler{
		config: config,
		conn:   conn,
		udpSM: udpSessionManager{
			listenFunc: config.Outbound.ListenUDP,
			m:          make(map[uint32]*udpSessionEntry),
		},
	}
}

type udpSessionEntry struct {
	Conn   UDPConn
	D      *frag.Defragger
	Closed bool
}

type udpSessionManager struct {
	listenFunc func() (UDPConn, error)
	mutex      sync.RWMutex
	m          map[uint32]*udpSessionEntry
	nextID     uint32
}

// Add returns the session ID, the UDP connection and a function to close the UDP connection & delete the session.
func (m *udpSessionManager) Add() (uint32, UDPConn, func(), error) {
	conn, err := m.listenFunc()
	if err != nil {
		return 0, nil, nil, err
	}

	m.mutex.Lock()
	defer m.mutex.Unlock()
	id := m.nextID
	m.nextID++
	entry := &udpSessionEntry{
		Conn:   conn,
		D:      &frag.Defragger{},
		Closed: false,
	}
	m.m[id] = entry

	return id, conn, func() {
		m.mutex.Lock()
		defer m.mutex.Unlock()
		if entry.Closed {
			// Already closed
			return
		}
		entry.Closed = true
		_ = conn.Close()
		delete(m.m, id)
	}, nil
}

func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) {
	m.mutex.RLock()
	defer m.mutex.RUnlock()

	entry, ok := m.m[msg.SessionID]
	if !ok {
		// No such session, drop the message
		return
	}
	dfMsg := entry.D.Feed(msg)
	if dfMsg == nil {
		// Not a complete message yet
		return
	}
	_, _ = entry.Conn.WriteTo(dfMsg.Data, dfMsg.Addr)
}

func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
	if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath {
		if h.authenticated {
			// Already authenticated
			protocol.AuthResponseDataToHeader(w.Header(), h.config.BandwidthConfig.MaxRx)
			w.WriteHeader(protocol.StatusAuthOK)
			return
		}
		auth, clientRx := protocol.AuthRequestDataFromHeader(r.Header)
		// actualTx = min(serverTx, clientRx)
		actualTx := clientRx
		if h.config.BandwidthConfig.MaxTx > 0 && actualTx > h.config.BandwidthConfig.MaxTx {
			actualTx = h.config.BandwidthConfig.MaxTx
		}
		ok, id := h.config.Authenticator.Authenticate(h.conn.RemoteAddr(), auth, actualTx)
		if ok {
			// Set authenticated flag
			h.authenticated = true
			h.authID = id
			// Update congestion control when applicable
			if actualTx > 0 {
				h.conn.SetCongestionControl(congestion.NewBrutalSender(actualTx))
			}
			// Auth OK, send response
			protocol.AuthResponseDataToHeader(w.Header(), h.config.BandwidthConfig.MaxRx)
			w.WriteHeader(protocol.StatusAuthOK)
			// Call event logger
			if h.config.EventLogger != nil {
				h.config.EventLogger.Connect(h.conn.RemoteAddr(), id, actualTx)
			}
			// Start UDP loop if UDP is not disabled
			// We use sync.Once to make sure that only one goroutine is started,
			// as ServeHTTP may be called by multiple goroutines simultaneously
			if !h.config.DisableUDP {
				h.udpOnce.Do(func() {
					go h.udpLoop()
				})
			}
		} else {
			// Auth failed, pretend to be a normal HTTP server
			h.masqHandler(w, r)
		}
	} else {
		// Not an auth request, pretend to be a normal HTTP server
		h.masqHandler(w, r)
	}
}

func (h *h3sHandler) ProxyStreamHijacker(ft http3.FrameType, conn quic.Connection, stream quic.Stream, err error) (bool, error) {
	if err != nil || !h.authenticated {
		return false, nil
	}

	// Wraps the stream with QStream, which handles Close() properly
	stream = &utils.QStream{Stream: stream}

	switch ft {
	case protocol.FrameTypeTCPRequest:
		go h.handleTCPRequest(stream)
		return true, nil
	case protocol.FrameTypeUDPRequest:
		go h.handleUDPRequest(stream)
		return true, nil
	default:
		return false, nil
	}
}

func (h *h3sHandler) handleTCPRequest(stream quic.Stream) {
	// Read request
	reqAddr, err := protocol.ReadTCPRequest(stream)
	if err != nil {
		_ = stream.Close()
		return
	}
	// Log the event
	if h.config.EventLogger != nil {
		h.config.EventLogger.TCPRequest(h.conn.RemoteAddr(), h.authID, reqAddr)
	}
	// Dial target
	tConn, err := h.config.Outbound.DialTCP(reqAddr)
	if err != nil {
		_ = protocol.WriteTCPResponse(stream, false, err.Error())
		_ = stream.Close()
		// Log the error
		if h.config.EventLogger != nil {
			h.config.EventLogger.TCPError(h.conn.RemoteAddr(), h.authID, reqAddr, err)
		}
		return
	}
	_ = protocol.WriteTCPResponse(stream, true, "")
	// Start proxying
	copyErrChan := make(chan error, 2)
	go func() {
		_, err := io.Copy(tConn, stream)
		copyErrChan <- err
	}()
	go func() {
		_, err := io.Copy(stream, tConn)
		copyErrChan <- err
	}()
	// Block until one of the copy goroutines exits
	err = <-copyErrChan
	if h.config.EventLogger != nil {
		h.config.EventLogger.TCPError(h.conn.RemoteAddr(), h.authID, reqAddr, err)
	}
	// Cleanup
	_ = tConn.Close()
	_ = stream.Close()
}

func (h *h3sHandler) handleUDPRequest(stream quic.Stream) {
	if h.config.DisableUDP {
		// UDP is disabled, send error message and close the stream
		_ = protocol.WriteUDPResponse(stream, false, 0, "UDP is disabled on this server")
		_ = stream.Close()
		return
	}
	// Read request
	err := protocol.ReadUDPRequest(stream)
	if err != nil {
		_ = stream.Close()
		return
	}
	// Add to session manager
	sessionID, conn, connCloseFunc, err := h.udpSM.Add()
	if err != nil {
		_ = protocol.WriteUDPResponse(stream, false, 0, err.Error())
		_ = stream.Close()
		return
	}
	// Send response
	_ = protocol.WriteUDPResponse(stream, true, sessionID, "")
	// Call event logger
	if h.config.EventLogger != nil {
		h.config.EventLogger.UDPRequest(h.conn.RemoteAddr(), h.authID, sessionID)
	}

	// client <- remote direction
	go func() {
		udpBuf := make([]byte, protocol.MaxUDPSize)
		msgBuf := make([]byte, protocol.MaxUDPSize)
		for {
			udpN, rAddr, err := conn.ReadFrom(udpBuf)
			if udpN > 0 {
				// Try no frag first
				msg := protocol.UDPMessage{
					SessionID: sessionID,
					PacketID:  0,
					FragID:    0,
					FragCount: 1,
					Addr:      rAddr,
					Data:      udpBuf[:udpN],
				}
				msgN := msg.Serialize(msgBuf)
				if msgN < 0 {
					// Message even larger than MaxUDPSize, drop it
					continue
				}
				sendErr := h.conn.SendMessage(msgBuf[:msgN])
				var errTooLarge quic.ErrMessageTooLarge
				if errors.As(sendErr, &errTooLarge) {
					// Message too large, try fragmentation
					msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
					fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
					for _, fMsg := range fMsgs {
						msgN = fMsg.Serialize(msgBuf)
						_ = h.conn.SendMessage(msgBuf[:msgN])
					}
				}
			}
			if err != nil {
				break
			}
		}
		connCloseFunc()
		_ = stream.Close()
	}()

	// Hold (drain) the stream until the client closes it.
	// Closing the stream is the signal to stop the UDP session.
	_, err = io.Copy(io.Discard, stream)
	// Call event logger
	if h.config.EventLogger != nil {
		h.config.EventLogger.UDPError(h.conn.RemoteAddr(), h.authID, sessionID, err)
	}

	// Cleanup
	connCloseFunc()
	_ = stream.Close()
}

func (h *h3sHandler) udpLoop() {
	for {
		msg, err := h.conn.ReceiveMessage()
		if err != nil {
			return
		}
		h.handleUDPMessage(msg)
	}
}

// client -> remote direction
func (h *h3sHandler) handleUDPMessage(msg []byte) {
	udpMsg, err := protocol.ParseUDPMessage(msg)
	if err != nil {
		return
	}
	h.udpSM.Feed(udpMsg)
}

func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {
	if h.config.MasqHandler != nil {
		h.config.MasqHandler.ServeHTTP(w, r)
	} else {
		// Return 404 for everything
		http.NotFound(w, r)
	}
}