package tproxy

import (
	"errors"
	"github.com/LiamHaworth/go-tproxy"
	"github.com/tobyxdd/hysteria/pkg/acl"
	"github.com/tobyxdd/hysteria/pkg/core"
	"github.com/tobyxdd/hysteria/pkg/transport"
	"github.com/tobyxdd/hysteria/pkg/utils"
	"net"
	"strconv"
	"sync"
	"sync/atomic"
	"time"
)

const udpBufferSize = 65535

var ErrTimeout = errors.New("inactivity timeout")

type UDPTProxy struct {
	HyClient   *core.Client
	Transport  transport.Transport
	ListenAddr *net.UDPAddr
	Timeout    time.Duration
	ACLEngine  *acl.Engine

	ConnFunc  func(addr net.Addr)
	ErrorFunc func(addr net.Addr, err error)
}

func NewUDPTProxy(hyClient *core.Client, transport transport.Transport, listen string, timeout time.Duration,
	aclEngine *acl.Engine,
	connFunc func(addr net.Addr), errorFunc func(addr net.Addr, err error)) (*UDPTProxy, error) {
	uAddr, err := transport.LocalResolveUDPAddr(listen)
	if err != nil {
		return nil, err
	}
	r := &UDPTProxy{
		HyClient:   hyClient,
		Transport:  transport,
		ListenAddr: uAddr,
		Timeout:    timeout,
		ACLEngine:  aclEngine,
		ConnFunc:   connFunc,
		ErrorFunc:  errorFunc,
	}
	if timeout == 0 {
		r.Timeout = 1 * time.Minute
	}
	return r, nil
}

type connEntry struct {
	HyConn    core.UDPConn
	LocalConn *net.UDPConn
	Deadline  atomic.Value
}

func (r *UDPTProxy) sendPacket(entry *connEntry, dstAddr *net.UDPAddr, data []byte) error {
	entry.Deadline.Store(time.Now().Add(r.Timeout))
	host, port, err := utils.SplitHostPort(dstAddr.String())
	if err != nil {
		return err
	}
	action, arg := acl.ActionProxy, ""
	var ipAddr *net.IPAddr
	var resErr error
	if r.ACLEngine != nil && entry.LocalConn != nil {
		action, arg, ipAddr, resErr = r.ACLEngine.ResolveAndMatch(host)
		// Doesn't always matter if the resolution fails, as we may send it through HyClient
	}
	switch action {
	case acl.ActionDirect:
		if resErr != nil {
			return resErr
		}
		_, err = entry.LocalConn.WriteToUDP(data, &net.UDPAddr{
			IP:   ipAddr.IP,
			Port: int(port),
			Zone: ipAddr.Zone,
		})
		return err
	case acl.ActionProxy:
		return entry.HyConn.WriteTo(data, dstAddr.String())
	case acl.ActionBlock:
		// Do nothing
		return nil
	case acl.ActionHijack:
		hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port)))
		rAddr, err := r.Transport.LocalResolveUDPAddr(hijackAddr)
		if err != nil {
			return err
		}
		_, err = entry.LocalConn.WriteToUDP(data, rAddr)
		return err
	default:
		// Do nothing
		return nil
	}
}

func (r *UDPTProxy) ListenAndServe() error {
	conn, err := tproxy.ListenUDP("udp", r.ListenAddr)
	if err != nil {
		return err
	}
	defer conn.Close()
	// src <-> HyClient UDPConn
	connMap := make(map[string]*connEntry)
	var connMapMutex sync.RWMutex
	// Read loop
	buf := make([]byte, udpBufferSize)
	for {
		n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(conn, buf)
		if n > 0 {
			connMapMutex.RLock()
			entry := connMap[srcAddr.String()]
			connMapMutex.RUnlock()
			if entry != nil {
				// Existing conn
				_ = r.sendPacket(entry, dstAddr, buf[:n])
			} else {
				// New
				r.ConnFunc(srcAddr)
				hyConn, err := r.HyClient.DialUDP()
				if err != nil {
					r.ErrorFunc(srcAddr, err)
					continue
				}
				var localConn *net.UDPConn
				if r.ACLEngine != nil {
					localConn, err = r.Transport.LocalListenUDP(nil)
					if err != nil {
						r.ErrorFunc(srcAddr, err)
						continue
					}
				}
				// Send
				entry := &connEntry{HyConn: hyConn, LocalConn: localConn}
				_ = r.sendPacket(entry, dstAddr, buf[:n])
				// Add it to the map
				connMapMutex.Lock()
				connMap[srcAddr.String()] = entry
				connMapMutex.Unlock()
				// Start remote to local
				go func() {
					for {
						bs, _, err := hyConn.ReadFrom()
						if err != nil {
							break
						}
						entry.Deadline.Store(time.Now().Add(r.Timeout))
						_, _ = conn.WriteToUDP(bs, srcAddr)
					}
				}()
				if localConn != nil {
					go func() {
						buf := make([]byte, udpBufferSize)
						for {
							n, _, err := localConn.ReadFrom(buf)
							if n > 0 {
								entry.Deadline.Store(time.Now().Add(r.Timeout))
								_, _ = conn.WriteToUDP(buf[:n], srcAddr)
							}
							if err != nil {
								break
							}
						}
					}()
				}
				// Timeout cleanup routine
				go func() {
					for {
						ttl := entry.Deadline.Load().(time.Time).Sub(time.Now())
						if ttl <= 0 {
							// Time to die
							connMapMutex.Lock()
							_ = hyConn.Close()
							if localConn != nil {
								_ = localConn.Close()
							}
							delete(connMap, srcAddr.String())
							connMapMutex.Unlock()
							r.ErrorFunc(srcAddr, ErrTimeout)
							return
						} else {
							time.Sleep(ttl)
						}
					}
				}()
			}
		}
		if err != nil {
			return err
		}
	}
}