diff --git a/cmd/proxy_client.go b/cmd/client.go similarity index 65% rename from cmd/proxy_client.go rename to cmd/client.go index ca8d37d..ff6afa7 100644 --- a/cmd/proxy_client.go +++ b/cmd/client.go @@ -3,11 +3,6 @@ package main import ( "crypto/tls" "crypto/x509" - "io/ioutil" - "net" - "net/http" - "time" - "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/congestion" "github.com/sirupsen/logrus" @@ -17,42 +12,38 @@ import ( hyHTTP "github.com/tobyxdd/hysteria/pkg/http" "github.com/tobyxdd/hysteria/pkg/obfs" "github.com/tobyxdd/hysteria/pkg/socks5" + "io/ioutil" + "net" + "net/http" + "time" ) -func proxyClient(args []string) { - var config proxyClientConfig - err := loadConfig(&config, args) - if err != nil { - logrus.WithField("error", err).Fatal("Unable to load configuration") - } - if err := config.Check(); err != nil { - logrus.WithField("error", err).Fatal("Configuration error") - } - logrus.WithField("config", config.String()).Info("Configuration loaded") - +func client(config *clientConfig) { + logrus.WithField("config", config.String()).Info("Client configuration loaded") + // TLS tlsConfig := &tls.Config{ InsecureSkipVerify: config.Insecure, - NextProtos: []string{proxyTLSProtocol}, + NextProtos: []string{tlsProtocolName}, MinVersion: tls.VersionTLS13, } // Load CA - if len(config.CustomCAFile) > 0 { - bs, err := ioutil.ReadFile(config.CustomCAFile) + if len(config.CustomCA) > 0 { + bs, err := ioutil.ReadFile(config.CustomCA) if err != nil { logrus.WithFields(logrus.Fields{ "error": err, - "file": config.CustomCAFile, - }).Fatal("Unable to load CA file") + "file": config.CustomCA, + }).Fatal("Failed to load CA") } cp := x509.NewCertPool() if !cp.AppendCertsFromPEM(bs) { logrus.WithFields(logrus.Fields{ - "file": config.CustomCAFile, - }).Fatal("Unable to parse CA file") + "file": config.CustomCA, + }).Fatal("Failed to parse CA") } tlsConfig.RootCAs = cp } - + // QUIC config quicConfig := &quic.Config{ MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn, MaxReceiveConnectionFlowControlWindow: config.ReceiveWindow, @@ -64,46 +55,54 @@ func proxyClient(args []string) { if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 { quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow } - + // Auth + var auth []byte + if len(config.Auth) > 0 { + auth = config.Auth + } else { + auth = []byte(config.AuthString) + } + // Obfuscator var obfuscator core.Obfuscator if len(config.Obfs) > 0 { obfuscator = obfs.XORObfuscator(config.Obfs) } - + // ACL var aclEngine *acl.Engine - if len(config.ACLFile) > 0 { - aclEngine, err = acl.LoadFromFile(config.ACLFile) + if len(config.ACL) > 0 { + var err error + aclEngine, err = acl.LoadFromFile(config.ACL) if err != nil { logrus.WithFields(logrus.Fields{ "error": err, - "file": config.ACLFile, - }).Fatal("Unable to parse ACL") + "file": config.ACL, + }).Fatal("Failed to parse ACL") } } - - client, err := core.NewClient(config.ServerAddr, config.Username, config.Password, tlsConfig, quicConfig, + // Client + client, err := core.NewClient(config.Server, auth, tlsConfig, quicConfig, uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, func(refBPS uint64) congestion.CongestionControl { return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) }, obfuscator) if err != nil { - logrus.WithField("error", err).Fatal("Client initialization failed") + logrus.WithField("error", err).Fatal("Failed to initialize client") } defer client.Close() - logrus.WithField("addr", config.ServerAddr).Info("Connected") + logrus.WithField("addr", config.Server).Info("Connected") + // Local errChan := make(chan error) - - if len(config.SOCKS5Addr) > 0 { + if len(config.SOCKS5.Listen) > 0 { go func() { var authFunc func(user, password string) bool - if config.SOCKS5User != "" && config.SOCKS5Password != "" { + if config.SOCKS5.User != "" && config.SOCKS5.Password != "" { authFunc = func(user, password string) bool { - return config.SOCKS5User == user && config.SOCKS5Password == password + return config.SOCKS5.User == user && config.SOCKS5.Password == password } } - socks5server, err := socks5.NewServer(client, config.SOCKS5Addr, authFunc, config.SOCKS5Timeout, aclEngine, - config.SOCKS5DisableUDP, + socks5server, err := socks5.NewServer(client, config.SOCKS5.Listen, authFunc, config.SOCKS5.Timeout, aclEngine, + config.SOCKS5.DisableUDP, func(addr net.Addr, reqAddr string, action acl.Action, arg string) { logrus.WithFields(logrus.Fields{ "action": actionToString(action, arg), @@ -144,22 +143,22 @@ func proxyClient(args []string) { }).Debug("SOCKS5 UDP tunnel closed") }) if err != nil { - logrus.WithField("error", err).Fatal("SOCKS5 server initialization failed") + logrus.WithField("error", err).Fatal("Failed to initialize SOCKS5 server") } - logrus.WithField("addr", config.SOCKS5Addr).Info("SOCKS5 server up and running") + logrus.WithField("addr", config.SOCKS5.Listen).Info("SOCKS5 server up and running") errChan <- socks5server.ListenAndServe() }() } - if len(config.HTTPAddr) > 0 { + if len(config.HTTP.Listen) > 0 { go func() { var authFunc func(user, password string) bool - if config.HTTPUser != "" && config.HTTPPassword != "" { + if config.HTTP.User != "" && config.HTTP.Password != "" { authFunc = func(user, password string) bool { - return config.HTTPUser == user && config.HTTPPassword == password + return config.HTTP.User == user && config.HTTP.Password == password } } - proxy, err := hyHTTP.NewProxyHTTPServer(client, time.Duration(config.HTTPTimeout)*time.Second, aclEngine, + proxy, err := hyHTTP.NewProxyHTTPServer(client, time.Duration(config.HTTP.Timeout)*time.Second, aclEngine, func(reqAddr string, action acl.Action, arg string) { logrus.WithFields(logrus.Fields{ "action": actionToString(action, arg), @@ -168,14 +167,14 @@ func proxyClient(args []string) { }, authFunc) if err != nil { - logrus.WithField("error", err).Fatal("HTTP server initialization failed") + logrus.WithField("error", err).Fatal("Failed to initialize HTTP server") } - if config.HTTPSCert != "" && config.HTTPSKey != "" { - logrus.WithField("addr", config.HTTPAddr).Info("HTTPS server up and running") - errChan <- http.ListenAndServeTLS(config.HTTPAddr, config.HTTPSCert, config.HTTPSKey, proxy) + if config.HTTP.Cert != "" && config.HTTP.Key != "" { + logrus.WithField("addr", config.HTTP.Listen).Info("HTTPS server up and running") + errChan <- http.ListenAndServeTLS(config.HTTP.Listen, config.HTTP.Cert, config.HTTP.Key, proxy) } else { - logrus.WithField("addr", config.HTTPAddr).Info("HTTP server up and running") - errChan <- http.ListenAndServe(config.HTTPAddr, proxy) + logrus.WithField("addr", config.HTTP.Listen).Info("HTTP server up and running") + errChan <- http.ListenAndServe(config.HTTP.Listen, proxy) } }() } diff --git a/cmd/config.go b/cmd/config.go index 365bdca..226b888 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -1,86 +1,127 @@ package main import ( - "encoding/json" - "flag" - "io/ioutil" - "os" - "reflect" - "strings" - "time" + "errors" + "fmt" ) const ( - mbpsToBps = 125000 - dialTimeout = 10 * time.Second + mbpsToBps = 125000 DefaultMaxReceiveStreamFlowControlWindow = 33554432 DefaultMaxReceiveConnectionFlowControlWindow = 67108864 - DefaultMaxIncomingStreams = 200 + DefaultMaxIncomingStreams = 1024 + + tlsProtocolName = "hysteria" ) -func loadConfig(cfg interface{}, args []string) error { - cfgVal := reflect.ValueOf(cfg).Elem() - fs := flag.NewFlagSet("", flag.ContinueOnError) - fsValMap := make(map[reflect.Value]interface{}, cfgVal.NumField()) - for i := 0; i < cfgVal.NumField(); i++ { - structField := cfgVal.Type().Field(i) - tag := structField.Tag - switch structField.Type.Kind() { - case reflect.String: - fsValMap[cfgVal.Field(i)] = - fs.String(jsonTagToFlagName(tag.Get("json")), "", tag.Get("desc")) - case reflect.Int: - fsValMap[cfgVal.Field(i)] = - fs.Int(jsonTagToFlagName(tag.Get("json")), 0, tag.Get("desc")) - case reflect.Uint64: - fsValMap[cfgVal.Field(i)] = - fs.Uint64(jsonTagToFlagName(tag.Get("json")), 0, tag.Get("desc")) - case reflect.Bool: - var bf optionalBoolFlag - fs.Var(&bf, jsonTagToFlagName(tag.Get("json")), tag.Get("desc")) - fsValMap[cfgVal.Field(i)] = &bf - } +type serverConfig struct { + Listen string `json:"listen"` + CertFile string `json:"cert"` + KeyFile string `json:"key"` + // Optional below + UpMbps int `json:"up_mbps"` + DownMbps int `json:"down_mbps"` + DisableUDP bool `json:"disable_udp"` + ACL string `json:"acl"` + Obfs string `json:"obfs"` + Auth struct { + Mode string `json:"mode"` + Config interface{} `json:"config"` + } `json:"auth"` + ReceiveWindowConn uint64 `json:"recv_window_conn"` + ReceiveWindowClient uint64 `json:"recv_window_client"` + MaxConnClient int `json:"max_conn_client"` +} + +func (c *serverConfig) Check() error { + if len(c.Listen) == 0 { + return errors.New("no listen address") } - configFile := fs.String("config", "", "Configuration file") - // Parse - if err := fs.Parse(args); err != nil { - os.Exit(1) + if len(c.CertFile) == 0 || len(c.KeyFile) == 0 { + return errors.New("TLS cert or key not provided") } - // Put together the config - if len(*configFile) > 0 { - cb, err := ioutil.ReadFile(*configFile) - if err != nil { - return err - } - if err := json.Unmarshal(cb, cfg); err != nil { - return err - } + if c.UpMbps < 0 || c.DownMbps < 0 { + return errors.New("invalid speed") } - // Flags override config from file - for field, val := range fsValMap { - switch v := val.(type) { - case *string: - if len(*v) > 0 { - field.SetString(*v) - } - case *int: - if *v != 0 { - field.SetInt(int64(*v)) - } - case *uint64: - if *v != 0 { - field.SetUint(*v) - } - case *optionalBoolFlag: - if v.Exists { - field.SetBool(v.Value) - } - } + if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || + (c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) { + return errors.New("invalid receive window size") + } + if c.MaxConnClient < 0 { + return errors.New("invalid max connections per client") } return nil } -func jsonTagToFlagName(tag string) string { - return strings.ReplaceAll(tag, "_", "-") +func (c *serverConfig) String() string { + return fmt.Sprintf("%+v", *c) +} + +type clientConfig struct { + Server string `json:"server"` + UpMbps int `json:"up_mbps"` + DownMbps int `json:"down_mbps"` + // Optional below + SOCKS5 struct { + Listen string `json:"listen"` + Timeout int `json:"timeout"` + DisableUDP bool `json:"disable_udp"` + User string `json:"user"` + Password string `json:"password"` + } `json:"socks5"` + HTTP struct { + Listen string `json:"listen"` + Timeout int `json:"timeout"` + User string `json:"user"` + Password string `json:"password"` + Cert string `json:"cert"` + Key string `json:"key"` + } `json:"http"` + Relay struct { + Listen string `json:"listen"` + Remote string `json:"remote"` + Timeout int `json:"timeout"` + } `json:"relay"` + ACL string `json:"acl"` + Obfs string `json:"obfs"` + Auth []byte `json:"auth"` + AuthString string `json:"auth_str"` + Insecure bool `json:"insecure"` + CustomCA string `json:"ca"` + ReceiveWindowConn uint64 `json:"recv_window_conn"` + ReceiveWindow uint64 `json:"recv_window"` +} + +func (c *clientConfig) Check() error { + if len(c.SOCKS5.Listen) == 0 && len(c.HTTP.Listen) == 0 && len(c.Relay.Listen) == 0 { + return errors.New("no SOCKS5, HTTP or relay listen address") + } + if len(c.Relay.Listen) > 0 && len(c.Relay.Remote) == 0 { + return errors.New("no relay remote address") + } + if c.SOCKS5.Timeout != 0 && c.SOCKS5.Timeout <= 4 { + return errors.New("invalid SOCKS5 timeout") + } + if c.HTTP.Timeout != 0 && c.HTTP.Timeout <= 4 { + return errors.New("invalid HTTP timeout") + } + if c.Relay.Timeout != 0 && c.Relay.Timeout <= 4 { + return errors.New("invalid relay timeout") + } + if len(c.Server) == 0 { + return errors.New("no server address") + } + if c.UpMbps <= 0 || c.DownMbps <= 0 { + return errors.New("invalid speed") + } + if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || + (c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) { + return errors.New("invalid receive window size") + } + return nil +} + +func (c *clientConfig) String() string { + return fmt.Sprintf("%+v", *c) } diff --git a/cmd/main.go b/cmd/main.go index 0b4f2c0..1ccc2c4 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,11 +1,13 @@ package main import ( + "flag" "fmt" + "github.com/sirupsen/logrus" + "github.com/yosuke-furukawa/json5/encoding/json5" + "io/ioutil" "os" "strings" - - "github.com/sirupsen/logrus" ) // Injected when compiling @@ -15,12 +17,10 @@ var ( appDate = "Unknown" ) -var modeMap = map[string]func(args []string){ - "relay server": relayServer, - "relay client": relayClient, - "proxy server": proxyServer, - "proxy client": proxyClient, -} +var ( + configPath = flag.String("config", "config.json", "Config file") + showVersion = flag.Bool("version", false, "Show version") +) func init() { logrus.SetOutput(os.Stdout) @@ -50,37 +50,68 @@ func init() { TimestampFormat: tsFormat, }) } + + flag.Parse() } func main() { - if len(os.Args) == 2 && strings.ToLower(strings.TrimSpace(os.Args[1])) == "version" { + if *showVersion { // Print version and quit fmt.Printf("%-10s%s\n", "Version:", appVersion) fmt.Printf("%-10s%s\n", "Commit:", appCommit) fmt.Printf("%-10s%s\n", "Date:", appDate) return } - if len(os.Args) < 3 { - fmt.Println() - fmt.Printf("Usage: %s MODE SUBMODE [OPTIONS]\n\n"+ - "Available mode/submode combinations: "+getModes()+"\n"+ - "Use -h to see the available options for a mode.\n\n", os.Args[0]) - return + cb, err := ioutil.ReadFile(*configPath) + if err != nil { + logrus.WithFields(logrus.Fields{ + "file": *configPath, + "error": err, + }).Fatal("Failed to read configuration") } - modeStr := fmt.Sprintf("%s %s", strings.ToLower(strings.TrimSpace(os.Args[1])), - strings.ToLower(strings.TrimSpace(os.Args[2]))) - f := modeMap[modeStr] - if f != nil { - f(os.Args[3:]) + mode := flag.Arg(0) + if strings.EqualFold(mode, "server") { + // server mode + c, err := parseServerConfig(cb) + if err != nil { + logrus.WithFields(logrus.Fields{ + "file": *configPath, + "error": err, + }).Fatal("Failed to parse server configuration") + } + server(c) + } else if len(mode) == 0 || strings.EqualFold(mode, "client") { + // client mode + c, err := parseClientConfig(cb) + if err != nil { + logrus.WithFields(logrus.Fields{ + "file": *configPath, + "error": err, + }).Fatal("Failed to parse client configuration") + } + client(c) } else { - fmt.Println("Invalid mode:", modeStr) + // invalid + fmt.Println() + fmt.Printf("Usage: %s MODE [OPTIONS]\n\n"+ + "Available modes: server, client\n\n", os.Args[0]) } } -func getModes() string { - modes := make([]string, 0, len(modeMap)) - for mode := range modeMap { - modes = append(modes, mode) +func parseServerConfig(cb []byte) (*serverConfig, error) { + var c serverConfig + err := json5.Unmarshal(cb, &c) + if err != nil { + return nil, err } - return strings.Join(modes, ", ") + return &c, c.Check() +} + +func parseClientConfig(cb []byte) (*clientConfig, error) { + var c clientConfig + err := json5.Unmarshal(cb, &c) + if err != nil { + return nil, err + } + return &c, c.Check() } diff --git a/cmd/proxy_config.go b/cmd/proxy_config.go deleted file mode 100644 index a192698..0000000 --- a/cmd/proxy_config.go +++ /dev/null @@ -1,99 +0,0 @@ -package main - -import ( - "errors" - "fmt" -) - -const proxyTLSProtocol = "hysteria-proxy" - -type proxyClientConfig struct { - SOCKS5Addr string `json:"socks5_addr" desc:"SOCKS5 listen address"` - SOCKS5Timeout int `json:"socks5_timeout" desc:"SOCKS5 connection timeout in seconds"` - SOCKS5DisableUDP bool `json:"socks5_disable_udp" desc:"Disable SOCKS5 UDP support"` - SOCKS5User string `json:"socks5_user" desc:"SOCKS5 auth username"` - SOCKS5Password string `json:"socks5_password" desc:"SOCKS5 auth password"` - HTTPAddr string `json:"http_addr" desc:"HTTP listen address"` - HTTPTimeout int `json:"http_timeout" desc:"HTTP connection timeout in seconds"` - HTTPUser string `json:"http_user" desc:"HTTP basic auth username"` - HTTPPassword string `json:"http_password" desc:"HTTP basic auth password"` - HTTPSCert string `json:"https_cert" desc:"HTTPS certificate file"` - HTTPSKey string `json:"https_key" desc:"HTTPS key file"` - ACLFile string `json:"acl" desc:"Access control list"` - ServerAddr string `json:"server" desc:"Server address"` - Username string `json:"username" desc:"Authentication username"` - Password string `json:"password" desc:"Authentication password"` - Insecure bool `json:"insecure" desc:"Ignore TLS certificate errors"` - CustomCAFile string `json:"ca" desc:"Specify a trusted CA file"` - UpMbps int `json:"up_mbps" desc:"Upload speed in Mbps"` - DownMbps int `json:"down_mbps" desc:"Download speed in Mbps"` - ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"` - ReceiveWindow uint64 `json:"recv_window" desc:"Max receive window size"` - Obfs string `json:"obfs" desc:"Obfuscation key"` -} - -func (c *proxyClientConfig) Check() error { - if len(c.SOCKS5Addr) == 0 && len(c.HTTPAddr) == 0 { - return errors.New("no SOCKS5 or HTTP listen address") - } - if c.SOCKS5Timeout != 0 && c.SOCKS5Timeout <= 4 { - return errors.New("invalid SOCKS5 timeout") - } - if c.HTTPTimeout != 0 && c.HTTPTimeout <= 4 { - return errors.New("invalid HTTP timeout") - } - if len(c.ServerAddr) == 0 { - return errors.New("no server address") - } - if c.UpMbps <= 0 || c.DownMbps <= 0 { - return errors.New("invalid speed") - } - if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || - (c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) { - return errors.New("invalid receive window size") - } - return nil -} - -func (c *proxyClientConfig) String() string { - return fmt.Sprintf("%+v", *c) -} - -type proxyServerConfig struct { - ListenAddr string `json:"listen" desc:"Server listen address"` - DisableUDP bool `json:"disable_udp" desc:"Disable UDP support"` - ACLFile string `json:"acl" desc:"Access control list"` - CertFile string `json:"cert" desc:"TLS certificate file"` - KeyFile string `json:"key" desc:"TLS key file"` - AuthFile string `json:"auth" desc:"Authentication file"` - UpMbps int `json:"up_mbps" desc:"Max upload speed per client in Mbps"` - DownMbps int `json:"down_mbps" desc:"Max download speed per client in Mbps"` - ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"` - ReceiveWindowClient uint64 `json:"recv_window_client" desc:"Max receive window size per client"` - MaxConnClient int `json:"max_conn_client" desc:"Max simultaneous connections allowed per client"` - Obfs string `json:"obfs" desc:"Obfuscation key"` -} - -func (c *proxyServerConfig) Check() error { - if len(c.ListenAddr) == 0 { - return errors.New("no listen address") - } - if len(c.CertFile) == 0 || len(c.KeyFile) == 0 { - return errors.New("TLS cert or key not provided") - } - if c.UpMbps < 0 || c.DownMbps < 0 { - return errors.New("invalid speed") - } - if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || - (c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) { - return errors.New("invalid receive window size") - } - if c.MaxConnClient < 0 { - return errors.New("invalid max connections per client") - } - return nil -} - -func (c *proxyServerConfig) String() string { - return fmt.Sprintf("%+v", *c) -} diff --git a/cmd/proxy_server.go b/cmd/proxy_server.go deleted file mode 100644 index dc46f69..0000000 --- a/cmd/proxy_server.go +++ /dev/null @@ -1,298 +0,0 @@ -package main - -import ( - "bufio" - "crypto/tls" - "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/congestion" - "github.com/sirupsen/logrus" - "github.com/tobyxdd/hysteria/pkg/acl" - hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion" - "github.com/tobyxdd/hysteria/pkg/core" - "github.com/tobyxdd/hysteria/pkg/obfs" - "io" - "net" - "os" - "strings" -) - -func proxyServer(args []string) { - var config proxyServerConfig - err := loadConfig(&config, args) - if err != nil { - logrus.WithField("error", err).Fatal("Unable to load configuration") - } - if err := config.Check(); err != nil { - logrus.WithField("error", err).Fatal("Configuration error") - } - logrus.WithField("config", config.String()).Info("Configuration loaded") - // Load cert - cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err, - "cert": config.CertFile, - "key": config.KeyFile, - }).Fatal("Unable to load the certificate") - } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - NextProtos: []string{proxyTLSProtocol}, - MinVersion: tls.VersionTLS13, - } - - quicConfig := &quic.Config{ - MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn, - MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient, - MaxIncomingStreams: int64(config.MaxConnClient), - KeepAlive: true, - } - if quicConfig.MaxReceiveStreamFlowControlWindow == 0 { - quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow - } - if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 { - quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow - } - if quicConfig.MaxIncomingStreams == 0 { - quicConfig.MaxIncomingStreams = DefaultMaxIncomingStreams - } - - if len(config.AuthFile) == 0 { - logrus.Warn("No authentication configured, this server can be used by anyone") - } - - var obfuscator core.Obfuscator - if len(config.Obfs) > 0 { - obfuscator = obfs.XORObfuscator(config.Obfs) - } - - var aclEngine *acl.Engine - if len(config.ACLFile) > 0 { - aclEngine, err = acl.LoadFromFile(config.ACLFile) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err, - "file": config.ACLFile, - }).Fatal("Unable to parse ACL") - } - aclEngine.DefaultAction = acl.ActionDirect - } - - server, err := core.NewServer(config.ListenAddr, tlsConfig, quicConfig, - uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, - func(refBPS uint64) congestion.CongestionControl { - return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) - }, - obfuscator, - func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) { - if len(config.AuthFile) == 0 { - logrus.WithFields(logrus.Fields{ - "addr": addr.String(), - "username": username, - "up": sSend / mbpsToBps, - "down": sRecv / mbpsToBps, - }).Info("Client connected") - return core.AuthResultSuccess, "" - } else { - // Need auth - ok, err := checkAuth(config.AuthFile, username, password) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err.Error(), - "addr": addr.String(), - "username": username, - }).Error("Client authentication error") - return core.AuthResultInternalError, "Server auth error" - } - if ok { - logrus.WithFields(logrus.Fields{ - "addr": addr.String(), - "username": username, - "up": sSend / mbpsToBps, - "down": sRecv / mbpsToBps, - }).Info("Client authenticated") - return core.AuthResultSuccess, "" - } else { - logrus.WithFields(logrus.Fields{ - "addr": addr.String(), - "username": username, - "up": sSend / mbpsToBps, - "down": sRecv / mbpsToBps, - }).Info("Client rejected due to invalid credential") - return core.AuthResultInvalidCred, "Invalid credential" - } - } - }, - func(addr net.Addr, username string, err error) { - logrus.WithFields(logrus.Fields{ - "error": err.Error(), - "addr": addr.String(), - "username": username, - }).Info("Client disconnected") - }, - func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) { - packet := reqType == core.ConnectionTypePacket - if packet && config.DisableUDP { - return core.ConnectResultBlocked, "UDP disabled", nil - } - host, port, err := net.SplitHostPort(reqAddr) - if err != nil { - return core.ConnectResultFailed, err.Error(), nil - } - ip := net.ParseIP(host) - if ip != nil { - // IP request, clear host for ACL engine - host = "" - } - action, arg := acl.ActionDirect, "" - if aclEngine != nil { - action, arg = aclEngine.Lookup(host, ip) - } - switch action { - case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side - if !packet { - // TCP - logrus.WithFields(logrus.Fields{ - "action": "direct", - "username": username, - "src": addr.String(), - "dst": reqAddr, - }).Debug("New TCP request") - conn, err := net.DialTimeout("tcp", reqAddr, dialTimeout) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err, - "dst": reqAddr, - }).Error("TCP error") - return core.ConnectResultFailed, err.Error(), nil - } - return core.ConnectResultSuccess, "", conn - } else { - // UDP - logrus.WithFields(logrus.Fields{ - "action": "direct", - "username": username, - "src": addr.String(), - "dst": reqAddr, - }).Debug("New UDP request") - conn, err := net.Dial("udp", reqAddr) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err, - "dst": reqAddr, - }).Error("UDP error") - return core.ConnectResultFailed, err.Error(), nil - } - return core.ConnectResultSuccess, "", conn - } - case acl.ActionBlock: - if !packet { - // TCP - logrus.WithFields(logrus.Fields{ - "action": "block", - "username": username, - "src": addr.String(), - "dst": reqAddr, - }).Debug("New TCP request") - return core.ConnectResultBlocked, "blocked by ACL", nil - } else { - // UDP - logrus.WithFields(logrus.Fields{ - "action": "block", - "username": username, - "src": addr.String(), - "dst": reqAddr, - }).Debug("New UDP request") - return core.ConnectResultBlocked, "blocked by ACL", nil - } - case acl.ActionHijack: - hijackAddr := net.JoinHostPort(arg, port) - if !packet { - // TCP - logrus.WithFields(logrus.Fields{ - "action": "hijack", - "username": username, - "src": addr.String(), - "dst": reqAddr, - "rdst": arg, - }).Debug("New TCP request") - conn, err := net.DialTimeout("tcp", hijackAddr, dialTimeout) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err, - "dst": hijackAddr, - }).Error("TCP error") - return core.ConnectResultFailed, err.Error(), nil - } - return core.ConnectResultSuccess, "", conn - } else { - // UDP - logrus.WithFields(logrus.Fields{ - "action": "hijack", - "username": username, - "src": addr.String(), - "dst": reqAddr, - "rdst": arg, - }).Debug("New UDP request") - conn, err := net.Dial("udp", hijackAddr) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err, - "dst": hijackAddr, - }).Error("UDP error") - return core.ConnectResultFailed, err.Error(), nil - } - return core.ConnectResultSuccess, "", conn - } - default: - return core.ConnectResultFailed, "server ACL error", nil - } - }, - func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) { - packet := reqType == core.ConnectionTypePacket - if !packet { - logrus.WithFields(logrus.Fields{ - "error": err, - "username": username, - "src": addr.String(), - "dst": reqAddr, - }).Debug("TCP request closed") - } else { - logrus.WithFields(logrus.Fields{ - "error": err, - "username": username, - "src": addr.String(), - "dst": reqAddr, - }).Debug("UDP request closed") - } - }, - ) - if err != nil { - logrus.WithField("error", err).Fatal("Server initialization failed") - } - defer server.Close() - logrus.WithField("addr", config.ListenAddr).Info("Server up and running") - - err = server.Serve() - logrus.WithField("error", err).Fatal("Server shutdown") -} - -func checkAuth(authFile, username, password string) (bool, error) { - f, err := os.Open(authFile) - if err != nil { - return false, err - } - defer f.Close() - scanner := bufio.NewScanner(f) - for scanner.Scan() { - pair := strings.Fields(scanner.Text()) - if len(pair) != 2 { - // Invalid format - continue - } - if username == pair[0] && password == pair[1] { - return true, nil - } - } - return false, nil -} diff --git a/cmd/relay_client.go b/cmd/relay_client.go deleted file mode 100644 index aa36f7e..0000000 --- a/cmd/relay_client.go +++ /dev/null @@ -1,119 +0,0 @@ -package main - -import ( - "crypto/tls" - "crypto/x509" - "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/congestion" - "github.com/sirupsen/logrus" - hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion" - "github.com/tobyxdd/hysteria/pkg/core" - "github.com/tobyxdd/hysteria/pkg/obfs" - "github.com/tobyxdd/hysteria/pkg/utils" - "io/ioutil" - "net" - "os/user" -) - -func relayClient(args []string) { - var config relayClientConfig - err := loadConfig(&config, args) - if err != nil { - logrus.WithField("error", err).Fatal("Unable to load configuration") - } - if err := config.Check(); err != nil { - logrus.WithField("error", err).Fatal("Configuration error") - } - if len(config.Name) == 0 { - usr, err := user.Current() - if err == nil { - config.Name = usr.Name - } - } - logrus.WithField("config", config.String()).Info("Configuration loaded") - - tlsConfig := &tls.Config{ - InsecureSkipVerify: config.Insecure, - NextProtos: []string{relayTLSProtocol}, - MinVersion: tls.VersionTLS13, - } - // Load CA - if len(config.CustomCAFile) > 0 { - bs, err := ioutil.ReadFile(config.CustomCAFile) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err, - "file": config.CustomCAFile, - }).Fatal("Unable to load CA file") - } - cp := x509.NewCertPool() - if !cp.AppendCertsFromPEM(bs) { - logrus.WithFields(logrus.Fields{ - "file": config.CustomCAFile, - }).Fatal("Unable to parse CA file") - } - tlsConfig.RootCAs = cp - } - - quicConfig := &quic.Config{ - MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn, - MaxReceiveConnectionFlowControlWindow: config.ReceiveWindow, - KeepAlive: true, - } - if quicConfig.MaxReceiveStreamFlowControlWindow == 0 { - quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow - } - if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 { - quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow - } - - var obfuscator core.Obfuscator - if len(config.Obfs) > 0 { - obfuscator = obfs.XORObfuscator(config.Obfs) - } - - client, err := core.NewClient(config.ServerAddr, config.Name, "", tlsConfig, quicConfig, - uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, - func(refBPS uint64) congestion.CongestionControl { - return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) - }, obfuscator) - if err != nil { - logrus.WithField("error", err).Fatal("Client initialization failed") - } - defer client.Close() - logrus.WithField("addr", config.ServerAddr).Info("Connected") - - listener, err := net.Listen("tcp", config.ListenAddr) - if err != nil { - logrus.WithField("error", err).Fatal("TCP listen failed") - } - defer listener.Close() - logrus.WithField("addr", listener.Addr().String()).Info("TCP server listening") - - for { - conn, err := listener.Accept() - if err != nil { - logrus.WithField("error", err).Fatal("TCP accept failed") - } - go relayClientHandleConn(conn, client) - } -} - -func relayClientHandleConn(conn net.Conn, client *core.Client) { - logrus.WithField("src", conn.RemoteAddr().String()).Debug("New connection") - var closeErr error - defer func() { - _ = conn.Close() - logrus.WithFields(logrus.Fields{ - "error": closeErr, - "src": conn.RemoteAddr().String(), - }).Debug("Connection closed") - }() - rwc, err := client.Dial(false, "") - if err != nil { - closeErr = err - return - } - defer rwc.Close() - closeErr = utils.PipePair(conn, rwc, nil, nil) -} diff --git a/cmd/relay_config.go b/cmd/relay_config.go deleted file mode 100644 index 6dd1060..0000000 --- a/cmd/relay_config.go +++ /dev/null @@ -1,82 +0,0 @@ -package main - -import ( - "errors" - "fmt" -) - -const relayTLSProtocol = "hysteria-relay" - -type relayClientConfig struct { - ListenAddr string `json:"listen" desc:"TCP listen address"` - ServerAddr string `json:"server" desc:"Server address"` - Name string `json:"name" desc:"Client name presented to the server"` - Insecure bool `json:"insecure" desc:"Ignore TLS certificate errors"` - CustomCAFile string `json:"ca" desc:"Specify a trusted CA file"` - UpMbps int `json:"up_mbps" desc:"Upload speed in Mbps"` - DownMbps int `json:"down_mbps" desc:"Download speed in Mbps"` - ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"` - ReceiveWindow uint64 `json:"recv_window" desc:"Max receive window size"` - Obfs string `json:"obfs" desc:"Obfuscation key"` -} - -func (c *relayClientConfig) Check() error { - if len(c.ListenAddr) == 0 { - return errors.New("no listen address") - } - if len(c.ServerAddr) == 0 { - return errors.New("no server address") - } - if c.UpMbps <= 0 || c.DownMbps <= 0 { - return errors.New("invalid speed") - } - if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || - (c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) { - return errors.New("invalid receive window size") - } - return nil -} - -func (c *relayClientConfig) String() string { - return fmt.Sprintf("%+v", *c) -} - -type relayServerConfig struct { - ListenAddr string `json:"listen" desc:"Server listen address"` - RemoteAddr string `json:"remote" desc:"Remote relay address"` - CertFile string `json:"cert" desc:"TLS certificate file"` - KeyFile string `json:"key" desc:"TLS key file"` - UpMbps int `json:"up_mbps" desc:"Max upload speed per client in Mbps"` - DownMbps int `json:"down_mbps" desc:"Max download speed per client in Mbps"` - ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"` - ReceiveWindowClient uint64 `json:"recv_window_client" desc:"Max receive window size per client"` - MaxConnClient int `json:"max_conn_client" desc:"Max simultaneous connections allowed per client"` - Obfs string `json:"obfs" desc:"Obfuscation key"` -} - -func (c *relayServerConfig) Check() error { - if len(c.ListenAddr) == 0 { - return errors.New("no listen address") - } - if len(c.RemoteAddr) == 0 { - return errors.New("no remote address") - } - if len(c.CertFile) == 0 || len(c.KeyFile) == 0 { - return errors.New("TLS cert or key not provided") - } - if c.UpMbps < 0 || c.DownMbps < 0 { - return errors.New("invalid speed") - } - if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || - (c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) { - return errors.New("invalid receive window size") - } - if c.MaxConnClient < 0 { - return errors.New("invalid max connections per client") - } - return nil -} - -func (c *relayServerConfig) String() string { - return fmt.Sprintf("%+v", *c) -} diff --git a/cmd/relay_server.go b/cmd/relay_server.go deleted file mode 100644 index d82a321..0000000 --- a/cmd/relay_server.go +++ /dev/null @@ -1,121 +0,0 @@ -package main - -import ( - "crypto/tls" - "github.com/lucas-clemente/quic-go" - "github.com/lucas-clemente/quic-go/congestion" - "github.com/sirupsen/logrus" - hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion" - "github.com/tobyxdd/hysteria/pkg/core" - "github.com/tobyxdd/hysteria/pkg/obfs" - "io" - "net" -) - -func relayServer(args []string) { - var config relayServerConfig - err := loadConfig(&config, args) - if err != nil { - logrus.WithField("error", err).Fatal("Unable to load configuration") - } - if err := config.Check(); err != nil { - logrus.WithField("error", err).Fatal("Configuration error") - } - logrus.WithField("config", config.String()).Info("Configuration loaded") - // Load cert - cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err, - "cert": config.CertFile, - "key": config.KeyFile, - }).Fatal("Unable to load the certificate") - } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - NextProtos: []string{relayTLSProtocol}, - MinVersion: tls.VersionTLS13, - } - - quicConfig := &quic.Config{ - MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn, - MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient, - MaxIncomingStreams: int64(config.MaxConnClient), - KeepAlive: true, - } - if quicConfig.MaxReceiveStreamFlowControlWindow == 0 { - quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow - } - if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 { - quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow - } - if quicConfig.MaxIncomingStreams == 0 { - quicConfig.MaxIncomingStreams = DefaultMaxIncomingStreams - } - - var obfuscator core.Obfuscator - if len(config.Obfs) > 0 { - obfuscator = obfs.XORObfuscator(config.Obfs) - } - - server, err := core.NewServer(config.ListenAddr, tlsConfig, quicConfig, - uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, - func(refBPS uint64) congestion.CongestionControl { - return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) - }, - obfuscator, - func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) { - // No authentication logic in relay, just log username and speed - logrus.WithFields(logrus.Fields{ - "addr": addr.String(), - "username": username, - "up": sSend / mbpsToBps, - "down": sRecv / mbpsToBps, - }).Info("Client connected") - return core.AuthResultSuccess, "" - }, - func(addr net.Addr, username string, err error) { - logrus.WithFields(logrus.Fields{ - "error": err.Error(), - "addr": addr.String(), - "username": username, - }).Info("Client disconnected") - }, - func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) { - packet := reqType == core.ConnectionTypePacket - logrus.WithFields(logrus.Fields{ - "username": username, - "src": addr.String(), - "id": id, - }).Debug("New stream") - if packet { - return core.ConnectResultBlocked, "unsupported", nil - } - conn, err := net.DialTimeout("tcp", config.RemoteAddr, dialTimeout) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err, - "dst": config.RemoteAddr, - }).Error("TCP error") - return core.ConnectResultFailed, err.Error(), nil - } - return core.ConnectResultSuccess, "", conn - }, - func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) { - logrus.WithFields(logrus.Fields{ - "error": err, - "username": username, - "src": addr.String(), - "id": id, - }).Debug("Stream closed") - }, - ) - if err != nil { - logrus.WithField("error", err).Fatal("Server initialization failed") - } - defer server.Close() - logrus.WithField("addr", config.ListenAddr).Info("Server up and running") - - err = server.Serve() - logrus.WithField("error", err).Fatal("Server shutdown") -} diff --git a/cmd/server.go b/cmd/server.go new file mode 100644 index 0000000..dab89a7 --- /dev/null +++ b/cmd/server.go @@ -0,0 +1,99 @@ +package main + +import ( + "crypto/tls" + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/congestion" + "github.com/sirupsen/logrus" + "github.com/tobyxdd/hysteria/pkg/acl" + hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion" + "github.com/tobyxdd/hysteria/pkg/core" + "github.com/tobyxdd/hysteria/pkg/obfs" + "net" + "strings" +) + +func server(config *serverConfig) { + logrus.WithField("config", config.String()).Info("Server configuration loaded") + // Load cert + cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile) + if err != nil { + logrus.WithFields(logrus.Fields{ + "error": err, + "cert": config.CertFile, + "key": config.KeyFile, + }).Fatal("Failed to load the certificate") + } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{tlsProtocolName}, + MinVersion: tls.VersionTLS13, + } + // QUIC config + quicConfig := &quic.Config{ + MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn, + MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient, + MaxIncomingStreams: int64(config.MaxConnClient), + KeepAlive: true, + } + if quicConfig.MaxReceiveStreamFlowControlWindow == 0 { + quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow + } + if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 { + quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow + } + if quicConfig.MaxIncomingStreams == 0 { + quicConfig.MaxIncomingStreams = DefaultMaxIncomingStreams + } + // Auth + var authFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) + if len(config.Auth.Mode) == 0 || strings.EqualFold(config.Auth.Mode, "none") { + logrus.Warn("No authentication configured") + authFunc = func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) { + return true, "Welcome" + } + } else { + // TODO + logrus.WithField("mode", config.Auth.Mode).Fatal("Unsupported authentication mode") + } + // Obfuscator + var obfuscator core.Obfuscator + if len(config.Obfs) > 0 { + obfuscator = obfs.XORObfuscator(config.Obfs) + } + // ACL + var aclEngine *acl.Engine + if len(config.ACL) > 0 { + aclEngine, err = acl.LoadFromFile(config.ACL) + if err != nil { + logrus.WithFields(logrus.Fields{ + "error": err, + "file": config.ACL, + }).Fatal("Failed to parse ACL") + } + aclEngine.DefaultAction = acl.ActionDirect + } + // Server + server, err := core.NewServer(config.Listen, tlsConfig, quicConfig, + uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, + func(refBPS uint64) congestion.CongestionControl { + return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) + }, aclEngine, obfuscator, authFunc, func(addr net.Addr, auth []byte, udp bool, reqAddr string) { + if !udp { + logrus.WithFields(logrus.Fields{ + "src": addr.String(), + "dst": reqAddr, + }).Debug("New TCP request") + } else { + // TODO + } + }) + if err != nil { + logrus.WithField("error", err).Fatal("Failed to initialize server") + } + defer server.Close() + logrus.WithField("addr", config.Listen).Info("Server up and running") + + err = server.Serve() + logrus.WithField("error", err).Fatal("Server shutdown") +} diff --git a/cmd/utils.go b/cmd/utils.go deleted file mode 100644 index 17fca0f..0000000 --- a/cmd/utils.go +++ /dev/null @@ -1,28 +0,0 @@ -package main - -import ( - "strconv" -) - -type optionalBoolFlag struct { - Exists bool - Value bool -} - -func (flag *optionalBoolFlag) String() string { - return strconv.FormatBool(flag.Value) -} - -func (flag *optionalBoolFlag) Set(s string) error { - v, err := strconv.ParseBool(s) - if err != nil { - return err - } - flag.Exists = true - flag.Value = v - return nil -} - -func (flag *optionalBoolFlag) IsBoolFlag() bool { - return true -} diff --git a/go.mod b/go.mod index 006c896..d739627 100644 --- a/go.mod +++ b/go.mod @@ -8,11 +8,13 @@ require ( github.com/golang/protobuf v1.4.2 github.com/hashicorp/golang-lru v0.5.4 github.com/lucas-clemente/quic-go v0.19.3 + github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/sirupsen/logrus v1.6.0 github.com/txthinking/runnergroup v0.0.0-20200327135940-540a793bb997 // indirect github.com/txthinking/socks5 v0.0.0-20200327133705-caf148ab5e9d github.com/txthinking/x v0.0.0-20200330144832-5ad2416896a9 // indirect + github.com/yosuke-furukawa/json5 v0.1.1 ) replace github.com/lucas-clemente/quic-go => github.com/tobyxdd/quic-go v0.19.4-0.20210127052624-0ecb862c82b5 diff --git a/go.sum b/go.sum index c74ef3d..3b78eef 100644 --- a/go.sum +++ b/go.sum @@ -87,6 +87,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 h1:EnfXoSqDfSNJv0VBNqY/88RNnhSGYkrHaO0mmFGbVsc= +github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40/go.mod h1:vy1vK6wD6j7xX6O6hXe621WabdtNkou2h7uRtTfRMyg= github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc= @@ -166,6 +168,8 @@ github.com/txthinking/x v0.0.0-20200330144832-5ad2416896a9 h1:ngJOce33YJJT1PFTfC github.com/txthinking/x v0.0.0-20200330144832-5ad2416896a9/go.mod h1:WgqbSEmUYSjEV3B1qmee/PpP2NYEz4bL9/+mF1ma+s4= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/yosuke-furukawa/json5 v0.1.1 h1:0F9mNwTvOuDNH243hoPqvf+dxa5QsKnZzU20uNsh3ZI= +github.com/yosuke-furukawa/json5 v0.1.1/go.mod h1:sw49aWDqNdRJ6DYUtIQiaA3xyj2IL9tjeNYmX2ixwcU= go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE= diff --git a/pkg/core/client.go b/pkg/core/client.go index b590adc..adee661 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -6,44 +6,45 @@ import ( "errors" "fmt" "github.com/lucas-clemente/quic-go" - "github.com/tobyxdd/hysteria/pkg/core/pb" - "github.com/tobyxdd/hysteria/pkg/utils" + "github.com/lucas-clemente/quic-go/congestion" + "github.com/lunixbochs/struc" "net" "sync" - "sync/atomic" + "time" ) var ( ErrClosed = errors.New("client closed") ) -type Client struct { - inboundBytes, outboundBytes uint64 // atomic +type CongestionFactory func(refBPS uint64) congestion.CongestionControl - reconnectMutex sync.Mutex - closed bool - quicSession quic.Session - serverAddr string - username, password string - tlsConfig *tls.Config - quicConfig *quic.Config - sendBPS, recvBPS uint64 - congestionFactory CongestionFactory - obfuscator Obfuscator +type Client struct { + serverAddr string + sendBPS, recvBPS uint64 + auth []byte + congestionFactory CongestionFactory + obfuscator Obfuscator + + tlsConfig *tls.Config + quicConfig *quic.Config + + quicSession quic.Session + reconnectMutex sync.Mutex + closed bool } -func NewClient(serverAddr string, username string, password string, tlsConfig *tls.Config, quicConfig *quic.Config, +func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator Obfuscator) (*Client, error) { c := &Client{ serverAddr: serverAddr, - username: username, - password: password, - tlsConfig: tlsConfig, - quicConfig: quicConfig, sendBPS: sendBPS, recvBPS: recvBPS, + auth: auth, congestionFactory: congestionFactory, obfuscator: obfuscator, + tlsConfig: tlsConfig, + quicConfig: quicConfig, } if err := c.connectToServer(); err != nil { return nil, err @@ -51,58 +52,6 @@ func NewClient(serverAddr string, username string, password string, tlsConfig *t return c, nil } -func (c *Client) Dial(packet bool, addr string) (net.Conn, error) { - stream, localAddr, remoteAddr, err := c.openStreamWithReconnect() - if err != nil { - return nil, err - } - // Send request - req := &pb.ClientConnectRequest{Address: addr} - if packet { - req.Type = pb.ConnectionType_Packet - } else { - req.Type = pb.ConnectionType_Stream - } - err = writeClientConnectRequest(stream, req) - if err != nil { - _ = stream.Close() - return nil, err - } - // Read response - resp, err := readServerConnectResponse(stream) - if err != nil { - _ = stream.Close() - return nil, err - } - if resp.Result != pb.ConnectResult_CONN_SUCCESS { - _ = stream.Close() - return nil, fmt.Errorf("server rejected the connection %s (msg: %s)", - resp.Result.String(), resp.Message) - } - connWrap := &utils.QUICStreamWrapperConn{ - Orig: stream, - PseudoLocalAddr: localAddr, - PseudoRemoteAddr: remoteAddr, - } - if packet { - return &utils.PacketWrapperConn{Orig: connWrap}, nil - } else { - return connWrap, nil - } -} - -func (c *Client) Stats() (uint64, uint64) { - return atomic.LoadUint64(&c.inboundBytes), atomic.LoadUint64(&c.outboundBytes) -} - -func (c *Client) Close() error { - c.reconnectMutex.Lock() - defer c.reconnectMutex.Unlock() - err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "generic") - c.closed = true - return err -} - func (c *Client) connectToServer() error { serverUDPAddr, err := net.ResolveUDPAddr("udp", c.serverAddr) if err != nil { @@ -124,51 +73,50 @@ func (c *Client) connectToServer() error { return err } // Control stream - ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout) - ctlStream, err := qs.OpenStreamSync(ctx) + ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout) + stream, err := qs.OpenStreamSync(ctx) ctxCancel() if err != nil { - _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error") + _ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error") return err } - result, msg, err := c.handleControlStream(qs, ctlStream) + ok, msg, err := c.handleControlStream(qs, stream) if err != nil { - _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error") + _ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error") return err } - if result != pb.AuthResult_AUTH_SUCCESS { - _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "authentication failure") - return fmt.Errorf("authentication failure %s (msg: %s)", result.String(), msg) + if !ok { + _ = qs.CloseWithError(closeErrorCodeAuth, "auth error") + return fmt.Errorf("auth error: %s", msg) } // All good c.quicSession = qs return nil } -func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (pb.AuthResult, string, error) { - err := writeClientAuthRequest(stream, &pb.ClientAuthRequest{ - Credential: &pb.Credential{ - Username: c.username, - Password: c.password, - }, - Speed: &pb.Speed{ - SendBps: c.sendBPS, - ReceiveBps: c.recvBPS, +func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool, string, error) { + // Send client hello + err := struc.Pack(stream, &clientHello{ + Rate: transmissionRate{ + SendBPS: c.sendBPS, + RecvBPS: c.recvBPS, }, + Auth: c.auth, }) if err != nil { - return 0, "", err + return false, "", err } - // Response - resp, err := readServerAuthResponse(stream) + // Receive server hello + var sh serverHello + err = struc.Unpack(stream, &sh) if err != nil { - return 0, "", err + return false, "", err } // Set the congestion accordingly - if resp.Result == pb.AuthResult_AUTH_SUCCESS && c.congestionFactory != nil { - qs.SetCongestionControl(c.congestionFactory(resp.Speed.ReceiveBps)) + if sh.OK && c.congestionFactory != nil { + qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS)) } - return resp.Result, resp.Message, nil + return true, sh.Message, nil } func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, error) { @@ -196,3 +144,81 @@ func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, err stream, err = c.quicSession.OpenStream() return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), err } + +func (c *Client) DialTCP(addr string) (net.Conn, error) { + stream, localAddr, remoteAddr, err := c.openStreamWithReconnect() + if err != nil { + return nil, err + } + // Send request + err = struc.Pack(stream, &clientRequest{ + UDP: false, + Address: addr, + }) + if err != nil { + _ = stream.Close() + return nil, err + } + // Read response + var sr serverResponse + err = struc.Unpack(stream, &sr) + if err != nil { + _ = stream.Close() + return nil, err + } + if !sr.OK { + _ = stream.Close() + return nil, fmt.Errorf("connection rejected: %s", sr.Message) + } + return &quicConn{ + Orig: stream, + PseudoLocalAddr: localAddr, + PseudoRemoteAddr: remoteAddr, + }, nil +} + +func (c *Client) Close() error { + c.reconnectMutex.Lock() + defer c.reconnectMutex.Unlock() + err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "") + c.closed = true + return err +} + +type quicConn struct { + Orig quic.Stream + PseudoLocalAddr net.Addr + PseudoRemoteAddr net.Addr +} + +func (w *quicConn) Read(b []byte) (n int, err error) { + return w.Orig.Read(b) +} + +func (w *quicConn) Write(b []byte) (n int, err error) { + return w.Orig.Write(b) +} + +func (w *quicConn) Close() error { + return w.Orig.Close() +} + +func (w *quicConn) LocalAddr() net.Addr { + return w.PseudoLocalAddr +} + +func (w *quicConn) RemoteAddr() net.Addr { + return w.PseudoRemoteAddr +} + +func (w *quicConn) SetDeadline(t time.Time) error { + return w.Orig.SetDeadline(t) +} + +func (w *quicConn) SetReadDeadline(t time.Time) error { + return w.Orig.SetReadDeadline(t) +} + +func (w *quicConn) SetWriteDeadline(t time.Time) error { + return w.Orig.SetWriteDeadline(t) +} diff --git a/pkg/core/control.go b/pkg/core/control.go deleted file mode 100644 index 81619a5..0000000 --- a/pkg/core/control.go +++ /dev/null @@ -1,104 +0,0 @@ -package core - -import ( - "encoding/binary" - "github.com/golang/protobuf/proto" - "github.com/tobyxdd/hysteria/pkg/core/pb" - "io" -) - -const ( - closeErrorCodeGeneric = 0 - closeErrorCodeProtocolFailure = 1 -) - -func readDataBlock(r io.Reader) ([]byte, error) { - var sz uint32 - if err := binary.Read(r, controlProtocolEndian, &sz); err != nil { - return nil, err - } - buf := make([]byte, sz) - _, err := io.ReadFull(r, buf) - return buf, err -} - -func writeDataBlock(w io.Writer, data []byte) error { - sz := uint32(len(data)) - if err := binary.Write(w, controlProtocolEndian, &sz); err != nil { - return err - } - _, err := w.Write(data) - return err -} - -func readClientAuthRequest(r io.Reader) (*pb.ClientAuthRequest, error) { - bs, err := readDataBlock(r) - if err != nil { - return nil, err - } - var req pb.ClientAuthRequest - err = proto.Unmarshal(bs, &req) - return &req, err -} - -func writeClientAuthRequest(w io.Writer, req *pb.ClientAuthRequest) error { - bs, err := proto.Marshal(req) - if err != nil { - return err - } - return writeDataBlock(w, bs) -} - -func readServerAuthResponse(r io.Reader) (*pb.ServerAuthResponse, error) { - bs, err := readDataBlock(r) - if err != nil { - return nil, err - } - var resp pb.ServerAuthResponse - err = proto.Unmarshal(bs, &resp) - return &resp, err -} - -func writeServerAuthResponse(w io.Writer, resp *pb.ServerAuthResponse) error { - bs, err := proto.Marshal(resp) - if err != nil { - return err - } - return writeDataBlock(w, bs) -} - -func readClientConnectRequest(r io.Reader) (*pb.ClientConnectRequest, error) { - bs, err := readDataBlock(r) - if err != nil { - return nil, err - } - var req pb.ClientConnectRequest - err = proto.Unmarshal(bs, &req) - return &req, err -} - -func writeClientConnectRequest(w io.Writer, req *pb.ClientConnectRequest) error { - bs, err := proto.Marshal(req) - if err != nil { - return err - } - return writeDataBlock(w, bs) -} - -func readServerConnectResponse(r io.Reader) (*pb.ServerConnectResponse, error) { - bs, err := readDataBlock(r) - if err != nil { - return nil, err - } - var resp pb.ServerConnectResponse - err = proto.Unmarshal(bs, &resp) - return &resp, err -} - -func writeServerConnectResponse(w io.Writer, resp *pb.ServerConnectResponse) error { - bs, err := proto.Marshal(resp) - if err != nil { - return err - } - return writeDataBlock(w, bs) -} diff --git a/pkg/core/pb/control.pb.go b/pkg/core/pb/control.pb.go deleted file mode 100644 index 47318f8..0000000 --- a/pkg/core/pb/control.pb.go +++ /dev/null @@ -1,440 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// source: control.proto - -package pb - -import ( - fmt "fmt" - proto "github.com/golang/protobuf/proto" - math "math" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the proto package it is being compiled against. -// A compilation error at this line likely means your copy of the -// proto package needs to be updated. -const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package - -type AuthResult int32 - -const ( - AuthResult_AUTH_SUCCESS AuthResult = 0 - AuthResult_AUTH_INVALID_CRED AuthResult = 1 - AuthResult_AUTH_INTERNAL_ERROR AuthResult = 2 -) - -var AuthResult_name = map[int32]string{ - 0: "AUTH_SUCCESS", - 1: "AUTH_INVALID_CRED", - 2: "AUTH_INTERNAL_ERROR", -} - -var AuthResult_value = map[string]int32{ - "AUTH_SUCCESS": 0, - "AUTH_INVALID_CRED": 1, - "AUTH_INTERNAL_ERROR": 2, -} - -func (x AuthResult) String() string { - return proto.EnumName(AuthResult_name, int32(x)) -} - -func (AuthResult) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{0} -} - -type ConnectionType int32 - -const ( - ConnectionType_Stream ConnectionType = 0 - ConnectionType_Packet ConnectionType = 1 -) - -var ConnectionType_name = map[int32]string{ - 0: "Stream", - 1: "Packet", -} - -var ConnectionType_value = map[string]int32{ - "Stream": 0, - "Packet": 1, -} - -func (x ConnectionType) String() string { - return proto.EnumName(ConnectionType_name, int32(x)) -} - -func (ConnectionType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{1} -} - -type ConnectResult int32 - -const ( - ConnectResult_CONN_SUCCESS ConnectResult = 0 - ConnectResult_CONN_FAILED ConnectResult = 1 - ConnectResult_CONN_BLOCKED ConnectResult = 2 -) - -var ConnectResult_name = map[int32]string{ - 0: "CONN_SUCCESS", - 1: "CONN_FAILED", - 2: "CONN_BLOCKED", -} - -var ConnectResult_value = map[string]int32{ - "CONN_SUCCESS": 0, - "CONN_FAILED": 1, - "CONN_BLOCKED": 2, -} - -func (x ConnectResult) String() string { - return proto.EnumName(ConnectResult_name, int32(x)) -} - -func (ConnectResult) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{2} -} - -type Speed struct { - SendBps uint64 `protobuf:"varint,1,opt,name=send_bps,json=sendBps,proto3" json:"send_bps,omitempty"` - ReceiveBps uint64 `protobuf:"varint,2,opt,name=receive_bps,json=receiveBps,proto3" json:"receive_bps,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *Speed) Reset() { *m = Speed{} } -func (m *Speed) String() string { return proto.CompactTextString(m) } -func (*Speed) ProtoMessage() {} -func (*Speed) Descriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{0} -} - -func (m *Speed) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_Speed.Unmarshal(m, b) -} -func (m *Speed) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_Speed.Marshal(b, m, deterministic) -} -func (m *Speed) XXX_Merge(src proto.Message) { - xxx_messageInfo_Speed.Merge(m, src) -} -func (m *Speed) XXX_Size() int { - return xxx_messageInfo_Speed.Size(m) -} -func (m *Speed) XXX_DiscardUnknown() { - xxx_messageInfo_Speed.DiscardUnknown(m) -} - -var xxx_messageInfo_Speed proto.InternalMessageInfo - -func (m *Speed) GetSendBps() uint64 { - if m != nil { - return m.SendBps - } - return 0 -} - -func (m *Speed) GetReceiveBps() uint64 { - if m != nil { - return m.ReceiveBps - } - return 0 -} - -type Credential struct { - Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` - Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *Credential) Reset() { *m = Credential{} } -func (m *Credential) String() string { return proto.CompactTextString(m) } -func (*Credential) ProtoMessage() {} -func (*Credential) Descriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{1} -} - -func (m *Credential) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_Credential.Unmarshal(m, b) -} -func (m *Credential) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_Credential.Marshal(b, m, deterministic) -} -func (m *Credential) XXX_Merge(src proto.Message) { - xxx_messageInfo_Credential.Merge(m, src) -} -func (m *Credential) XXX_Size() int { - return xxx_messageInfo_Credential.Size(m) -} -func (m *Credential) XXX_DiscardUnknown() { - xxx_messageInfo_Credential.DiscardUnknown(m) -} - -var xxx_messageInfo_Credential proto.InternalMessageInfo - -func (m *Credential) GetUsername() string { - if m != nil { - return m.Username - } - return "" -} - -func (m *Credential) GetPassword() string { - if m != nil { - return m.Password - } - return "" -} - -type ClientAuthRequest struct { - Credential *Credential `protobuf:"bytes,1,opt,name=credential,proto3" json:"credential,omitempty"` - Speed *Speed `protobuf:"bytes,2,opt,name=speed,proto3" json:"speed,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *ClientAuthRequest) Reset() { *m = ClientAuthRequest{} } -func (m *ClientAuthRequest) String() string { return proto.CompactTextString(m) } -func (*ClientAuthRequest) ProtoMessage() {} -func (*ClientAuthRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{2} -} - -func (m *ClientAuthRequest) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_ClientAuthRequest.Unmarshal(m, b) -} -func (m *ClientAuthRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_ClientAuthRequest.Marshal(b, m, deterministic) -} -func (m *ClientAuthRequest) XXX_Merge(src proto.Message) { - xxx_messageInfo_ClientAuthRequest.Merge(m, src) -} -func (m *ClientAuthRequest) XXX_Size() int { - return xxx_messageInfo_ClientAuthRequest.Size(m) -} -func (m *ClientAuthRequest) XXX_DiscardUnknown() { - xxx_messageInfo_ClientAuthRequest.DiscardUnknown(m) -} - -var xxx_messageInfo_ClientAuthRequest proto.InternalMessageInfo - -func (m *ClientAuthRequest) GetCredential() *Credential { - if m != nil { - return m.Credential - } - return nil -} - -func (m *ClientAuthRequest) GetSpeed() *Speed { - if m != nil { - return m.Speed - } - return nil -} - -type ServerAuthResponse struct { - Result AuthResult `protobuf:"varint,1,opt,name=result,proto3,enum=core.AuthResult" json:"result,omitempty"` - Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` - Speed *Speed `protobuf:"bytes,3,opt,name=speed,proto3" json:"speed,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *ServerAuthResponse) Reset() { *m = ServerAuthResponse{} } -func (m *ServerAuthResponse) String() string { return proto.CompactTextString(m) } -func (*ServerAuthResponse) ProtoMessage() {} -func (*ServerAuthResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{3} -} - -func (m *ServerAuthResponse) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_ServerAuthResponse.Unmarshal(m, b) -} -func (m *ServerAuthResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_ServerAuthResponse.Marshal(b, m, deterministic) -} -func (m *ServerAuthResponse) XXX_Merge(src proto.Message) { - xxx_messageInfo_ServerAuthResponse.Merge(m, src) -} -func (m *ServerAuthResponse) XXX_Size() int { - return xxx_messageInfo_ServerAuthResponse.Size(m) -} -func (m *ServerAuthResponse) XXX_DiscardUnknown() { - xxx_messageInfo_ServerAuthResponse.DiscardUnknown(m) -} - -var xxx_messageInfo_ServerAuthResponse proto.InternalMessageInfo - -func (m *ServerAuthResponse) GetResult() AuthResult { - if m != nil { - return m.Result - } - return AuthResult_AUTH_SUCCESS -} - -func (m *ServerAuthResponse) GetMessage() string { - if m != nil { - return m.Message - } - return "" -} - -func (m *ServerAuthResponse) GetSpeed() *Speed { - if m != nil { - return m.Speed - } - return nil -} - -type ClientConnectRequest struct { - Type ConnectionType `protobuf:"varint,1,opt,name=type,proto3,enum=core.ConnectionType" json:"type,omitempty"` - Address string `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *ClientConnectRequest) Reset() { *m = ClientConnectRequest{} } -func (m *ClientConnectRequest) String() string { return proto.CompactTextString(m) } -func (*ClientConnectRequest) ProtoMessage() {} -func (*ClientConnectRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{4} -} - -func (m *ClientConnectRequest) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_ClientConnectRequest.Unmarshal(m, b) -} -func (m *ClientConnectRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_ClientConnectRequest.Marshal(b, m, deterministic) -} -func (m *ClientConnectRequest) XXX_Merge(src proto.Message) { - xxx_messageInfo_ClientConnectRequest.Merge(m, src) -} -func (m *ClientConnectRequest) XXX_Size() int { - return xxx_messageInfo_ClientConnectRequest.Size(m) -} -func (m *ClientConnectRequest) XXX_DiscardUnknown() { - xxx_messageInfo_ClientConnectRequest.DiscardUnknown(m) -} - -var xxx_messageInfo_ClientConnectRequest proto.InternalMessageInfo - -func (m *ClientConnectRequest) GetType() ConnectionType { - if m != nil { - return m.Type - } - return ConnectionType_Stream -} - -func (m *ClientConnectRequest) GetAddress() string { - if m != nil { - return m.Address - } - return "" -} - -type ServerConnectResponse struct { - Result ConnectResult `protobuf:"varint,1,opt,name=result,proto3,enum=core.ConnectResult" json:"result,omitempty"` - Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *ServerConnectResponse) Reset() { *m = ServerConnectResponse{} } -func (m *ServerConnectResponse) String() string { return proto.CompactTextString(m) } -func (*ServerConnectResponse) ProtoMessage() {} -func (*ServerConnectResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{5} -} - -func (m *ServerConnectResponse) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_ServerConnectResponse.Unmarshal(m, b) -} -func (m *ServerConnectResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_ServerConnectResponse.Marshal(b, m, deterministic) -} -func (m *ServerConnectResponse) XXX_Merge(src proto.Message) { - xxx_messageInfo_ServerConnectResponse.Merge(m, src) -} -func (m *ServerConnectResponse) XXX_Size() int { - return xxx_messageInfo_ServerConnectResponse.Size(m) -} -func (m *ServerConnectResponse) XXX_DiscardUnknown() { - xxx_messageInfo_ServerConnectResponse.DiscardUnknown(m) -} - -var xxx_messageInfo_ServerConnectResponse proto.InternalMessageInfo - -func (m *ServerConnectResponse) GetResult() ConnectResult { - if m != nil { - return m.Result - } - return ConnectResult_CONN_SUCCESS -} - -func (m *ServerConnectResponse) GetMessage() string { - if m != nil { - return m.Message - } - return "" -} - -func init() { - proto.RegisterEnum("core.AuthResult", AuthResult_name, AuthResult_value) - proto.RegisterEnum("core.ConnectionType", ConnectionType_name, ConnectionType_value) - proto.RegisterEnum("core.ConnectResult", ConnectResult_name, ConnectResult_value) - proto.RegisterType((*Speed)(nil), "core.Speed") - proto.RegisterType((*Credential)(nil), "core.Credential") - proto.RegisterType((*ClientAuthRequest)(nil), "core.ClientAuthRequest") - proto.RegisterType((*ServerAuthResponse)(nil), "core.ServerAuthResponse") - proto.RegisterType((*ClientConnectRequest)(nil), "core.ClientConnectRequest") - proto.RegisterType((*ServerConnectResponse)(nil), "core.ServerConnectResponse") -} - -func init() { - proto.RegisterFile("control.proto", fileDescriptor_0c5120591600887d) -} - -var fileDescriptor_0c5120591600887d = []byte{ - // 434 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xd1, 0x6e, 0xd3, 0x30, - 0x14, 0x86, 0xd7, 0xd2, 0x75, 0xdb, 0x09, 0x1b, 0x99, 0xb7, 0x89, 0xc1, 0x0d, 0x90, 0xab, 0xaa, - 0x48, 0x15, 0x1a, 0x4f, 0x90, 0x3a, 0x41, 0x54, 0x54, 0x29, 0x72, 0x3a, 0x2e, 0xb8, 0xa0, 0xca, - 0x92, 0x23, 0x56, 0x91, 0xda, 0xc6, 0x76, 0x86, 0x26, 0x5e, 0x1e, 0xc5, 0x71, 0xd2, 0x16, 0x09, - 0x89, 0xbb, 0x9e, 0x73, 0x7e, 0xfd, 0x9f, 0xbf, 0x2a, 0x70, 0x9a, 0x0b, 0x6e, 0x94, 0x28, 0x27, - 0x52, 0x09, 0x23, 0xc8, 0x20, 0x17, 0x0a, 0x03, 0x0a, 0x87, 0xa9, 0x44, 0x2c, 0xc8, 0x0b, 0x38, - 0xd6, 0xc8, 0x8b, 0xd5, 0x9d, 0xd4, 0xd7, 0xbd, 0xd7, 0xbd, 0xd1, 0x80, 0x1d, 0xd5, 0xf3, 0x54, - 0x6a, 0xf2, 0x0a, 0x3c, 0x85, 0x39, 0xae, 0x1f, 0xd0, 0x5e, 0xfb, 0xf6, 0x0a, 0x6e, 0x35, 0x95, - 0x3a, 0x88, 0x00, 0xa8, 0xc2, 0x02, 0xb9, 0x59, 0x67, 0x25, 0x79, 0x09, 0xc7, 0x95, 0x46, 0xc5, - 0xb3, 0x0d, 0xda, 0xa6, 0x13, 0xd6, 0xcd, 0xf5, 0x4d, 0x66, 0x5a, 0xff, 0x12, 0xaa, 0xb0, 0x3d, - 0x27, 0xac, 0x9b, 0x83, 0x7b, 0x38, 0xa7, 0xe5, 0x1a, 0xb9, 0x09, 0x2b, 0x73, 0xcf, 0xf0, 0x67, - 0x85, 0xda, 0x90, 0x77, 0x00, 0x79, 0x57, 0x6d, 0xeb, 0xbc, 0x1b, 0x7f, 0x52, 0x3f, 0x7d, 0xb2, - 0x45, 0xb2, 0x9d, 0x0c, 0x79, 0x03, 0x87, 0xba, 0x36, 0xb2, 0xfd, 0xde, 0x8d, 0xd7, 0x84, 0xad, - 0x24, 0x6b, 0x2e, 0xc1, 0x6f, 0x20, 0x29, 0xaa, 0x07, 0x54, 0x0d, 0x49, 0x4b, 0xc1, 0x35, 0x92, - 0x11, 0x0c, 0x15, 0xea, 0xaa, 0x34, 0x16, 0x73, 0xd6, 0x62, 0x5c, 0xa6, 0x2a, 0x0d, 0x73, 0x77, - 0x72, 0x0d, 0x47, 0x1b, 0xd4, 0x3a, 0xfb, 0x8e, 0x4e, 0xa2, 0x1d, 0xb7, 0xf0, 0x27, 0xff, 0x84, - 0x7f, 0x85, 0xcb, 0x46, 0x93, 0x0a, 0xce, 0x31, 0x37, 0xad, 0xe9, 0x08, 0x06, 0xe6, 0x51, 0xa2, - 0x83, 0x5f, 0x3a, 0xc7, 0x26, 0xb3, 0x16, 0x7c, 0xf9, 0x28, 0x91, 0xd9, 0x44, 0x8d, 0xcf, 0x8a, - 0x42, 0xa1, 0xd6, 0x2d, 0xde, 0x8d, 0xc1, 0x37, 0xb8, 0x6a, 0xc4, 0xba, 0x6e, 0xe7, 0xf6, 0xf6, - 0x2f, 0xb7, 0x8b, 0xbd, 0xfa, 0xff, 0xd5, 0x1b, 0x27, 0x00, 0xdb, 0xbf, 0x83, 0xf8, 0xf0, 0x34, - 0xbc, 0x5d, 0x7e, 0x5c, 0xa5, 0xb7, 0x94, 0xc6, 0x69, 0xea, 0x1f, 0x90, 0x2b, 0x38, 0xb7, 0x9b, - 0x59, 0xf2, 0x25, 0x9c, 0xcf, 0xa2, 0x15, 0x65, 0x71, 0xe4, 0xf7, 0xc8, 0x73, 0xb8, 0x70, 0xeb, - 0x65, 0xcc, 0x92, 0x70, 0xbe, 0x8a, 0x19, 0x5b, 0x30, 0xbf, 0x3f, 0x1e, 0xc1, 0xd9, 0xbe, 0x21, - 0x01, 0x18, 0xa6, 0x46, 0x61, 0xb6, 0xf1, 0x0f, 0xea, 0xdf, 0x9f, 0xb3, 0xfc, 0x07, 0x1a, 0xbf, - 0x37, 0x8e, 0xe0, 0x74, 0xef, 0xb1, 0x35, 0x9c, 0x2e, 0x92, 0x64, 0x07, 0xfe, 0x0c, 0x3c, 0xbb, - 0xf9, 0x10, 0xce, 0xe6, 0x16, 0xdb, 0x46, 0xa6, 0xf3, 0x05, 0xfd, 0x14, 0x47, 0x7e, 0xff, 0x6e, - 0x68, 0x3f, 0xfd, 0xf7, 0x7f, 0x02, 0x00, 0x00, 0xff, 0xff, 0xd2, 0x2f, 0x8d, 0x6f, 0x0b, 0x03, - 0x00, 0x00, -} diff --git a/pkg/core/pb/control.proto b/pkg/core/pb/control.proto deleted file mode 100644 index a72eda3..0000000 --- a/pkg/core/pb/control.proto +++ /dev/null @@ -1,50 +0,0 @@ -syntax = "proto3"; -package core; - -message Speed { - uint64 send_bps = 1; - uint64 receive_bps = 2; -} - -message Credential { - string username = 1; - string password = 2; -} - -enum AuthResult { - AUTH_SUCCESS = 0; - AUTH_INVALID_CRED = 1; - AUTH_INTERNAL_ERROR = 2; -} - -message ClientAuthRequest { - Credential credential = 1; - Speed speed = 2; -} - -message ServerAuthResponse { - AuthResult result = 1; - string message = 2; - Speed speed = 3; -} - -enum ConnectionType { - Stream = 0; - Packet = 1; -} - -enum ConnectResult { - CONN_SUCCESS = 0; - CONN_FAILED = 1; - CONN_BLOCKED = 2; -} - -message ClientConnectRequest { - ConnectionType type = 1; - string address = 2; -} - -message ServerConnectResponse { - ConnectResult result = 1; - string message = 2; -} \ No newline at end of file diff --git a/pkg/core/pb/protogen.go b/pkg/core/pb/protogen.go deleted file mode 100644 index 871b9ce..0000000 --- a/pkg/core/pb/protogen.go +++ /dev/null @@ -1,3 +0,0 @@ -package pb - -//go:generate protoc --go_out=. control.proto diff --git a/pkg/core/protocol.go b/pkg/core/protocol.go new file mode 100644 index 0000000..83e9b8c --- /dev/null +++ b/pkg/core/protocol.go @@ -0,0 +1,45 @@ +package core + +import ( + "time" +) + +const ( + protocolVersion = uint8(1) + protocolTimeout = 10 * time.Second + + closeErrorCodeGeneric = 0 + closeErrorCodeProtocol = 1 + closeErrorCodeAuth = 2 +) + +type transmissionRate struct { + SendBPS uint64 + RecvBPS uint64 +} + +type clientHello struct { + Rate transmissionRate + AuthLen uint16 `struc:"sizeof=Auth"` + Auth []byte +} + +type serverHello struct { + OK bool + Rate transmissionRate + MessageLen uint16 `struc:"sizeof=Message"` + Message string +} + +type clientRequest struct { + UDP bool + AddressLen uint16 `struc:"sizeof=Address"` + Address string +} + +type serverResponse struct { + OK bool + UDPSessionID uint32 + MessageLen uint16 `struc:"sizeof=Message"` + Message string +} diff --git a/pkg/core/server.go b/pkg/core/server.go index 2a28270..da06b6f 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -4,61 +4,32 @@ import ( "context" "crypto/tls" "errors" - "fmt" "github.com/lucas-clemente/quic-go" - "github.com/tobyxdd/hysteria/pkg/core/pb" + "github.com/lunixbochs/struc" + "github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/utils" - "io" "net" - "sync/atomic" + "time" ) -type AuthResult int32 -type ConnectionType int32 -type ConnectResult int32 +const dialTimeout = 10 * time.Second -const ( - AuthResultSuccess AuthResult = iota - AuthResultInvalidCred - AuthResultInternalError -) - -const ( - ConnectionTypeStream ConnectionType = iota - ConnectionTypePacket -) - -const ( - ConnectResultSuccess ConnectResult = iota - ConnectResultFailed - ConnectResultBlocked -) - -type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string) -type ClientDisconnectedFunc func(addr net.Addr, username string, err error) -type HandleRequestFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string) (ConnectResult, string, io.ReadWriteCloser) -type RequestClosedFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string, err error) +type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) +type RequestFunc func(addr net.Addr, auth []byte, udp bool, reqAddr string) type Server struct { - inboundBytes, outboundBytes uint64 // atomic + sendBPS, recvBPS uint64 + congestionFactory CongestionFactory + authFunc AuthFunc + requestFunc RequestFunc + aclEngine *acl.Engine - listener quic.Listener - sendBPS, recvBPS uint64 - - congestionFactory CongestionFactory - clientAuthFunc ClientAuthFunc - clientDisconnectedFunc ClientDisconnectedFunc - handleRequestFunc HandleRequestFunc - requestClosedFunc RequestClosedFunc + listener quic.Listener } func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, - sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, - obfuscator Obfuscator, - clientAuthFunc ClientAuthFunc, - clientDisconnectedFunc ClientDisconnectedFunc, - handleRequestFunc HandleRequestFunc, - requestClosedFunc RequestClosedFunc) (*Server, error) { + sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, aclEngine *acl.Engine, + obfuscator Obfuscator, authFunc AuthFunc, requestFunc RequestFunc) (*Server, error) { packetConn, err := net.ListenPacket("udp", addr) if err != nil { return nil, err @@ -75,14 +46,13 @@ func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, return nil, err } s := &Server{ - listener: listener, - sendBPS: sendBPS, - recvBPS: recvBPS, - congestionFactory: congestionFactory, - clientAuthFunc: clientAuthFunc, - clientDisconnectedFunc: clientDisconnectedFunc, - handleRequestFunc: handleRequestFunc, - requestClosedFunc: requestClosedFunc, + listener: listener, + sendBPS: sendBPS, + recvBPS: recvBPS, + congestionFactory: congestionFactory, + authFunc: authFunc, + requestFunc: requestFunc, + aclEngine: aclEngine, } return s, nil } @@ -97,128 +67,156 @@ func (s *Server) Serve() error { } } -func (s *Server) Stats() (uint64, uint64) { - return atomic.LoadUint64(&s.inboundBytes), atomic.LoadUint64(&s.outboundBytes) -} - func (s *Server) Close() error { return s.listener.Close() } func (s *Server) handleClient(cs quic.Session) { // Expect the client to create a control stream to send its own information - ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout) - ctlStream, err := cs.AcceptStream(ctx) + ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout) + stream, err := cs.AcceptStream(ctx) ctxCancel() if err != nil { - _ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error") + _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") return } // Handle the control stream - username, ok, err := s.handleControlStream(cs, ctlStream) + auth, ok, err := s.handleControlStream(cs, stream) if err != nil { - _ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error") + _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") return } if !ok { - _ = cs.CloseWithError(closeErrorCodeGeneric, "authentication failure") + _ = cs.CloseWithError(closeErrorCodeAuth, "auth error") return } // Start accepting streams - var closeErr error for { stream, err := cs.AcceptStream(context.Background()) if err != nil { - closeErr = err break } - go s.handleStream(cs.LocalAddr(), cs.RemoteAddr(), username, stream) + go s.handleStream(cs.RemoteAddr(), auth, stream) } - s.clientDisconnectedFunc(cs.RemoteAddr(), username, closeErr) - _ = cs.CloseWithError(closeErrorCodeGeneric, "generic") + _ = cs.CloseWithError(closeErrorCodeGeneric, "") } // Auth & negotiate speed -func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) (string, bool, error) { - req, err := readClientAuthRequest(stream) +func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, error) { + var ch clientHello + err := struc.Unpack(stream, &ch) if err != nil { - return "", false, err + return nil, false, err } // Speed - if req.Speed == nil || req.Speed.SendBps == 0 || req.Speed.ReceiveBps == 0 { - return "", false, errors.New("incorrect speed provided by the client") + if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 { + return nil, false, errors.New("invalid rate from client") } - serverSendBPS, serverReceiveBPS := req.Speed.ReceiveBps, req.Speed.SendBps + serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS if s.sendBPS > 0 && serverSendBPS > s.sendBPS { serverSendBPS = s.sendBPS } - if s.recvBPS > 0 && serverReceiveBPS > s.recvBPS { - serverReceiveBPS = s.recvBPS + if s.recvBPS > 0 && serverRecvBPS > s.recvBPS { + serverRecvBPS = s.recvBPS } // Auth - if req.Credential == nil { - return "", false, errors.New("incorrect credential provided by the client") - } - authResult, msg := s.clientAuthFunc(cs.RemoteAddr(), req.Credential.Username, req.Credential.Password, - serverSendBPS, serverReceiveBPS) + ok, msg := s.authFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS) // Response - err = writeServerAuthResponse(stream, &pb.ServerAuthResponse{ - Result: pb.AuthResult(authResult), - Message: msg, - Speed: &pb.Speed{ - SendBps: serverSendBPS, - ReceiveBps: serverReceiveBPS, + err = struc.Pack(stream, &serverHello{ + OK: ok, + Rate: transmissionRate{ + SendBPS: serverSendBPS, + RecvBPS: serverRecvBPS, }, + Message: msg, }) if err != nil { - return "", false, err + return nil, false, err } // Set the congestion accordingly - if authResult == AuthResultSuccess && s.congestionFactory != nil { + if ok && s.congestionFactory != nil { cs.SetCongestionControl(s.congestionFactory(serverSendBPS)) } - return req.Credential.Username, authResult == AuthResultSuccess, nil + return ch.Auth, ok, nil } -func (s *Server) handleStream(localAddr net.Addr, remoteAddr net.Addr, username string, stream quic.Stream) { +func (s *Server) handleStream(remoteAddr net.Addr, auth []byte, stream quic.Stream) { defer stream.Close() // Read request - req, err := readClientConnectRequest(stream) + var req clientRequest + err := struc.Unpack(stream, &req) if err != nil { return } - // Create connection with the handler - result, msg, conn := s.handleRequestFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address) - defer func() { - if conn != nil { - _ = conn.Close() + s.requestFunc(remoteAddr, auth, req.UDP, req.Address) + if !req.UDP { + // TCP connection + s.handleTCP(stream, req.Address) + } else { + // UDP connection + // TODO + } +} + +func (s *Server) handleTCP(stream quic.Stream, reqAddr string) { + host, port, err := net.SplitHostPort(reqAddr) + if err != nil { + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: "invalid address", + }) + return + } + ip := net.ParseIP(host) + if ip != nil { + // IP request, clear host for ACL engine + host = "" + } + action, arg := acl.ActionDirect, "" + if s.aclEngine != nil { + action, arg = s.aclEngine.Lookup(host, ip) + } + + var conn net.Conn // Connection to be piped + switch action { + case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side + conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout) + if err != nil { + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: err.Error(), + }) + return } - }() - // Send response - err = writeServerConnectResponse(stream, &pb.ServerConnectResponse{ - Result: pb.ConnectResult(result), - Message: msg, + case acl.ActionBlock: + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: "blocked by ACL", + }) + return + case acl.ActionHijack: + hijackAddr := net.JoinHostPort(arg, port) + conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout) + if err != nil { + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: err.Error(), + }) + return + } + default: + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: "ACL error", + }) + return + } + // So far so good if we reach here + err = struc.Pack(stream, &serverResponse{ + OK: true, }) if err != nil { - s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, err) return } - if result != ConnectResultSuccess { - s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, - fmt.Errorf("handler returned an unsuccessful state %d (msg: %s)", result, msg)) - return - } - switch req.Type { - case pb.ConnectionType_Stream: - err = utils.PipePair(stream, conn, &s.outboundBytes, &s.inboundBytes) - case pb.ConnectionType_Packet: - err = utils.PipePair(&utils.PacketWrapperConn{Orig: &utils.QUICStreamWrapperConn{ - Orig: stream, - PseudoLocalAddr: localAddr, - PseudoRemoteAddr: remoteAddr, - }}, conn, &s.outboundBytes, &s.inboundBytes) - default: - err = fmt.Errorf("unsupported connection type %s", req.Type.String()) - } - s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, err) + _ = utils.Pipe2Way(stream, conn) } diff --git a/pkg/core/types.go b/pkg/core/types.go deleted file mode 100644 index 841da1f..0000000 --- a/pkg/core/types.go +++ /dev/null @@ -1,13 +0,0 @@ -package core - -import ( - "encoding/binary" - "github.com/lucas-clemente/quic-go/congestion" - "time" -) - -const controlStreamTimeout = 10 * time.Second - -var controlProtocolEndian = binary.BigEndian - -type CongestionFactory func(refBPS uint64) congestion.CongestionControl diff --git a/pkg/http/server.go b/pkg/http/server.go index 8d82d03..a9c0bbc 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -42,9 +42,9 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng case acl.ActionDirect: return net.Dial(network, addr) case acl.ActionProxy: - return hyClient.Dial(false, addr) + return hyClient.DialTCP(addr) case acl.ActionBlock: - return nil, errors.New("blocked in ACL") + return nil, errors.New("blocked by ACL") case acl.ActionHijack: return net.Dial(network, net.JoinHostPort(arg, port)) default: @@ -53,7 +53,6 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng }, IdleConnTimeout: idleTimeout, // TODO: Disable HTTP2 support? ref: https://github.com/elazarl/goproxy/issues/361 - //TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper), } proxy.ConnectDial = nil if basicAuthFunc != nil { diff --git a/pkg/socks5/server.go b/pkg/socks5/server.go index bedd867..603bdc4 100644 --- a/pkg/socks5/server.go +++ b/pkg/socks5/server.go @@ -156,12 +156,8 @@ func (s *Server) handle(c *net.TCPConn, r *socks5.Request) error { return s.handleTCP(c, r) } else if r.Cmd == socks5.CmdUDP { // UDP - if !s.DisableUDP { - return s.handleUDP(c, r) - } else { - _ = sendReply(c, socks5.RepCommandNotSupported) - return ErrUnsupportedCmd - } + _ = sendReply(c, socks5.RepCommandNotSupported) + return ErrUnsupportedCmd } else { _ = sendReply(c, socks5.RepCommandNotSupported) return ErrUnsupportedCmd @@ -193,7 +189,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error { closeErr = pipePair(c, rc, s.TCPDeadline) return nil case acl.ActionProxy: - rc, err := s.HyClient.Dial(false, addr) + rc, err := s.HyClient.DialTCP(addr) if err != nil { _ = sendReply(c, socks5.RepHostUnreachable) closeErr = err @@ -225,132 +221,6 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error { } } -func (s *Server) handleUDP(c *net.TCPConn, r *socks5.Request) error { - s.NewUDPAssociateFunc(c.RemoteAddr()) - var closeErr error - defer func() { - s.UDPAssociateClosedFunc(c.RemoteAddr(), closeErr) - }() - udpConn, err := net.ListenUDP("udp", &net.UDPAddr{ - IP: s.TCPAddr.IP, - Zone: s.TCPAddr.Zone, - }) - if err != nil { - _ = sendReply(c, socks5.RepServerFailure) - closeErr = err - return err - } - defer udpConn.Close() - // Send UDP server addr to the client - atyp, addr, port, err := socks5.ParseAddress(udpConn.LocalAddr().String()) - if err != nil { - _ = sendReply(c, socks5.RepServerFailure) - closeErr = err - return err - } - _, _ = socks5.NewReply(socks5.RepSuccess, atyp, addr, port).WriteTo(c) - // Let UDP server do its job, we hold the TCP connection here - go s.udpServer(udpConn) - buf := make([]byte, 1024) - for { - if s.TCPDeadline != 0 { - _ = c.SetDeadline(time.Now().Add(time.Duration(s.TCPDeadline) * time.Second)) - } - _, err := c.Read(buf) - if err != nil { - closeErr = err - break - } - } - // As the TCP connection closes, so does the UDP listener - return nil -} - -func (s *Server) udpServer(c *net.UDPConn) { - var clientAddr *net.UDPAddr - remoteMap := make(map[string]io.ReadWriteCloser) // Remote addr <-> Remote conn - buf := make([]byte, utils.PipeBufferSize) - var closeErr error - - for { - n, caddr, err := c.ReadFromUDP(buf) - if err != nil { - closeErr = err - break - } - d, err := socks5.NewDatagramFromBytes(buf[:n]) - if err != nil || d.Frag != 0 { - // Ignore bad packets - continue - } - if clientAddr == nil { - // Whoever sends the first valid packet is our client :P - clientAddr = caddr - } else if caddr.String() != clientAddr.String() { - // We already have a client and you're not it! - continue - } - domain, ip, port, addr := parseDatagramRequestAddress(d) - rc := remoteMap[addr] - if rc == nil { - // Need a new entry - action, arg := acl.ActionProxy, "" - if s.ACLEngine != nil { - action, arg = s.ACLEngine.Lookup(domain, ip) - } - s.NewUDPTunnelFunc(clientAddr, addr, action, arg) - // Handle according to the action - switch action { - case acl.ActionDirect: - rc, err = net.Dial("udp", addr) - if err != nil { - s.UDPTunnelClosedFunc(clientAddr, addr, err) - continue - } - // The other direction - go udpReversePipe(clientAddr, c, rc) - remoteMap[addr] = rc - case acl.ActionProxy: - rc, err = s.HyClient.Dial(true, addr) - if err != nil { - s.UDPTunnelClosedFunc(clientAddr, addr, err) - continue - } - // The other direction - go udpReversePipe(clientAddr, c, rc) - remoteMap[addr] = rc - case acl.ActionBlock: - s.UDPTunnelClosedFunc(clientAddr, addr, errors.New("blocked in ACL")) - continue - case acl.ActionHijack: - rc, err = net.Dial("udp", net.JoinHostPort(arg, port)) - if err != nil { - s.UDPTunnelClosedFunc(clientAddr, addr, err) - continue - } - // The other direction - go udpReversePipe(clientAddr, c, rc) - remoteMap[addr] = rc - default: - s.UDPTunnelClosedFunc(clientAddr, addr, fmt.Errorf("unknown action %d", action)) - continue - } - } - _, err = rc.Write(d.Data) - if err != nil { - // The connection is no longer valid, close & remove from map - _ = rc.Close() - delete(remoteMap, addr) - s.UDPTunnelClosedFunc(clientAddr, addr, err) - } - } - // Close all remote connections - for raddr, rc := range remoteMap { - _ = rc.Close() - s.UDPTunnelClosedFunc(clientAddr, raddr, closeErr) - } -} - func sendReply(conn *net.TCPConn, rep byte) error { p := socks5.NewReply(rep, socks5.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}) _, err := p.WriteTo(conn) @@ -367,24 +237,15 @@ func parseRequestAddress(r *socks5.Request) (domain string, ip net.IP, port stri } } -func parseDatagramRequestAddress(r *socks5.Datagram) (domain string, ip net.IP, port string, addr string) { - p := strconv.Itoa(int(binary.BigEndian.Uint16(r.DstPort))) - if r.Atyp == socks5.ATYPDomain { - d := string(r.DstAddr[1:]) - return d, nil, p, net.JoinHostPort(d, p) - } else { - return "", r.DstAddr, p, net.JoinHostPort(net.IP(r.DstAddr).String(), p) - } -} - func pipePair(conn *net.TCPConn, stream io.ReadWriteCloser, deadline int) error { + deadlineDuration := time.Duration(deadline) * time.Second errChan := make(chan error, 2) // TCP to stream go func() { buf := make([]byte, utils.PipeBufferSize) for { if deadline != 0 { - _ = conn.SetDeadline(time.Now().Add(time.Duration(deadline) * time.Second)) + _ = conn.SetDeadline(time.Now().Add(deadlineDuration)) } rn, err := conn.Read(buf) if rn > 0 { @@ -402,22 +263,7 @@ func pipePair(conn *net.TCPConn, stream io.ReadWriteCloser, deadline int) error }() // Stream to TCP go func() { - errChan <- utils.Pipe(stream, conn, nil) + errChan <- utils.Pipe(stream, conn) }() return <-errChan } - -func udpReversePipe(clientAddr *net.UDPAddr, c *net.UDPConn, rc io.ReadWriteCloser) { - buf := make([]byte, utils.PipeBufferSize) - for { - n, err := rc.Read(buf) - if err != nil { - break - } - d := socks5.NewDatagram(socks5.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}, buf[:n]) - _, err = c.WriteTo(d.Bytes(), clientAddr) - if err != nil { - break - } - } -} diff --git a/pkg/utils/conn_wrappers.go b/pkg/utils/conn_wrappers.go deleted file mode 100644 index 6f745e4..0000000 --- a/pkg/utils/conn_wrappers.go +++ /dev/null @@ -1,96 +0,0 @@ -package utils - -import ( - "encoding/binary" - "fmt" - "github.com/lucas-clemente/quic-go" - "io" - "net" - "time" -) - -type PacketWrapperConn struct { - Orig net.Conn -} - -func (w *PacketWrapperConn) Read(b []byte) (n int, err error) { - var sz uint32 - if err := binary.Read(w.Orig, binary.BigEndian, &sz); err != nil { - return 0, err - } - if int(sz) <= len(b) { - return io.ReadFull(w.Orig, b[:sz]) - } else { - return 0, fmt.Errorf("the buffer is too small to hold %d bytes of packet data", sz) - } -} - -func (w *PacketWrapperConn) Write(b []byte) (n int, err error) { - sz := uint32(len(b)) - if err := binary.Write(w.Orig, binary.BigEndian, &sz); err != nil { - return 0, err - } - return w.Orig.Write(b) -} - -func (w *PacketWrapperConn) Close() error { - return w.Orig.Close() -} - -func (w *PacketWrapperConn) LocalAddr() net.Addr { - return w.Orig.LocalAddr() -} - -func (w *PacketWrapperConn) RemoteAddr() net.Addr { - return w.Orig.RemoteAddr() -} - -func (w *PacketWrapperConn) SetDeadline(t time.Time) error { - return w.Orig.SetDeadline(t) -} - -func (w *PacketWrapperConn) SetReadDeadline(t time.Time) error { - return w.Orig.SetReadDeadline(t) -} - -func (w *PacketWrapperConn) SetWriteDeadline(t time.Time) error { - return w.Orig.SetWriteDeadline(t) -} - -type QUICStreamWrapperConn struct { - Orig quic.Stream - PseudoLocalAddr net.Addr - PseudoRemoteAddr net.Addr -} - -func (w *QUICStreamWrapperConn) Read(b []byte) (n int, err error) { - return w.Orig.Read(b) -} - -func (w *QUICStreamWrapperConn) Write(b []byte) (n int, err error) { - return w.Orig.Write(b) -} - -func (w *QUICStreamWrapperConn) Close() error { - return w.Orig.Close() -} - -func (w *QUICStreamWrapperConn) LocalAddr() net.Addr { - return w.PseudoLocalAddr -} - -func (w *QUICStreamWrapperConn) RemoteAddr() net.Addr { - return w.PseudoRemoteAddr -} - -func (w *QUICStreamWrapperConn) SetDeadline(t time.Time) error { - return w.Orig.SetDeadline(t) -} - -func (w *QUICStreamWrapperConn) SetReadDeadline(t time.Time) error { - return w.Orig.SetReadDeadline(t) -} - -func (w *QUICStreamWrapperConn) SetWriteDeadline(t time.Time) error { - return w.Orig.SetWriteDeadline(t) -} diff --git a/pkg/utils/pipe.go b/pkg/utils/pipe.go index 5fa16a0..47318f9 100644 --- a/pkg/utils/pipe.go +++ b/pkg/utils/pipe.go @@ -2,20 +2,16 @@ package utils import ( "io" - "sync/atomic" ) const PipeBufferSize = 65536 -func Pipe(src, dst io.ReadWriter, atomicCounter *uint64) error { +func Pipe(src, dst io.ReadWriter) error { buf := make([]byte, PipeBufferSize) for { rn, err := src.Read(buf) if rn > 0 { - wn, err := dst.Write(buf[:rn]) - if atomicCounter != nil { - atomic.AddUint64(atomicCounter, uint64(wn)) - } + _, err := dst.Write(buf[:rn]) if err != nil { return err } @@ -26,13 +22,13 @@ func Pipe(src, dst io.ReadWriter, atomicCounter *uint64) error { } } -func PipePair(rw1, rw2 io.ReadWriter, rw1WriteCounter, rw2WriteCounter *uint64) error { +func Pipe2Way(rw1, rw2 io.ReadWriter) error { errChan := make(chan error, 2) go func() { - errChan <- Pipe(rw2, rw1, rw1WriteCounter) + errChan <- Pipe(rw2, rw1) }() go func() { - errChan <- Pipe(rw1, rw2, rw2WriteCounter) + errChan <- Pipe(rw1, rw2) }() // We only need the first error return <-errChan