From ee8558f2fb2512a1a554b5ff125bf30f18ec1fe3 Mon Sep 17 00:00:00 2001 From: Toby Date: Sat, 25 Apr 2020 22:56:49 -0700 Subject: [PATCH] ACL engine & tests --- pkg/acl/engine.go | 48 +++++++++++++++++- pkg/acl/engine_test.go | 107 +++++++++++++++++++++++++++++++++++++++++ pkg/acl/entry.go | 22 ++++++++- 3 files changed, 174 insertions(+), 3 deletions(-) create mode 100644 pkg/acl/engine_test.go diff --git a/pkg/acl/engine.go b/pkg/acl/engine.go index d90c638..d37da65 100644 --- a/pkg/acl/engine.go +++ b/pkg/acl/engine.go @@ -1,5 +1,51 @@ package acl +import ( + "bufio" + "net" + "os" + "strings" +) + type Engine struct { - Entries []Entry + DefaultAction Action + Entries []Entry +} + +func LoadFromFile(filename string) (*Engine, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + scanner := bufio.NewScanner(f) + entries := make([]Entry, 0, 1024) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if len(line) == 0 || strings.HasPrefix(line, "#") { + // Ignore empty lines & comments + continue + } + entry, err := ParseEntry(line) + if err != nil { + return nil, err + } + entries = append(entries, entry) + } + return &Engine{ + DefaultAction: ActionProxy, + Entries: entries, + }, nil +} + +func (e *Engine) Lookup(domain string, ip net.IP) (Action, string) { + if len(domain) == 0 && ip == nil { + return e.DefaultAction, "" + } + for _, entry := range e.Entries { + if entry.Match(domain, ip) { + return entry.Action, entry.ActionArg + } + } + return e.DefaultAction, "" } diff --git a/pkg/acl/engine_test.go b/pkg/acl/engine_test.go new file mode 100644 index 0000000..a43fa09 --- /dev/null +++ b/pkg/acl/engine_test.go @@ -0,0 +1,107 @@ +package acl + +import ( + "net" + "testing" +) + +func TestEngine_Lookup(t *testing.T) { + e := &Engine{ + DefaultAction: ActionDirect, + Entries: []Entry{ + { + Net: nil, + Domain: "google.com", + Suffix: false, + All: false, + Action: ActionProxy, + ActionArg: "", + }, + { + Net: nil, + Domain: "evil.corp", + Suffix: true, + All: false, + Action: ActionHijack, + ActionArg: "good.org", + }, + { + Net: &net.IPNet{ + IP: net.ParseIP("10.0.0.0"), + Mask: net.CIDRMask(8, 32), + }, + Domain: "", + Suffix: false, + All: false, + Action: ActionProxy, + ActionArg: "", + }, + { + Net: nil, + Domain: "", + Suffix: false, + All: true, + Action: ActionBlock, + ActionArg: "", + }, + }, + } + type args struct { + domain string + ip net.IP + } + tests := []struct { + name string + args args + want Action + want1 string + }{ + { + name: "domain direct", + args: args{"google.com", nil}, + want: ActionProxy, + want1: "", + }, + { + name: "domain suffix 1", + args: args{"evil.corp", nil}, + want: ActionHijack, + want1: "good.org", + }, + { + name: "domain suffix 2", + args: args{"notevil.corp", nil}, + want: ActionBlock, + want1: "", + }, + { + name: "domain suffix 3", + args: args{"im.real.evil.corp", nil}, + want: ActionHijack, + want1: "good.org", + }, + { + name: "ip match", + args: args{"", net.ParseIP("10.2.3.4")}, + want: ActionProxy, + want1: "", + }, + { + name: "ip mismatch", + args: args{"", net.ParseIP("100.5.6.0")}, + want: ActionBlock, + want1: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := e.Lookup(tt.args.domain, tt.args.ip) + if got != tt.want { + t.Errorf("Lookup() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Lookup() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} diff --git a/pkg/acl/entry.go b/pkg/acl/entry.go index f489574..171cd78 100644 --- a/pkg/acl/entry.go +++ b/pkg/acl/entry.go @@ -25,6 +25,24 @@ type Entry struct { ActionArg string } +func (e Entry) Match(domain string, ip net.IP) bool { + if e.All { + return true + } + if e.Net != nil && ip != nil { + return e.Net.Contains(ip) + } + if len(e.Domain) > 0 && len(domain) > 0 { + ld := strings.ToLower(domain) + if e.Suffix { + return e.Domain == ld || strings.HasSuffix(ld, "."+e.Domain) + } else { + return e.Domain == ld + } + } + return false +} + // Format: action cond_type cond arg // Examples: // proxy domain-suffix google.com @@ -75,12 +93,12 @@ func parseCond(typ, cond string) (*net.IPNet, string, bool, bool, error) { if len(cond) == 0 { return nil, "", false, false, errors.New("empty domain") } - return nil, cond, false, false, nil + return nil, strings.ToLower(cond), false, false, nil case "domain-suffix": if len(cond) == 0 { return nil, "", false, false, errors.New("empty domain suffix") } - return nil, cond, true, false, nil + return nil, strings.ToLower(cond), true, false, nil case "cidr": _, ipNet, err := net.ParseCIDR(cond) if err != nil {