package socks5

import (
	"encoding/binary"
	"errors"
	"fmt"
	"strconv"

	"github.com/HyNetwork/hysteria/pkg/acl"
	"github.com/HyNetwork/hysteria/pkg/core"
	"github.com/HyNetwork/hysteria/pkg/transport"
	"github.com/HyNetwork/hysteria/pkg/utils"
)

import (
	"net"
	"time"

	"github.com/txthinking/socks5"
)

const udpBufferSize = 65535

var (
	ErrUnsupportedCmd = errors.New("unsupported command")
	ErrUserPassAuth   = errors.New("invalid username or password")
)

type Server struct {
	HyClient   *core.Client
	Transport  *transport.ClientTransport
	AuthFunc   func(username, password string) bool
	Method     byte
	TCPAddr    *net.TCPAddr
	TCPTimeout time.Duration
	ACLEngine  *acl.Engine
	DisableUDP bool

	TCPRequestFunc   func(addr net.Addr, reqAddr string, action acl.Action, arg string)
	TCPErrorFunc     func(addr net.Addr, reqAddr string, err error)
	UDPAssociateFunc func(addr net.Addr)
	UDPErrorFunc     func(addr net.Addr, err error)

	tcpListener *net.TCPListener
}

func NewServer(hyClient *core.Client, transport *transport.ClientTransport, addr string,
	authFunc func(username, password string) bool, tcpTimeout time.Duration,
	aclEngine *acl.Engine, disableUDP bool,
	tcpReqFunc func(addr net.Addr, reqAddr string, action acl.Action, arg string),
	tcpErrorFunc func(addr net.Addr, reqAddr string, err error),
	udpAssocFunc func(addr net.Addr), udpErrorFunc func(addr net.Addr, err error),
) (*Server, error) {
	tAddr, err := net.ResolveTCPAddr("tcp", addr)
	if err != nil {
		return nil, err
	}
	m := socks5.MethodNone
	if authFunc != nil {
		m = socks5.MethodUsernamePassword
	}
	s := &Server{
		HyClient:         hyClient,
		Transport:        transport,
		AuthFunc:         authFunc,
		Method:           m,
		TCPAddr:          tAddr,
		TCPTimeout:       tcpTimeout,
		ACLEngine:        aclEngine,
		DisableUDP:       disableUDP,
		TCPRequestFunc:   tcpReqFunc,
		TCPErrorFunc:     tcpErrorFunc,
		UDPAssociateFunc: udpAssocFunc,
		UDPErrorFunc:     udpErrorFunc,
	}
	return s, nil
}

func (s *Server) negotiate(c *net.TCPConn) error {
	rq, err := socks5.NewNegotiationRequestFrom(c)
	if err != nil {
		return err
	}
	var got bool
	var m byte
	for _, m = range rq.Methods {
		if m == s.Method {
			got = true
		}
	}
	if !got {
		rp := socks5.NewNegotiationReply(socks5.MethodUnsupportAll)
		if _, err := rp.WriteTo(c); err != nil {
			return err
		}
	}
	rp := socks5.NewNegotiationReply(s.Method)
	if _, err := rp.WriteTo(c); err != nil {
		return err
	}

	if s.Method == socks5.MethodUsernamePassword {
		urq, err := socks5.NewUserPassNegotiationRequestFrom(c)
		if err != nil {
			return err
		}
		if !s.AuthFunc(string(urq.Uname), string(urq.Passwd)) {
			urp := socks5.NewUserPassNegotiationReply(socks5.UserPassStatusFailure)
			if _, err := urp.WriteTo(c); err != nil {
				return err
			}
			return ErrUserPassAuth
		}
		urp := socks5.NewUserPassNegotiationReply(socks5.UserPassStatusSuccess)
		if _, err := urp.WriteTo(c); err != nil {
			return err
		}
	}
	return nil
}

func (s *Server) ListenAndServe() error {
	var err error
	s.tcpListener, err = net.ListenTCP("tcp", s.TCPAddr)
	if err != nil {
		return err
	}
	defer s.tcpListener.Close()
	for {
		c, err := s.tcpListener.AcceptTCP()
		if err != nil {
			return err
		}
		go func() {
			defer c.Close()
			if s.TCPTimeout != 0 {
				if err := c.SetDeadline(time.Now().Add(s.TCPTimeout)); err != nil {
					return
				}
			}
			if err := s.negotiate(c); err != nil {
				return
			}
			r, err := socks5.NewRequestFrom(c)
			if err != nil {
				return
			}
			_ = s.handle(c, r)
		}()
	}
}

