diff --git a/go.mod b/go.mod index 938e527..c6631fd 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,8 @@ require github.com/golang/protobuf v1.3.1 require ( github.com/lucas-clemente/quic-go v0.15.2 - github.com/patrickmn/go-cache v2.1.0+incompatible - github.com/txthinking/runnergroup v0.0.0-20200327135940-540a793bb997 + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect + github.com/txthinking/runnergroup v0.0.0-20200327135940-540a793bb997 // indirect github.com/txthinking/socks5 v0.0.0-20200327133705-caf148ab5e9d github.com/txthinking/x v0.0.0-20200330144832-5ad2416896a9 // indirect ) diff --git a/pkg/acl/engine.go b/pkg/acl/engine.go new file mode 100644 index 0000000..d90c638 --- /dev/null +++ b/pkg/acl/engine.go @@ -0,0 +1,5 @@ +package acl + +type Engine struct { + Entries []Entry +} diff --git a/pkg/acl/entry.go b/pkg/acl/entry.go new file mode 100644 index 0000000..f489574 --- /dev/null +++ b/pkg/acl/entry.go @@ -0,0 +1,111 @@ +package acl + +import ( + "errors" + "fmt" + "net" + "strings" +) + +type Action byte + +const ( + ActionDirect = Action(iota) + ActionProxy + ActionBlock + ActionHijack +) + +type Entry struct { + Net *net.IPNet + Domain string + Suffix bool + All bool + Action Action + ActionArg string +} + +// Format: action cond_type cond arg +// Examples: +// proxy domain-suffix google.com +// block ip 8.8.8.8 +// hijack cidr 192.168.1.1/24 127.0.0.1 +func ParseEntry(s string) (Entry, error) { + fields := strings.Fields(s) + if len(fields) < 2 { + return Entry{}, fmt.Errorf("expecting at least 2 fields, got %d", len(fields)) + } + args := fields[1:] + if len(args) == 1 { + // Make sure there are at least 2 args + args = append(args, "") + } + ipNet, domain, suffix, all, err := parseCond(args[0], args[1]) + if err != nil { + return Entry{}, err + } + e := Entry{ + Net: ipNet, + Domain: domain, + Suffix: suffix, + All: all, + } + switch strings.ToLower(fields[0]) { + case "direct": + e.Action = ActionDirect + case "proxy": + e.Action = ActionProxy + case "block": + e.Action = ActionBlock + case "hijack": + if len(args) < 3 { + return Entry{}, fmt.Errorf("no hijack destination for %s %s", args[0], args[1]) + } + e.Action = ActionHijack + e.ActionArg = args[2] + default: + return Entry{}, fmt.Errorf("invalid action %s", fields[0]) + } + return e, nil +} + +func parseCond(typ, cond string) (*net.IPNet, string, bool, bool, error) { + switch strings.ToLower(typ) { + case "domain": + if len(cond) == 0 { + return nil, "", false, false, errors.New("empty domain") + } + return nil, cond, false, false, nil + case "domain-suffix": + if len(cond) == 0 { + return nil, "", false, false, errors.New("empty domain suffix") + } + return nil, cond, true, false, nil + case "cidr": + _, ipNet, err := net.ParseCIDR(cond) + if err != nil { + return nil, "", false, false, err + } + return ipNet, "", false, false, nil + case "ip": + ip := net.ParseIP(cond) + if ip == nil { + return nil, "", false, false, fmt.Errorf("invalid ip %s", cond) + } + if ip.To4() != nil { + return &net.IPNet{ + IP: ip, + Mask: net.CIDRMask(32, 32), + }, "", false, false, nil + } else { + return &net.IPNet{ + IP: ip, + Mask: net.CIDRMask(128, 128), + }, "", false, false, nil + } + case "all": + return nil, "", false, true, nil + default: + return nil, "", false, false, fmt.Errorf("invalid condition type %s", typ) + } +} diff --git a/pkg/acl/entry_test.go b/pkg/acl/entry_test.go new file mode 100644 index 0000000..326046b --- /dev/null +++ b/pkg/acl/entry_test.go @@ -0,0 +1,55 @@ +package acl + +import ( + "net" + "reflect" + "testing" +) + +func TestParseEntry(t *testing.T) { + _, ok4ipnet, _ := net.ParseCIDR("8.8.8.0/24") + + type args struct { + s string + } + tests := []struct { + name string + args args + want Entry + wantErr bool + }{ + {name: "empty", args: args{""}, want: Entry{}, wantErr: true}, + {name: "ok 1", args: args{"direct domain-suffix google.com"}, + want: Entry{nil, "google.com", true, false, ActionDirect, ""}, + wantErr: false}, + {name: "ok 2", args: args{"proxy ip 8.8.8.8"}, + want: Entry{&net.IPNet{net.ParseIP("8.8.8.8"), net.CIDRMask(32, 32)}, + "", false, false, ActionProxy, ""}, wantErr: false}, + {name: "ok 3", args: args{"hijack domain mad.bad 127.0.0.1"}, + want: Entry{nil, "mad.bad", false, false, ActionHijack, "127.0.0.1"}, + wantErr: false}, + {name: "ok 4", args: args{"block cidr 8.8.8.0/24"}, + want: Entry{ok4ipnet, "", false, false, ActionBlock, ""}, + wantErr: false}, + {name: "ok 5", args: args{"block all"}, + want: Entry{nil, "", false, true, ActionBlock, ""}, + wantErr: false}, + {name: "invalid 1", args: args{"proxy domain"}, want: Entry{}, wantErr: true}, + {name: "invalid 2", args: args{"proxy dom google.com"}, want: Entry{}, wantErr: true}, + {name: "invalid 3", args: args{"hijack ip 1.1.1.1"}, want: Entry{}, wantErr: true}, + {name: "invalid 4", args: args{"direct cidr"}, want: Entry{}, wantErr: true}, + {name: "invalid 5", args: args{"oxy ip 8.8.8.8"}, want: Entry{}, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseEntry(tt.args.s) + if (err != nil) != tt.wantErr { + t.Errorf("ParseEntry() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseEntry() got = %v, want %v", got, tt.want) + } + }) + } +}