package integration_tests

import (
	"io"
	"net"
	"sync/atomic"
	"testing"

	"github.com/apernet/hysteria/core/client"
	"github.com/apernet/hysteria/core/server"
)

type testTrafficLogger struct {
	Tx, Rx uint64
	Block  atomic.Bool
}

func (l *testTrafficLogger) Log(id string, tx, rx uint64) bool {
	atomic.AddUint64(&l.Tx, tx)
	atomic.AddUint64(&l.Rx, rx)
	return !l.Block.Load()
}

func (l *testTrafficLogger) Get() (tx, rx uint64) {
	return atomic.LoadUint64(&l.Tx), atomic.LoadUint64(&l.Rx)
}

func (l *testTrafficLogger) SetBlock(block bool) {
	l.Block.Store(block)
}

func (l *testTrafficLogger) Reset() {
	atomic.StoreUint64(&l.Tx, 0)
	atomic.StoreUint64(&l.Rx, 0)
}

// TestServerTrafficLogger tests that the server's TrafficLogger interface is working correctly.
// More specifically, it tests that the server is correctly logging traffic in both directions,
// and that it is correctly disconnecting clients when the traffic logger returns false.
func TestServerTrafficLogger(t *testing.T) {
	tl := &testTrafficLogger{}

	// Create server
	udpAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 14514}
	udpConn, err := net.ListenUDP("udp", udpAddr)
	if err != nil {
		t.Fatal("error creating server:", err)
	}
	s, err := server.NewServer(&server.Config{
		TLSConfig: serverTLSConfig(),
		Conn:      udpConn,
		Authenticator: &pwAuthenticator{
			Password: "password",
			ID:       "nobody",
		},
		TrafficLogger: tl,
	})
	if err != nil {
		t.Fatal("error creating server:", err)
	}
	defer s.Close()
	go s.Serve()

	// Create TCP double echo server
	// We use double echo to test that the traffic logger is correctly logging both directions.
	echoTCPAddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 14515}
	echoListener, err := net.ListenTCP("tcp", echoTCPAddr)
	if err != nil {
		t.Fatal("error creating TCP echo server:", err)
	}
	tEchoServer := &tcpDoubleEchoServer{Listener: echoListener}
	defer tEchoServer.Close()
	go tEchoServer.Serve()

	// Create client
	c, err := client.NewClient(&client.Config{
		ServerAddr: udpAddr,
		Auth:       "password",
		TLSConfig:  client.TLSConfig{InsecureSkipVerify: true},
	})
	if err != nil {
		t.Fatal("error creating client:", err)
	}
	defer c.Close()

	// Dial TCP
	tConn, err := c.DialTCP(echoTCPAddr.String())
	if err != nil {
		t.Fatal("error dialing TCP:", err)
	}
	defer tConn.Close()

	// Send and receive TCP data
	sData := []byte("1234")
	_, err = tConn.Write(sData)
	if err != nil {
		t.Fatal("error writing to TCP:", err)
	}
	rData := make([]byte, len(sData)*2)
	_, err = io.ReadFull(tConn, rData)
	if err != nil {
		t.Fatal("error reading from TCP:", err)
	}
	expected := string(sData) + string(sData)
	if string(rData) != expected {
		t.Fatalf("expected %q, got %q", expected, string(rData))
	}

	// Check traffic logger
	tx, rx := tl.Get()
	if tx != uint64(len(sData)) || rx != uint64(len(rData)) {
		t.Fatalf("expected TrafficLogger Tx=%d, Rx=%d, got Tx=%d, Rx=%d", len(sData), len(rData), tx, rx)
	}
	tl.Reset()

	// Create UDP double echo server
	echoUDPAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 55555}
	echoConn, err := net.ListenUDP("udp", echoUDPAddr)
	if err != nil {
		t.Fatal("error creating UDP echo server:", err)
	}
	uEchoServer := &udpDoubleEchoServer{Conn: echoConn}
	defer uEchoServer.Close()
	go uEchoServer.Serve()

	// Listen UDP
	uConn, err := c.ListenUDP()
	if err != nil {
		t.Fatal("error listening UDP:", err)
	}
	defer uConn.Close()

	// Send and receive UDP data
	sData = []byte("gucci gang")
	err = uConn.Send(sData, echoUDPAddr.String())
	if err != nil {
		t.Fatal("error sending UDP:", err)
	}
	for i := 0; i < 2; i++ {
		rData, rAddr, err := uConn.Receive()
		if err != nil {
			t.Fatal("error receiving UDP:", err)
		}
		if string(rData) != string(sData) {
			t.Fatalf("expected %q, got %q", string(sData), string(rData))
		}
		if rAddr != echoUDPAddr.String() {
			t.Fatalf("expected %q, got %q", echoUDPAddr.String(), rAddr)
		}
	}

	// Check traffic logger
	tx, rx = tl.Get()
	if tx != uint64(len(sData)) || rx != uint64(len(sData)*2) {
		t.Fatalf("expected TrafficLogger Tx=%d, Rx=%d, got Tx=%d, Rx=%d", len(sData), len(sData)*2, tx, rx)
	}

	// Check the disconnect client functionality
	tl.SetBlock(true)

	// Send and receive TCP data again
	sData = []byte("1234")
	_, err = tConn.Write(sData)
	if err != nil {
		t.Fatal("error writing to TCP:", err)
	}
	// This should fail instantly without reading any data
	// io.Copy should return nil as EOF is treated as a non-error though
	n, err := io.Copy(io.Discard, tConn)
	if n != 0 || err != nil {
		t.Fatal("expected 0 bytes read and nil error, got", n, err)
	}
}