func (s *Server) handle(c *net.TCPConn, r *socks5.Request) error {
	if r.Cmd == socks5.CmdConnect {
		// TCP
		return s.handleTCP(c, r)
	} else if r.Cmd == socks5.CmdUDP {
		// UDP
		if !s.DisableUDP {
			return s.handleUDP(c, r)
		} else {
			_ = sendReply(c, socks5.RepCommandNotSupported)
			return ErrUnsupportedCmd
		}
	} else {
		_ = sendReply(c, socks5.RepCommandNotSupported)
		return ErrUnsupportedCmd
	}
}

func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
	host, port, addr := parseRequestAddress(r)
	action, arg := acl.ActionProxy, ""
	var ipAddr *net.IPAddr
	var resErr error
	if s.ACLEngine != nil {
		action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host, port, false)
		// Doesn't always matter if the resolution fails, as we may send it through HyClient
	}
	s.TCPRequestFunc(c.RemoteAddr(), addr, action, arg)
	var closeErr error
	defer func() {
		s.TCPErrorFunc(c.RemoteAddr(), addr, closeErr)
	}()
	// Handle according to the action
	switch action {
	case acl.ActionDirect:
		if resErr != nil {
			_ = sendReply(c, socks5.RepHostUnreachable)
			closeErr = resErr
			return resErr
		}
		rc, err := s.Transport.DialTCP(&net.TCPAddr{
			IP:   ipAddr.IP,
			Port: int(port),
			Zone: ipAddr.Zone,
		})
		if err != nil {
			_ = sendReply(c, socks5.RepHostUnreachable)
			closeErr = err
			return err
		}
		defer rc.Close()
		_ = sendReply(c, socks5.RepSuccess)
		closeErr = utils.PipePairWithTimeout(c, rc, s.TCPTimeout)
		return nil
	case acl.ActionProxy:
		rc, err := s.HyClient.DialTCP(addr)
		if err != nil {
			_ = sendReply(c, socks5.RepHostUnreachable)
			closeErr = err
			return err
		}
		defer rc.Close()
		_ = sendReply(c, socks5.RepSuccess)
		closeErr = utils.PipePairWithTimeout(c, rc, s.TCPTimeout)
		return nil
	case acl.ActionBlock:
		_ = sendReply(c, socks5.RepHostUnreachable)
		closeErr = errors.New("blocked in ACL")
		return nil
	case acl.ActionHijack:
		hijackIPAddr, err := s.Transport.ResolveIPAddr(arg)
		if err != nil {
			_ = sendReply(c, socks5.RepHostUnreachable)
			closeErr = err
			return err
		}
		rc, err := s.Transport.DialTCP(&net.TCPAddr{
			IP:   hijackIPAddr.IP,
			Port: int(port),
			Zone: hijackIPAddr.Zone,
		})
		if err != nil {
			_ = sendReply(c, socks5.RepHostUnreachable)
			closeErr = err
			return err
		}
		defer rc.Close()
		_ = sendReply(c, socks5.RepSuccess)
		closeErr = utils.PipePairWithTimeout(c, rc, s.TCPTimeout)
		return nil
	default:
		_ = sendReply(c, socks5.RepServerFailure)
		closeErr = fmt.Errorf("unknown action %d", action)
		return nil
	}
}

func (s *Server) handleUDP(c *net.TCPConn, r *socks5.Request) error {
	s.UDPAssociateFunc(c.RemoteAddr())
	var closeErr error
	defer func() {
		s.UDPErrorFunc(c.RemoteAddr(), closeErr)
	}()
	// Start local UDP server
	udpConn, err := net.ListenUDP("udp", &net.UDPAddr{
		IP:   s.TCPAddr.IP,
		Zone: s.TCPAddr.Zone,
	})
	if err != nil {
		_ = sendReply(c, socks5.RepServerFailure)
		closeErr = err
		return err
	}
	defer udpConn.Close()
	// Local UDP relay conn for ACL Direct
	var localRelayConn *net.UDPConn
	if s.ACLEngine != nil {
		localRelayConn, err = s.Transport.ListenUDP()
		if err != nil {
			_ = sendReply(c, socks5.RepServerFailure)
			closeErr = err
			return err
		}
		defer localRelayConn.Close()
	}
	// HyClient UDP session
	hyUDP, err := s.HyClient.DialUDP()
	if err != nil {
		_ = sendReply(c, socks5.RepServerFailure)
		closeErr = err
		return err
	}
	defer hyUDP.Close()
	// Send UDP server addr to the client
	// Same IP as TCP but a different port
	tcpLocalAddr := c.LocalAddr().(*net.TCPAddr)
	var atyp byte
	var addr, port []byte
	if ip4 := tcpLocalAddr.IP.To4(); ip4 != nil {
		atyp = socks5.ATYPIPv4
		addr = ip4
	} else if ip6 := tcpLocalAddr.IP.To16(); ip6 != nil {
		atyp = socks5.ATYPIPv6
		addr = ip6
	} else {
		_ = sendReply(c, socks5.RepServerFailure)
		closeErr = errors.New("invalid local addr")
		return closeErr
	}
	port = make([]byte, 2)
	binary.BigEndian.PutUint16(port, uint16(udpConn.LocalAddr().(*net.UDPAddr).Port))
	_, _ = socks5.NewReply(socks5.RepSuccess, atyp, addr, port).WriteTo(c)
	// Let UDP server do its job, we hold the TCP connection here
	go s.udpServer(udpConn, localRelayConn, hyUDP)
	if s.TCPTimeout != 0 {
		// Disable TCP timeout for UDP holder
		_ = c.SetDeadline(time.Time{})
	}
	buf := make([]byte, 1024)
	for {
		_, err := c.Read(buf)
		if err != nil {
			closeErr = err
			break
		}
	}
	// As the TCP connection closes, so does the UDP server & HyClient session
	return nil
}

