package forwarding

import (
	"net"
	"sync"
	"sync/atomic"
	"time"

	"github.com/apernet/hysteria/core/client"
)

const (
	udpBufferSize = 4096

	defaultTimeout = 5 * time.Minute
)

type UDPTunnel struct {
	HyClient    client.Client
	Remote      string
	Timeout     time.Duration
	EventLogger UDPEventLogger
}

type UDPEventLogger interface {
	Connect(addr net.Addr)
	Error(addr net.Addr, err error)
}

type sessionEntry struct {
	HyConn   client.HyUDPConn
	Deadline atomic.Value
}

type sessionManager struct {
	SessionMap  map[string]*sessionEntry
	Timeout     time.Duration
	TimeoutFunc func(addr net.Addr)
	Mutex       sync.RWMutex
}

func (sm *sessionManager) New(addr net.Addr, hyConn client.HyUDPConn) {
	entry := &sessionEntry{
		HyConn: hyConn,
	}
	entry.Deadline.Store(time.Now().Add(sm.Timeout))

	// Timeout cleanup routine
	go func() {
		for {
			ttl := entry.Deadline.Load().(time.Time).Sub(time.Now())
			if ttl <= 0 {
				// Inactive for too long, close the session
				sm.Mutex.Lock()
				delete(sm.SessionMap, addr.String())
				sm.Mutex.Unlock()
				_ = hyConn.Close()
				if sm.TimeoutFunc != nil {
					sm.TimeoutFunc(addr)
				}
				return
			} else {
				time.Sleep(ttl)
			}
		}
	}()

	sm.Mutex.Lock()
	defer sm.Mutex.Unlock()
	sm.SessionMap[addr.String()] = entry
}

func (sm *sessionManager) Get(addr net.Addr) client.HyUDPConn {
	sm.Mutex.RLock()
	defer sm.Mutex.RUnlock()
	if entry, ok := sm.SessionMap[addr.String()]; ok {
		return entry.HyConn
	} else {
		return nil
	}
}

func (sm *sessionManager) Renew(addr net.Addr) {
	sm.Mutex.RLock() // RLock is enough as we are not modifying the map itself, only a value in the entry
	defer sm.Mutex.RUnlock()
	if entry, ok := sm.SessionMap[addr.String()]; ok {
		entry.Deadline.Store(time.Now().Add(sm.Timeout))
	}
}

func (t *UDPTunnel) Serve(listener net.PacketConn) error {
	sm := &sessionManager{
		SessionMap:  make(map[string]*sessionEntry),
		Timeout:     t.Timeout,
		TimeoutFunc: func(addr net.Addr) { t.EventLogger.Error(addr, nil) },
	}
	if sm.Timeout <= 0 {
		sm.Timeout = defaultTimeout
	}
	buf := make([]byte, udpBufferSize)
	for {
		n, addr, err := listener.ReadFrom(buf)
		if err != nil {
			return err
		}
		t.handle(listener, sm, addr, buf[:n])
	}
}

func (t *UDPTunnel) handle(l net.PacketConn, sm *sessionManager, addr net.Addr, data []byte) {
	hyConn := sm.Get(addr)
	if hyConn != nil {
		// Existing session
		_ = hyConn.Send(data, t.Remote)
		sm.Renew(addr)
	} else {
		// New session
		if t.EventLogger != nil {
			t.EventLogger.Connect(addr)
		}
		hyConn, err := t.HyClient.ListenUDP()
		if err != nil {
			if t.EventLogger != nil {
				t.EventLogger.Error(addr, err)
			}
			return
		}
		sm.New(addr, hyConn)
		_ = hyConn.Send(data, t.Remote)

		// Local <- Remote routine
		go func() {
			for {
				data, _, err := hyConn.Receive()
				if err != nil {
					return
				}
				_, err = l.WriteTo(data, addr)
				if err != nil {
					return
				}
				sm.Renew(addr)
			}
		}()
	}
}