package acl

import (
	"errors"
	"net"
	"strings"
	"testing"

	lru "github.com/hashicorp/golang-lru/v2"
)

func TestEngine_ResolveAndMatch(t *testing.T) {
	cache, _ := lru.NewARC[cacheKey, cacheValue](entryCacheSize)
	e := &Engine{
		DefaultAction: ActionDirect,
		Entries: []Entry{
			{
				Action:    ActionProxy,
				ActionArg: "",
				Matcher: &domainMatcher{
					matcherBase: matcherBase{
						Protocol: ProtocolTCP,
						Port:     443,
					},
					Domain: "google.com",
					Suffix: false,
				},
			},
			{
				Action:    ActionHijack,
				ActionArg: "good.org",
				Matcher: &domainMatcher{
					matcherBase: matcherBase{},
					Domain:      "evil.corp",
					Suffix:      true,
				},
			},
			{
				Action:    ActionProxy,
				ActionArg: "",
				Matcher: &netMatcher{
					matcherBase: matcherBase{},
					Net: &net.IPNet{
						IP:   net.ParseIP("10.0.0.0"),
						Mask: net.CIDRMask(8, 32),
					},
				},
			},
			{
				Action:    ActionBlock,
				ActionArg: "",
				Matcher:   &allMatcher{},
			},
		},
		Cache: cache,
		ResolveIPAddr: func(s string) (*net.IPAddr, error) {
			if strings.Contains(s, "evil.corp") {
				return nil, errors.New("resolve error")
			}
			return net.ResolveIPAddr("ip", s)
		},
	}
	tests := []struct {
		name       string
		host       string
		port       uint16
		isUDP      bool
		wantAction Action
		wantArg    string
		wantErr    bool
	}{
		{
			name:       "domain proxy",
			host:       "google.com",
			port:       443,
			isUDP:      false,
			wantAction: ActionProxy,
			wantArg:    "",
		},
		{
			name:       "domain block",
			host:       "google.com",
			port:       80,
			isUDP:      false,
			wantAction: ActionBlock,
			wantArg:    "",
		},
		{
			name:       "domain suffix 1",
			host:       "evil.corp",
			port:       8899,
			isUDP:      true,
			wantAction: ActionHijack,
			wantArg:    "good.org",
			wantErr:    true,
		},
		{
			name:       "domain suffix 2",
			host:       "notevil.corp",
			port:       22,
			isUDP:      false,
			wantAction: ActionBlock,
			wantArg:    "",
			wantErr:    true,
		},
		{
			name:       "domain suffix 3",
			host:       "im.real.evil.corp",
			port:       443,
			isUDP:      true,
			wantAction: ActionHijack,
			wantArg:    "good.org",
			wantErr:    true,
		},
		{
			name:       "ip match",
			host:       "10.2.3.4",
			port:       80,
			isUDP:      false,
			wantAction: ActionProxy,
			wantArg:    "",
		},
		{
			name:       "ip mismatch",
			host:       "100.5.6.0",
			port:       1234,
			isUDP:      false,
			wantAction: ActionBlock,
			wantArg:    "",
		},
		{
			name:       "domain proxy cache",
			host:       "google.com",
			port:       443,
			isUDP:      false,
			wantAction: ActionProxy,
			wantArg:    "",
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			gotAction, gotArg, _, _, err := e.ResolveAndMatch(tt.host, tt.port, tt.isUDP)
			if (err != nil) != tt.wantErr {
				t.Errorf("ResolveAndMatch() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if gotAction != tt.wantAction {
				t.Errorf("ResolveAndMatch() gotAction = %v, wantAction %v", gotAction, tt.wantAction)
			}
			if gotArg != tt.wantArg {
				t.Errorf("ResolveAndMatch() gotArg = %v, wantAction %v", gotArg, tt.wantArg)
			}
		})
	}
}