func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn, hyUDP core.UDPConn) {
	var clientAddr *net.UDPAddr
	buf := make([]byte, udpBufferSize)
	// Local to remote
	for {
		n, cAddr, err := clientConn.ReadFromUDP(buf)
		if err != nil {
			break
		}
		d, err := socks5.NewDatagramFromBytes(buf[:n])
		if err != nil || d.Frag != 0 {
			// Ignore bad packets
			continue
		}
		if clientAddr == nil {
			// Whoever sends the first valid packet is our client
			clientAddr = cAddr
			// Start remote to local
			go func() {
				for {
					bs, from, err := hyUDP.ReadFrom()
					if err != nil {
						break
					}
					atyp, addr, port, err := socks5.ParseAddress(from)
					if err != nil {
						continue
					}
					d := socks5.NewDatagram(atyp, addr, port, bs)
					_, _ = clientConn.WriteToUDP(d.Bytes(), clientAddr)
				}
			}()
			if localRelayConn != nil {
				go func() {
					buf := make([]byte, udpBufferSize)
					for {
						n, from, err := localRelayConn.ReadFrom(buf)
						if n > 0 {
							atyp, addr, port, err := socks5.ParseAddress(from.String())
							if err != nil {
								continue
							}
							d := socks5.NewDatagram(atyp, addr, port, buf[:n])
							_, _ = clientConn.WriteToUDP(d.Bytes(), clientAddr)
						}
						if err != nil {
							break
						}
					}
				}()
			}
		} else if cAddr.String() != clientAddr.String() {
			// Not our client, bye
			continue
		}
		host, port, addr := parseDatagramRequestAddress(d)
		action, arg := acl.ActionProxy, ""
		var ipAddr *net.IPAddr
		var resErr error
		if s.ACLEngine != nil && localRelayConn != nil {
			action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host, port, true)
			// Doesn't always matter if the resolution fails, as we may send it through HyClient
		}
		// Handle according to the action
		switch action {
		case acl.ActionDirect:
			if resErr != nil {
				return
			}
			_, _ = localRelayConn.WriteToUDP(d.Data, &net.UDPAddr{
				IP:   ipAddr.IP,
				Port: int(port),
				Zone: ipAddr.Zone,
			})
		case acl.ActionProxy:
			_ = hyUDP.WriteTo(d.Data, addr)
		case acl.ActionBlock:
			// Do nothing
		case acl.ActionHijack:
			hijackIPAddr, err := s.Transport.ResolveIPAddr(arg)
			if err == nil {
				_, _ = localRelayConn.WriteToUDP(d.Data, &net.UDPAddr{
					IP:   hijackIPAddr.IP,
					Port: int(port),
					Zone: hijackIPAddr.Zone,
				})
			}
		default:
			// Do nothing
		}
	}
}

func sendReply(conn *net.TCPConn, rep byte) error {
	p := socks5.NewReply(rep, socks5.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00})
	_, err := p.WriteTo(conn)
	return err
}

func parseRequestAddress(r *socks5.Request) (host string, port uint16, addr string) {
	p := binary.BigEndian.Uint16(r.DstPort)
	if r.Atyp == socks5.ATYPDomain {
		d := string(r.DstAddr[1:])
		return d, p, net.JoinHostPort(d, strconv.Itoa(int(p)))
	} else {
		ipStr := net.IP(r.DstAddr).String()
		return ipStr, p, net.JoinHostPort(ipStr, strconv.Itoa(int(p)))
	}
}

func parseDatagramRequestAddress(r *socks5.Datagram) (host string, port uint16, addr string) {
	p := binary.BigEndian.Uint16(r.DstPort)
	if r.Atyp == socks5.ATYPDomain {
		d := string(r.DstAddr[1:])
		return d, p, net.JoinHostPort(d, strconv.Itoa(int(p)))
	} else {
		ipStr := net.IP(r.DstAddr).String()
		return ipStr, p, net.JoinHostPort(ipStr, strconv.Itoa(int(p)))
	}
}