From a424a17af3c2f6a732a5a163fd97fc7afd81ccae Mon Sep 17 00:00:00 2001 From: Toby Date: Mon, 20 Apr 2020 16:53:13 -0700 Subject: [PATCH] Tons of refactoring --- .gitignore | 3 +- cmd/forwarder/client.go | 169 --------- cmd/forwarder/config.go | 76 ---- cmd/forwarder/server.go | 204 ----------- cmd/relay/client.go | 118 ++++++ cmd/relay/config.go | 150 ++++++++ cmd/{forwarder => relay}/flags.go | 14 +- cmd/{forwarder => relay}/main.go | 0 cmd/relay/server.go | 84 +++++ internal/core/client.go | 175 +++++++++ internal/core/control.go | 103 ++++++ internal/core/control.pb.go | 439 +++++++++++++++++++++++ internal/core/control.proto | 50 +++ internal/{forwarder => core}/protogen.go | 2 +- internal/core/server.go | 199 ++++++++++ internal/core/types.go | 13 + internal/forwarder/client.go | 203 ----------- internal/forwarder/control.go | 67 ---- internal/forwarder/control.pb.go | 206 ----------- internal/forwarder/control.proto | 19 - internal/forwarder/params.go | 10 - internal/forwarder/server.go | 176 --------- internal/forwarder/types.go | 21 -- internal/utils/packet_readwritecloser.go | 35 ++ internal/utils/pipe.go | 6 +- pkg/core/interface.go | 71 ++++ pkg/forwarder/client.go | 70 ---- pkg/forwarder/interface.go | 89 ----- pkg/forwarder/params.go | 9 - pkg/forwarder/server.go | 119 ------ 30 files changed, 1444 insertions(+), 1456 deletions(-) delete mode 100644 cmd/forwarder/client.go delete mode 100644 cmd/forwarder/config.go delete mode 100644 cmd/forwarder/server.go create mode 100644 cmd/relay/client.go create mode 100644 cmd/relay/config.go rename cmd/{forwarder => relay}/flags.go (58%) rename cmd/{forwarder => relay}/main.go (100%) create mode 100644 cmd/relay/server.go create mode 100644 internal/core/client.go create mode 100644 internal/core/control.go create mode 100644 internal/core/control.pb.go create mode 100644 internal/core/control.proto rename internal/{forwarder => core}/protogen.go (72%) create mode 100644 internal/core/server.go create mode 100644 internal/core/types.go delete mode 100644 internal/forwarder/client.go delete mode 100644 internal/forwarder/control.go delete mode 100644 internal/forwarder/control.pb.go delete mode 100644 internal/forwarder/control.proto delete mode 100644 internal/forwarder/params.go delete mode 100644 internal/forwarder/server.go delete mode 100644 internal/forwarder/types.go create mode 100644 internal/utils/packet_readwritecloser.go create mode 100644 pkg/core/interface.go delete mode 100644 pkg/forwarder/client.go delete mode 100644 pkg/forwarder/interface.go delete mode 100644 pkg/forwarder/params.go delete mode 100644 pkg/forwarder/server.go diff --git a/.gitignore b/.gitignore index ba0b0dd..e9e0ccf 100644 --- a/.gitignore +++ b/.gitignore @@ -179,5 +179,4 @@ $RECYCLE.BIN/ # End of https://www.gitignore.io/api/go,linux,macos,windows,intellij+all -cmd/forwarder/*.json -cmd/forwarder/forwarder +cmd/relay/*.json diff --git a/cmd/forwarder/client.go b/cmd/forwarder/client.go deleted file mode 100644 index 0e07e75..0000000 --- a/cmd/forwarder/client.go +++ /dev/null @@ -1,169 +0,0 @@ -package main - -import ( - "crypto/tls" - "crypto/x509" - "encoding/json" - "flag" - "fmt" - "github.com/lucas-clemente/quic-go/congestion" - hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion" - "github.com/tobyxdd/hysteria/pkg/forwarder" - "io/ioutil" - "log" - "net" - "os" - "os/user" -) - -func loadCmdClientConfig(args []string) (CmdClientConfig, error) { - fs := flag.NewFlagSet("client", flag.ContinueOnError) - // Config file - configFile := fs.String("config", "", "Configuration file path") - // Listen - listen := fs.String("listen", "", "TCP listen address") - // Server - server := fs.String("server", "", "Server address") - // Name - name := fs.String("name", "", "Client name presented to the server") - // Insecure - var insecure optionalBoolFlag - fs.Var(&insecure, "insecure", "Ignore TLS certificate errors") - // Custom CA - customCAFile := fs.String("ca", "", "Specify a trusted CA file") - // Up Mbps - upMbps := fs.Int("up-mbps", 0, "Upload speed in Mbps") - // Down Mbps - downMbps := fs.Int("down-mbps", 0, "Download speed in Mbps") - // Receive window conn - recvWindowConn := fs.Uint64("recv-window-conn", 0, "Max receive window size per connection") - // Receive window - recvWindow := fs.Uint64("recv-window", 0, "Max receive window size") - // Parse - if err := fs.Parse(args); err != nil { - os.Exit(1) - } - // Put together the config - var config CmdClientConfig - // Load from file first - if len(*configFile) > 0 { - cb, err := ioutil.ReadFile(*configFile) - if err != nil { - return CmdClientConfig{}, err - } - if err := json.Unmarshal(cb, &config); err != nil { - return CmdClientConfig{}, err - } - } - // Then CLI options can override config - if len(*listen) > 0 { - config.ListenAddr = *listen - } - if len(*server) > 0 { - config.ServerAddr = *server - } - if len(*name) > 0 { - config.Name = *name - } - if insecure.Exists { - config.Insecure = insecure.Value - } - if len(*customCAFile) > 0 { - config.CustomCAFile = *customCAFile - } - if *upMbps != 0 { - config.UpMbps = *upMbps - } - if *downMbps != 0 { - config.DownMbps = *downMbps - } - if *recvWindowConn != 0 { - config.ReceiveWindowConn = *recvWindowConn - } - if *recvWindow != 0 { - config.ReceiveWindow = *recvWindow - } - return config, nil -} - -func client(args []string) { - config, err := loadCmdClientConfig(args) - if err != nil { - log.Fatalln("Unable to load configuration:", err.Error()) - } - if err := config.Check(); err != nil { - log.Fatalln("Configuration error:", err.Error()) - } - if len(config.Name) == 0 { - usr, err := user.Current() - if err == nil { - config.Name = usr.Name - } - } - fmt.Printf("Configuration loaded: %+v\n", config) - - tlsConfig := &tls.Config{ - NextProtos: []string{forwarder.TLSAppProtocol}, - MinVersion: tls.VersionTLS13, - } - - // Load CA - if len(config.CustomCAFile) > 0 { - bs, err := ioutil.ReadFile(config.CustomCAFile) - if err != nil { - log.Fatalln("Unable to load CA file:", err) - } - cp := x509.NewCertPool() - if !cp.AppendCertsFromPEM(bs) { - log.Fatalln("Unable to parse CA file", config.CustomCAFile) - } - tlsConfig.RootCAs = cp - } - - logChan := make(chan string, 4) - - go func() { - _, err = forwarder.NewClient(config.ListenAddr, config.ServerAddr, forwarder.ClientConfig{ - Name: config.Name, - TLSConfig: tlsConfig, - Speed: &forwarder.Speed{ - SendBPS: uint64(config.UpMbps) * mbpsToBps, - ReceiveBPS: uint64(config.DownMbps) * mbpsToBps, - }, - MaxReceiveWindowPerConnection: config.ReceiveWindowConn, - MaxReceiveWindow: config.ReceiveWindow, - CongestionFactory: func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos { - return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) - }, - }, forwarder.ClientCallbacks{ - ServerConnectedCallback: func(addr net.Addr, banner string, cSend uint64, cRecv uint64) { - logChan <- fmt.Sprintf("Connected to server %s, negotiated speed in Mbps: Up %d / Down %d", - addr.String(), cSend/mbpsToBps, cRecv/mbpsToBps) - logChan <- fmt.Sprintf("Server banner: [%s]", banner) - }, - ServerErrorCallback: func(err error) { - logChan <- fmt.Sprintf("Error connecting to the server: %s", err.Error()) - }, - NewTCPConnectionCallback: func(addr net.Addr) { - logChan <- fmt.Sprintf("New connection: %s", addr.String()) - }, - TCPConnectionClosedCallback: func(addr net.Addr, err error) { - logChan <- fmt.Sprintf("Connection %s closed: %s", addr.String(), err.Error()) - }, - }) - if err != nil { - log.Fatalln("Client startup failure:", err) - } else { - log.Println("The client is now up and running :)") - } - }() - - for { - logStr := <-logChan - if len(logStr) == 0 { - break - } - log.Println(logStr) - } - -} diff --git a/cmd/forwarder/config.go b/cmd/forwarder/config.go deleted file mode 100644 index e4193a1..0000000 --- a/cmd/forwarder/config.go +++ /dev/null @@ -1,76 +0,0 @@ -package main - -import ( - "errors" - "fmt" -) - -type CmdClientConfig struct { - ListenAddr string `json:"listen"` - ServerAddr string `json:"server"` - Name string `json:"name"` - Insecure bool `json:"insecure"` - CustomCAFile string `json:"ca"` - UpMbps int `json:"up_mbps"` - DownMbps int `json:"down_mbps"` - ReceiveWindowConn uint64 `json:"recv_window_conn"` - ReceiveWindow uint64 `json:"recv_window"` -} - -func (c *CmdClientConfig) Check() error { - if len(c.ListenAddr) == 0 { - return errors.New("no listen address") - } - if len(c.ServerAddr) == 0 { - return errors.New("no server address") - } - if c.UpMbps <= 0 || c.DownMbps <= 0 { - return errors.New("invalid speed") - } - if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || - (c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) { - return errors.New("invalid receive window size") - } - return nil -} - -type ForwardEntry struct { - ListenAddr string `json:"listen"` - RemoteAddr string `json:"remote"` -} - -func (e *ForwardEntry) String() string { - return fmt.Sprintf("%s <-> %s", e.ListenAddr, e.RemoteAddr) -} - -type CmdServerConfig struct { - Entries []ForwardEntry `json:"entries"` - Banner string `json:"banner"` - CertFile string `json:"cert"` - KeyFile string `json:"key"` - UpMbps int `json:"up_mbps"` - DownMbps int `json:"down_mbps"` - ReceiveWindowConn uint64 `json:"recv_window_conn"` - ReceiveWindowClient uint64 `json:"recv_window_client"` - MaxConnClient int `json:"max_conn_client"` -} - -func (c *CmdServerConfig) Check() error { - if len(c.Entries) == 0 { - return errors.New("no entries") - } - if len(c.CertFile) == 0 || len(c.KeyFile) == 0 { - return errors.New("TLS cert or key not provided") - } - if c.UpMbps < 0 || c.DownMbps < 0 { - return errors.New("invalid speed") - } - if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || - (c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) { - return errors.New("invalid receive window size") - } - if c.MaxConnClient < 0 { - return errors.New("invalid max connections per client") - } - return nil -} diff --git a/cmd/forwarder/server.go b/cmd/forwarder/server.go deleted file mode 100644 index 25a94da..0000000 --- a/cmd/forwarder/server.go +++ /dev/null @@ -1,204 +0,0 @@ -package main - -import ( - "crypto/tls" - "encoding/json" - "flag" - "fmt" - "github.com/lucas-clemente/quic-go/congestion" - hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion" - "github.com/tobyxdd/hysteria/pkg/forwarder" - "io/ioutil" - "log" - "net" - "os" - "strings" -) - -const mbpsToBps = 125000 - -func loadCmdServerConfig(args []string) (CmdServerConfig, error) { - fs := flag.NewFlagSet("server", flag.ContinueOnError) - // Config file - configFile := fs.String("config", "", "Configuration file path") - // Entries - var entries stringSliceFlag - fs.Var(&entries, "entry", "Add a forwarding entry. Separate the listen address and the remote address with a comma. You can add this option multiple times. Example: localhost:444,google.com:443") - // Banner - banner := fs.String("banner", "", "A banner to present to clients") - // Cert file - certFile := fs.String("cert", "", "TLS certificate file") - // Key file - keyFile := fs.String("key", "", "TLS key file") - // Up Mbps - upMbps := fs.Int("up-mbps", 0, "Max upload speed per client in Mbps") - // Down Mbps - downMbps := fs.Int("down-mbps", 0, "Max download speed per client in Mbps") - // Receive window conn - recvWindowConn := fs.Uint64("recv-window-conn", 0, "Max receive window size per connection") - // Receive window client - recvWindowClient := fs.Uint64("recv-window-client", 0, "Max receive window size per client") - // Max conn client - maxConnClient := fs.Int("max-conn-client", 0, "Max simultaneous connections allowed per client") - // Parse - if err := fs.Parse(args); err != nil { - os.Exit(1) - } - // Put together the config - var config CmdServerConfig - // Load from file first - if len(*configFile) > 0 { - cb, err := ioutil.ReadFile(*configFile) - if err != nil { - return CmdServerConfig{}, err - } - if err := json.Unmarshal(cb, &config); err != nil { - return CmdServerConfig{}, err - } - } - // Then CLI options can override config - if len(entries) > 0 { - fe, err := flagToEntries(entries) - if err != nil { - return CmdServerConfig{}, err - } - config.Entries = append(config.Entries, fe...) - } - if len(*banner) > 0 { - config.Banner = *banner - } - if len(*certFile) > 0 { - config.CertFile = *certFile - } - if len(*keyFile) > 0 { - config.KeyFile = *keyFile - } - if *upMbps != 0 { - config.UpMbps = *upMbps - } - if *downMbps != 0 { - config.DownMbps = *downMbps - } - if *recvWindowConn != 0 { - config.ReceiveWindowConn = *recvWindowConn - } - if *recvWindowClient != 0 { - config.ReceiveWindowClient = *recvWindowClient - } - if *maxConnClient != 0 { - config.MaxConnClient = *maxConnClient - } - return config, nil -} - -func flagToEntries(f stringSliceFlag) ([]ForwardEntry, error) { - out := make([]ForwardEntry, len(f)) - for i, entry := range f { - es := strings.Split(entry, ",") - if len(es) != 2 { - return nil, fmt.Errorf("incorrect entry syntax: %s", entry) - } - out[i] = ForwardEntry{ - ListenAddr: es[0], - RemoteAddr: es[1], - } - } - return out, nil -} - -func server(args []string) { - config, err := loadCmdServerConfig(args) - if err != nil { - log.Fatalln("Unable to load configuration:", err.Error()) - } - if err := config.Check(); err != nil { - log.Fatalln("Configuration error:", err.Error()) - } - fmt.Printf("Configuration loaded: %+v\n", config) - // Load cert - cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile) - if err != nil { - log.Fatalln("Unable to load the certificate:", err) - } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - NextProtos: []string{forwarder.TLSAppProtocol}, - MinVersion: tls.VersionTLS13, - } - - logChan := make(chan string, 4) - - go func() { - server := forwarder.NewServer(forwarder.ServerConfig{ - BannerMessage: config.Banner, - TLSConfig: tlsConfig, - MaxSpeedPerClient: &forwarder.Speed{ - SendBPS: uint64(config.UpMbps) * mbpsToBps, - ReceiveBPS: uint64(config.DownMbps) * mbpsToBps, - }, - MaxReceiveWindowPerConnection: config.ReceiveWindowConn, - MaxReceiveWindowPerClient: config.ReceiveWindowClient, - MaxConnectionPerClient: config.MaxConnClient, - CongestionFactory: func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos { - return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) - }, - }, forwarder.ServerCallbacks{ - ClientConnectedCallback: func(listenAddr string, clientAddr net.Addr, name string, sSend uint64, sRecv uint64) { - if len(name) > 0 { - logChan <- fmt.Sprintf("[%s] Client %s (%s) connected, negotiated speed in Mbps: Up %d / Down %d", - listenAddr, clientAddr.String(), name, sSend/mbpsToBps, sRecv/mbpsToBps) - } else { - logChan <- fmt.Sprintf("[%s] Client %s connected, negotiated speed in Mbps: Up %d / Down %d", - listenAddr, clientAddr.String(), sSend/mbpsToBps, sRecv/mbpsToBps) - } - }, - ClientDisconnectedCallback: func(listenAddr string, clientAddr net.Addr, name string, err error) { - if len(name) > 0 { - logChan <- fmt.Sprintf("[%s] Client %s (%s) disconnected: %s", - listenAddr, clientAddr.String(), name, err.Error()) - } else { - logChan <- fmt.Sprintf("[%s] Client %s disconnected: %s", - listenAddr, clientAddr.String(), err.Error()) - } - }, - ClientNewStreamCallback: func(listenAddr string, clientAddr net.Addr, name string, id int) { - if len(name) > 0 { - logChan <- fmt.Sprintf("[%s] Client %s (%s) opened stream ID %d", - listenAddr, clientAddr.String(), name, id) - } else { - logChan <- fmt.Sprintf("[%s] Client %s opened stream ID %d", - listenAddr, clientAddr.String(), id) - } - }, - ClientStreamClosedCallback: func(listenAddr string, clientAddr net.Addr, name string, id int, err error) { - if len(name) > 0 { - logChan <- fmt.Sprintf("[%s] Client %s (%s) closed stream ID %d: %s", - listenAddr, clientAddr.String(), name, id, err.Error()) - } else { - logChan <- fmt.Sprintf("[%s] Client %s closed stream ID %d: %s", - listenAddr, clientAddr.String(), id, err.Error()) - } - }, - TCPErrorCallback: func(listenAddr string, remoteAddr string, err error) { - logChan <- fmt.Sprintf("[%s] TCP error when connecting to %s: %s", - listenAddr, remoteAddr, err.Error()) - }, - }) - for _, entry := range config.Entries { - log.Println("Starting", entry.String(), "...") - if err := server.Add(entry.ListenAddr, entry.RemoteAddr); err != nil { - log.Fatalln(err) - } - } - log.Println("The server is now up and running :)") - }() - - for { - logStr := <-logChan - if len(logStr) == 0 { - break - } - log.Println(logStr) - } - -} diff --git a/cmd/relay/client.go b/cmd/relay/client.go new file mode 100644 index 0000000..8253261 --- /dev/null +++ b/cmd/relay/client.go @@ -0,0 +1,118 @@ +package main + +import ( + "crypto/tls" + "crypto/x509" + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/congestion" + "github.com/tobyxdd/hysteria/internal/utils" + hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion" + "github.com/tobyxdd/hysteria/pkg/core" + "io" + "io/ioutil" + "log" + "net" + "os/user" +) + +func client(args []string) { + var config cmdClientConfig + err := loadConfig(&config, args) + if err != nil { + log.Fatalln("Unable to load configuration:", err) + } + if err := config.Check(); err != nil { + log.Fatalln("Configuration error:", err) + } + if len(config.Name) == 0 { + usr, err := user.Current() + if err == nil { + config.Name = usr.Name + } + } + log.Printf("Configuration loaded: %+v\n", config) + + tlsConfig := &tls.Config{ + NextProtos: []string{TLSAppProtocol}, + MinVersion: tls.VersionTLS13, + } + // Load CA + if len(config.CustomCAFile) > 0 { + bs, err := ioutil.ReadFile(config.CustomCAFile) + if err != nil { + log.Fatalln("Unable to load CA file:", err) + } + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(bs) { + log.Fatalln("Unable to parse CA file", config.CustomCAFile) + } + tlsConfig.RootCAs = cp + } + + quicConfig := &quic.Config{ + MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn, + MaxReceiveConnectionFlowControlWindow: config.ReceiveWindow, + KeepAlive: true, + } + if quicConfig.MaxReceiveStreamFlowControlWindow == 0 { + quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow + } + if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 { + quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow + } + + client, err := core.NewClient(config.ServerAddr, config.Name, "", tlsConfig, quicConfig, + uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, + func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos { + return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) + }) + if err != nil { + log.Fatalln("Client initialization failed:", err) + } + defer client.Close() + log.Println("Client initialization complete, connected to", config.ServerAddr) + + listener, err := net.Listen("tcp", config.ListenAddr) + if err != nil { + log.Fatalln("TCP listen failed:", err) + } + defer listener.Close() + log.Println("TCP listening on", listener.Addr().String()) + + for { + conn, err := listener.Accept() + if err != nil { + log.Fatalln("TCP accept failed:", err) + } + go clientHandleConn(conn, client) + } +} + +func clientHandleConn(conn net.Conn, client core.Client) { + log.Println("New TCP connection from", conn.RemoteAddr().String()) + var closeErr error + defer func() { + _ = conn.Close() + log.Println("TCP connection from", conn.RemoteAddr().String(), "closed", closeErr) + }() + rwc, err := client.Dial(false, "") + if err != nil { + closeErr = err + return + } + defer rwc.Close() + closeErr = pipePair(conn, rwc) +} + +func pipePair(rw1, rw2 io.ReadWriter) error { + // Pipes + errChan := make(chan error, 2) + go func() { + errChan <- utils.Pipe(rw2, rw1, nil) + }() + go func() { + errChan <- utils.Pipe(rw1, rw2, nil) + }() + // We only need the first error + return <-errChan +} diff --git a/cmd/relay/config.go b/cmd/relay/config.go new file mode 100644 index 0000000..1ec7fc4 --- /dev/null +++ b/cmd/relay/config.go @@ -0,0 +1,150 @@ +package main + +import ( + "encoding/json" + "errors" + "flag" + "io/ioutil" + "os" + "reflect" + "strings" +) + +const ( + mbpsToBps = 125000 + + TLSAppProtocol = "hysteria-relay" + + DefaultMaxReceiveStreamFlowControlWindow = 33554432 + DefaultMaxReceiveConnectionFlowControlWindow = 67108864 +) + +type cmdClientConfig struct { + ListenAddr string `json:"listen" desc:"TCP listen address"` + ServerAddr string `json:"server" desc:"Server address"` + Name string `json:"name" desc:"Client name presented to the server"` + Insecure bool `json:"insecure" desc:"Ignore TLS certificate errors"` + CustomCAFile string `json:"ca" desc:"Specify a trusted CA file"` + UpMbps int `json:"up_mbps" desc:"Upload speed in Mbps"` + DownMbps int `json:"down_mbps" desc:"Download speed in Mbps"` + ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"` + ReceiveWindow uint64 `json:"recv_window" desc:"Max receive window size"` +} + +func (c *cmdClientConfig) Check() error { + if len(c.ListenAddr) == 0 { + return errors.New("no listen address") + } + if len(c.ServerAddr) == 0 { + return errors.New("no server address") + } + if c.UpMbps <= 0 || c.DownMbps <= 0 { + return errors.New("invalid speed") + } + if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || + (c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) { + return errors.New("invalid receive window size") + } + return nil +} + +type cmdServerConfig struct { + ListenAddr string `json:"listen" desc:"Server listen address"` + RemoteAddr string `json:"remote" desc:"Remote relay address"` + CertFile string `json:"cert" desc:"TLS certificate file"` + KeyFile string `json:"key" desc:"TLS key file"` + UpMbps int `json:"up_mbps" desc:"Max upload speed per client in Mbps"` + DownMbps int `json:"down_mbps" desc:"Max download speed per client in Mbps"` + ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"` + ReceiveWindowClient uint64 `json:"recv_window_client" desc:"Max receive window size per client"` + MaxConnClient int `json:"max_conn_client" desc:"Max simultaneous connections allowed per client"` +} + +func (c *cmdServerConfig) Check() error { + if len(c.ListenAddr) == 0 { + return errors.New("no listen address") + } + if len(c.RemoteAddr) == 0 { + return errors.New("no remote address") + } + if len(c.CertFile) == 0 || len(c.KeyFile) == 0 { + return errors.New("TLS cert or key not provided") + } + if c.UpMbps < 0 || c.DownMbps < 0 { + return errors.New("invalid speed") + } + if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) || + (c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) { + return errors.New("invalid receive window size") + } + if c.MaxConnClient < 0 { + return errors.New("invalid max connections per client") + } + return nil +} + +func loadConfig(cfg interface{}, args []string) error { + cfgVal := reflect.ValueOf(cfg).Elem() + fs := flag.NewFlagSet("", flag.ContinueOnError) + fsValMap := make(map[reflect.Value]interface{}, cfgVal.NumField()) + for i := 0; i < cfgVal.NumField(); i++ { + structField := cfgVal.Type().Field(i) + tag := structField.Tag + switch structField.Type.Kind() { + case reflect.String: + fsValMap[cfgVal.Field(i)] = + fs.String(jsonTagToFlagName(tag.Get("json")), "", tag.Get("desc")) + case reflect.Int: + fsValMap[cfgVal.Field(i)] = + fs.Int(jsonTagToFlagName(tag.Get("json")), 0, tag.Get("desc")) + case reflect.Uint64: + fsValMap[cfgVal.Field(i)] = + fs.Uint64(jsonTagToFlagName(tag.Get("json")), 0, tag.Get("desc")) + case reflect.Bool: + var bf optionalBoolFlag + fs.Var(&bf, jsonTagToFlagName(tag.Get("json")), tag.Get("desc")) + fsValMap[cfgVal.Field(i)] = &bf + } + } + configFile := fs.String("config", "", "Configuration file") + // Parse + if err := fs.Parse(args); err != nil { + os.Exit(1) + } + // Put together the config + if len(*configFile) > 0 { + cb, err := ioutil.ReadFile(*configFile) + if err != nil { + return err + } + if err := json.Unmarshal(cb, cfg); err != nil { + return err + } + } + // Flags override config from file + for field, val := range fsValMap { + switch v := val.(type) { + case *string: + if len(*v) > 0 { + field.SetString(*v) + } + case *int: + if *v != 0 { + field.SetInt(int64(*v)) + } + case *uint64: + if *v != 0 { + field.SetUint(*v) + } + case *optionalBoolFlag: + if v.Exists { + field.SetBool(v.Value) + } + } + } + return nil +} + +func jsonTagToFlagName(tag string) string { + return strings.ReplaceAll(tag, "_", "-") +} diff --git a/cmd/forwarder/flags.go b/cmd/relay/flags.go similarity index 58% rename from cmd/forwarder/flags.go rename to cmd/relay/flags.go index 41529b9..17fca0f 100644 --- a/cmd/forwarder/flags.go +++ b/cmd/relay/flags.go @@ -2,7 +2,6 @@ package main import ( "strconv" - "strings" ) type optionalBoolFlag struct { @@ -24,17 +23,6 @@ func (flag *optionalBoolFlag) Set(s string) error { return nil } -func (o *optionalBoolFlag) IsBoolFlag() bool { +func (flag *optionalBoolFlag) IsBoolFlag() bool { return true } - -type stringSliceFlag []string - -func (flag *stringSliceFlag) String() string { - return strings.Join(*flag, ";") -} - -func (flag *stringSliceFlag) Set(s string) error { - *flag = append(*flag, s) - return nil -} diff --git a/cmd/forwarder/main.go b/cmd/relay/main.go similarity index 100% rename from cmd/forwarder/main.go rename to cmd/relay/main.go diff --git a/cmd/relay/server.go b/cmd/relay/server.go new file mode 100644 index 0000000..1dd787b --- /dev/null +++ b/cmd/relay/server.go @@ -0,0 +1,84 @@ +package main + +import ( + "crypto/tls" + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/congestion" + hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion" + "github.com/tobyxdd/hysteria/pkg/core" + "io" + "log" + "net" +) + +func server(args []string) { + var config cmdServerConfig + err := loadConfig(&config, args) + if err != nil { + log.Fatalln("Unable to load configuration:", err) + } + if err := config.Check(); err != nil { + log.Fatalln("Configuration error:", err.Error()) + } + log.Printf("Configuration loaded: %+v\n", config) + // Load cert + cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile) + if err != nil { + log.Fatalln("Unable to load the certificate:", err) + } + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{TLSAppProtocol}, + MinVersion: tls.VersionTLS13, + } + + quicConfig := &quic.Config{ + MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn, + MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient, + KeepAlive: true, + } + if quicConfig.MaxReceiveStreamFlowControlWindow == 0 { + quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow + } + if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 { + quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow + } + + server, err := core.NewServer(config.ListenAddr, tlsConfig, quicConfig, + uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, + func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos { + return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) + }, + func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) { + // No authentication logic in relay, just log username and speed + log.Printf("Client %s connected, negotiated speed in Mbps: Up %d / Down %d\n", + addr.String(), sSend/mbpsToBps, sRecv/mbpsToBps) + return core.AuthSuccess, "" + }, + func(addr net.Addr, username string, err error) { + log.Printf("Client %s (%s) disconnected: %s\n", addr.String(), username, err.Error()) + }, + func(addr net.Addr, username string, id int, isUDP bool, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) { + log.Printf("Client %s (%s) opened stream ID %d\n", addr.String(), username, id) + if isUDP { + return core.ConnBlocked, "unsupported", nil + } + conn, err := net.Dial("tcp", config.RemoteAddr) + if err != nil { + log.Printf("TCP error when connecting to %s: %s", config.RemoteAddr, err.Error()) + return core.ConnFailed, err.Error(), nil + } + return core.ConnSuccess, "", conn + }, + func(addr net.Addr, username string, id int, isUDP bool, reqAddr string, err error) { + log.Printf("Client %s (%s) closed stream ID %d: %s", addr.String(), username, id, err.Error()) + }, + ) + if err != nil { + log.Fatalln("Server initialization failed:", err) + } + defer server.Close() + log.Println("The server is now up and running :)") + + log.Fatalln("Server error:", server.Serve()) +} diff --git a/internal/core/client.go b/internal/core/client.go new file mode 100644 index 0000000..1d17eba --- /dev/null +++ b/internal/core/client.go @@ -0,0 +1,175 @@ +package core + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "github.com/lucas-clemente/quic-go" + "github.com/tobyxdd/hysteria/internal/utils" + "io" + "net" + "sync" + "sync/atomic" +) + +var ( + ErrClosed = errors.New("client closed") +) + +type Client struct { + inboundBytes, outboundBytes uint64 // atomic + + reconnectMutex sync.Mutex + closed bool + quicSession quic.Session + serverAddr string + username, password string + tlsConfig *tls.Config + quicConfig *quic.Config + sendBPS, recvBPS uint64 + congestionFactory CongestionFactory +} + +func NewClient(serverAddr string, username string, password string, tlsConfig *tls.Config, quicConfig *quic.Config, + sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory) (*Client, error) { + c := &Client{ + serverAddr: serverAddr, + username: username, + password: password, + tlsConfig: tlsConfig, + quicConfig: quicConfig, + sendBPS: sendBPS, + recvBPS: recvBPS, + congestionFactory: congestionFactory, + } + if err := c.connectToServer(); err != nil { + return nil, err + } + return c, nil +} + +func (c *Client) Dial(udp bool, addr string) (io.ReadWriteCloser, error) { + stream, err := c.openStreamWithReconnect() + if err != nil { + return nil, err + } + // Send request + req := &ClientConnectRequest{Address: addr} + if udp { + req.Type = ConnectionType_UDP + } else { + req.Type = ConnectionType_TCP + } + err = writeClientConnectRequest(stream, req) + if err != nil { + _ = stream.Close() + return nil, err + } + // Read response + resp, err := readServerConnectResponse(stream) + if err != nil { + _ = stream.Close() + return nil, err + } + if resp.Result != ConnectResult_CONN_SUCCESS { + _ = stream.Close() + return nil, fmt.Errorf("server rejected the connection %s (msg: %s)", + resp.Result.String(), resp.Message) + } + if udp { + return &utils.PacketReadWriteCloser{Orig: stream}, nil + } else { + return stream, nil + } +} + +func (c *Client) Stats() (uint64, uint64) { + return atomic.LoadUint64(&c.inboundBytes), atomic.LoadUint64(&c.outboundBytes) +} + +func (c *Client) Close() error { + c.reconnectMutex.Lock() + defer c.reconnectMutex.Unlock() + err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "generic") + c.closed = true + return err +} + +func (c *Client) connectToServer() error { + qs, err := quic.DialAddr(c.serverAddr, c.tlsConfig, c.quicConfig) + if err != nil { + return err + } + // Control stream + ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout) + ctlStream, err := qs.OpenStreamSync(ctx) + ctxCancel() + if err != nil { + _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error") + return err + } + result, msg, err := c.handleControlStream(qs, ctlStream) + if err != nil { + _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error") + return err + } + if result != AuthResult_AUTH_SUCCESS { + _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "authentication failure") + return fmt.Errorf("authentication failure %s (msg: %s)", result.String(), msg) + } + // All good + c.quicSession = qs + return nil +} + +func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (AuthResult, string, error) { + err := writeClientAuthRequest(stream, &ClientAuthRequest{ + Credential: &Credential{ + Username: c.username, + Password: c.password, + }, + Speed: &Speed{ + SendBps: c.sendBPS, + ReceiveBps: c.recvBPS, + }, + }) + if err != nil { + return 0, "", err + } + // Response + resp, err := readServerAuthResponse(stream) + if err != nil { + return 0, "", err + } + // Set the congestion accordingly + if resp.Result == AuthResult_AUTH_SUCCESS && c.congestionFactory != nil { + qs.SetCongestion(c.congestionFactory(resp.Speed.ReceiveBps)) + } + return resp.Result, resp.Message, nil +} + +func (c *Client) openStreamWithReconnect() (quic.Stream, error) { + c.reconnectMutex.Lock() + defer c.reconnectMutex.Unlock() + if c.closed { + return nil, ErrClosed + } + stream, err := c.quicSession.OpenStream() + if err == nil { + // All good + return stream, nil + } + // Something is wrong + if nErr, ok := err.(net.Error); ok && nErr.Temporary() { + // Temporary error, just return + return nil, err + } + // Permanent error, need to reconnect + if err := c.connectToServer(); err != nil { + // Still error, oops + return nil, err + } + // We are not going to try again even if it still fails the second time + return c.quicSession.OpenStream() +} diff --git a/internal/core/control.go b/internal/core/control.go new file mode 100644 index 0000000..eaea770 --- /dev/null +++ b/internal/core/control.go @@ -0,0 +1,103 @@ +package core + +import ( + "encoding/binary" + "github.com/golang/protobuf/proto" + "io" +) + +const ( + closeErrorCodeGeneric = 0 + closeErrorCodeProtocolFailure = 1 +) + +func readDataBlock(r io.Reader) ([]byte, error) { + var sz uint32 + if err := binary.Read(r, controlProtocolEndian, &sz); err != nil { + return nil, err + } + buf := make([]byte, sz) + _, err := io.ReadFull(r, buf) + return buf, err +} + +func writeDataBlock(w io.Writer, data []byte) error { + sz := uint32(len(data)) + if err := binary.Write(w, controlProtocolEndian, &sz); err != nil { + return err + } + _, err := w.Write(data) + return err +} + +func readClientAuthRequest(r io.Reader) (*ClientAuthRequest, error) { + bs, err := readDataBlock(r) + if err != nil { + return nil, err + } + var req ClientAuthRequest + err = proto.Unmarshal(bs, &req) + return &req, err +} + +func writeClientAuthRequest(w io.Writer, req *ClientAuthRequest) error { + bs, err := proto.Marshal(req) + if err != nil { + return err + } + return writeDataBlock(w, bs) +} + +func readServerAuthResponse(r io.Reader) (*ServerAuthResponse, error) { + bs, err := readDataBlock(r) + if err != nil { + return nil, err + } + var resp ServerAuthResponse + err = proto.Unmarshal(bs, &resp) + return &resp, err +} + +func writeServerAuthResponse(w io.Writer, resp *ServerAuthResponse) error { + bs, err := proto.Marshal(resp) + if err != nil { + return err + } + return writeDataBlock(w, bs) +} + +func readClientConnectRequest(r io.Reader) (*ClientConnectRequest, error) { + bs, err := readDataBlock(r) + if err != nil { + return nil, err + } + var req ClientConnectRequest + err = proto.Unmarshal(bs, &req) + return &req, err +} + +func writeClientConnectRequest(w io.Writer, req *ClientConnectRequest) error { + bs, err := proto.Marshal(req) + if err != nil { + return err + } + return writeDataBlock(w, bs) +} + +func readServerConnectResponse(r io.Reader) (*ServerConnectResponse, error) { + bs, err := readDataBlock(r) + if err != nil { + return nil, err + } + var resp ServerConnectResponse + err = proto.Unmarshal(bs, &resp) + return &resp, err +} + +func writeServerConnectResponse(w io.Writer, resp *ServerConnectResponse) error { + bs, err := proto.Marshal(resp) + if err != nil { + return err + } + return writeDataBlock(w, bs) +} diff --git a/internal/core/control.pb.go b/internal/core/control.pb.go new file mode 100644 index 0000000..b9d5da5 --- /dev/null +++ b/internal/core/control.pb.go @@ -0,0 +1,439 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: control.proto + +package core + +import ( + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package + +type AuthResult int32 + +const ( + AuthResult_AUTH_SUCCESS AuthResult = 0 + AuthResult_AUTH_INVALID_CRED AuthResult = 1 + AuthResult_AUTH_INTERNAL_ERROR AuthResult = 2 +) + +var AuthResult_name = map[int32]string{ + 0: "AUTH_SUCCESS", + 1: "AUTH_INVALID_CRED", + 2: "AUTH_INTERNAL_ERROR", +} + +var AuthResult_value = map[string]int32{ + "AUTH_SUCCESS": 0, + "AUTH_INVALID_CRED": 1, + "AUTH_INTERNAL_ERROR": 2, +} + +func (x AuthResult) String() string { + return proto.EnumName(AuthResult_name, int32(x)) +} + +func (AuthResult) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_0c5120591600887d, []int{0} +} + +type ConnectionType int32 + +const ( + ConnectionType_TCP ConnectionType = 0 + ConnectionType_UDP ConnectionType = 1 +) + +var ConnectionType_name = map[int32]string{ + 0: "TCP", + 1: "UDP", +} + +var ConnectionType_value = map[string]int32{ + "TCP": 0, + "UDP": 1, +} + +func (x ConnectionType) String() string { + return proto.EnumName(ConnectionType_name, int32(x)) +} + +func (ConnectionType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_0c5120591600887d, []int{1} +} + +type ConnectResult int32 + +const ( + ConnectResult_CONN_SUCCESS ConnectResult = 0 + ConnectResult_CONN_FAILED ConnectResult = 1 + ConnectResult_CONN_BLOCKED ConnectResult = 2 +) + +var ConnectResult_name = map[int32]string{ + 0: "CONN_SUCCESS", + 1: "CONN_FAILED", + 2: "CONN_BLOCKED", +} + +var ConnectResult_value = map[string]int32{ + "CONN_SUCCESS": 0, + "CONN_FAILED": 1, + "CONN_BLOCKED": 2, +} + +func (x ConnectResult) String() string { + return proto.EnumName(ConnectResult_name, int32(x)) +} + +func (ConnectResult) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_0c5120591600887d, []int{2} +} + +type Speed struct { + SendBps uint64 `protobuf:"varint,1,opt,name=send_bps,json=sendBps,proto3" json:"send_bps,omitempty"` + ReceiveBps uint64 `protobuf:"varint,2,opt,name=receive_bps,json=receiveBps,proto3" json:"receive_bps,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Speed) Reset() { *m = Speed{} } +func (m *Speed) String() string { return proto.CompactTextString(m) } +func (*Speed) ProtoMessage() {} +func (*Speed) Descriptor() ([]byte, []int) { + return fileDescriptor_0c5120591600887d, []int{0} +} + +func (m *Speed) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Speed.Unmarshal(m, b) +} +func (m *Speed) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Speed.Marshal(b, m, deterministic) +} +func (m *Speed) XXX_Merge(src proto.Message) { + xxx_messageInfo_Speed.Merge(m, src) +} +func (m *Speed) XXX_Size() int { + return xxx_messageInfo_Speed.Size(m) +} +func (m *Speed) XXX_DiscardUnknown() { + xxx_messageInfo_Speed.DiscardUnknown(m) +} + +var xxx_messageInfo_Speed proto.InternalMessageInfo + +func (m *Speed) GetSendBps() uint64 { + if m != nil { + return m.SendBps + } + return 0 +} + +func (m *Speed) GetReceiveBps() uint64 { + if m != nil { + return m.ReceiveBps + } + return 0 +} + +type Credential struct { + Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` + Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Credential) Reset() { *m = Credential{} } +func (m *Credential) String() string { return proto.CompactTextString(m) } +func (*Credential) ProtoMessage() {} +func (*Credential) Descriptor() ([]byte, []int) { + return fileDescriptor_0c5120591600887d, []int{1} +} + +func (m *Credential) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Credential.Unmarshal(m, b) +} +func (m *Credential) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Credential.Marshal(b, m, deterministic) +} +func (m *Credential) XXX_Merge(src proto.Message) { + xxx_messageInfo_Credential.Merge(m, src) +} +func (m *Credential) XXX_Size() int { + return xxx_messageInfo_Credential.Size(m) +} +func (m *Credential) XXX_DiscardUnknown() { + xxx_messageInfo_Credential.DiscardUnknown(m) +} + +var xxx_messageInfo_Credential proto.InternalMessageInfo + +func (m *Credential) GetUsername() string { + if m != nil { + return m.Username + } + return "" +} + +func (m *Credential) GetPassword() string { + if m != nil { + return m.Password + } + return "" +} + +type ClientAuthRequest struct { + Credential *Credential `protobuf:"bytes,1,opt,name=credential,proto3" json:"credential,omitempty"` + Speed *Speed `protobuf:"bytes,2,opt,name=speed,proto3" json:"speed,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ClientAuthRequest) Reset() { *m = ClientAuthRequest{} } +func (m *ClientAuthRequest) String() string { return proto.CompactTextString(m) } +func (*ClientAuthRequest) ProtoMessage() {} +func (*ClientAuthRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_0c5120591600887d, []int{2} +} + +func (m *ClientAuthRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ClientAuthRequest.Unmarshal(m, b) +} +func (m *ClientAuthRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ClientAuthRequest.Marshal(b, m, deterministic) +} +func (m *ClientAuthRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ClientAuthRequest.Merge(m, src) +} +func (m *ClientAuthRequest) XXX_Size() int { + return xxx_messageInfo_ClientAuthRequest.Size(m) +} +func (m *ClientAuthRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ClientAuthRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ClientAuthRequest proto.InternalMessageInfo + +func (m *ClientAuthRequest) GetCredential() *Credential { + if m != nil { + return m.Credential + } + return nil +} + +func (m *ClientAuthRequest) GetSpeed() *Speed { + if m != nil { + return m.Speed + } + return nil +} + +type ServerAuthResponse struct { + Result AuthResult `protobuf:"varint,1,opt,name=result,proto3,enum=core.AuthResult" json:"result,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + Speed *Speed `protobuf:"bytes,3,opt,name=speed,proto3" json:"speed,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ServerAuthResponse) Reset() { *m = ServerAuthResponse{} } +func (m *ServerAuthResponse) String() string { return proto.CompactTextString(m) } +func (*ServerAuthResponse) ProtoMessage() {} +func (*ServerAuthResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_0c5120591600887d, []int{3} +} + +func (m *ServerAuthResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ServerAuthResponse.Unmarshal(m, b) +} +func (m *ServerAuthResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ServerAuthResponse.Marshal(b, m, deterministic) +} +func (m *ServerAuthResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ServerAuthResponse.Merge(m, src) +} +func (m *ServerAuthResponse) XXX_Size() int { + return xxx_messageInfo_ServerAuthResponse.Size(m) +} +func (m *ServerAuthResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ServerAuthResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ServerAuthResponse proto.InternalMessageInfo + +func (m *ServerAuthResponse) GetResult() AuthResult { + if m != nil { + return m.Result + } + return AuthResult_AUTH_SUCCESS +} + +func (m *ServerAuthResponse) GetMessage() string { + if m != nil { + return m.Message + } + return "" +} + +func (m *ServerAuthResponse) GetSpeed() *Speed { + if m != nil { + return m.Speed + } + return nil +} + +type ClientConnectRequest struct { + Type ConnectionType `protobuf:"varint,1,opt,name=type,proto3,enum=core.ConnectionType" json:"type,omitempty"` + Address string `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ClientConnectRequest) Reset() { *m = ClientConnectRequest{} } +func (m *ClientConnectRequest) String() string { return proto.CompactTextString(m) } +func (*ClientConnectRequest) ProtoMessage() {} +func (*ClientConnectRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_0c5120591600887d, []int{4} +} + +func (m *ClientConnectRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ClientConnectRequest.Unmarshal(m, b) +} +func (m *ClientConnectRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ClientConnectRequest.Marshal(b, m, deterministic) +} +func (m *ClientConnectRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_ClientConnectRequest.Merge(m, src) +} +func (m *ClientConnectRequest) XXX_Size() int { + return xxx_messageInfo_ClientConnectRequest.Size(m) +} +func (m *ClientConnectRequest) XXX_DiscardUnknown() { + xxx_messageInfo_ClientConnectRequest.DiscardUnknown(m) +} + +var xxx_messageInfo_ClientConnectRequest proto.InternalMessageInfo + +func (m *ClientConnectRequest) GetType() ConnectionType { + if m != nil { + return m.Type + } + return ConnectionType_TCP +} + +func (m *ClientConnectRequest) GetAddress() string { + if m != nil { + return m.Address + } + return "" +} + +type ServerConnectResponse struct { + Result ConnectResult `protobuf:"varint,1,opt,name=result,proto3,enum=core.ConnectResult" json:"result,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *ServerConnectResponse) Reset() { *m = ServerConnectResponse{} } +func (m *ServerConnectResponse) String() string { return proto.CompactTextString(m) } +func (*ServerConnectResponse) ProtoMessage() {} +func (*ServerConnectResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_0c5120591600887d, []int{5} +} + +func (m *ServerConnectResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_ServerConnectResponse.Unmarshal(m, b) +} +func (m *ServerConnectResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_ServerConnectResponse.Marshal(b, m, deterministic) +} +func (m *ServerConnectResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_ServerConnectResponse.Merge(m, src) +} +func (m *ServerConnectResponse) XXX_Size() int { + return xxx_messageInfo_ServerConnectResponse.Size(m) +} +func (m *ServerConnectResponse) XXX_DiscardUnknown() { + xxx_messageInfo_ServerConnectResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_ServerConnectResponse proto.InternalMessageInfo + +func (m *ServerConnectResponse) GetResult() ConnectResult { + if m != nil { + return m.Result + } + return ConnectResult_CONN_SUCCESS +} + +func (m *ServerConnectResponse) GetMessage() string { + if m != nil { + return m.Message + } + return "" +} + +func init() { + proto.RegisterEnum("core.AuthResult", AuthResult_name, AuthResult_value) + proto.RegisterEnum("core.ConnectionType", ConnectionType_name, ConnectionType_value) + proto.RegisterEnum("core.ConnectResult", ConnectResult_name, ConnectResult_value) + proto.RegisterType((*Speed)(nil), "core.Speed") + proto.RegisterType((*Credential)(nil), "core.Credential") + proto.RegisterType((*ClientAuthRequest)(nil), "core.ClientAuthRequest") + proto.RegisterType((*ServerAuthResponse)(nil), "core.ServerAuthResponse") + proto.RegisterType((*ClientConnectRequest)(nil), "core.ClientConnectRequest") + proto.RegisterType((*ServerConnectResponse)(nil), "core.ServerConnectResponse") +} + +func init() { + proto.RegisterFile("control.proto", fileDescriptor_0c5120591600887d) +} + +var fileDescriptor_0c5120591600887d = []byte{ + // 431 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xd1, 0x6e, 0xd3, 0x30, + 0x14, 0x86, 0xd7, 0xae, 0x5b, 0xb7, 0x13, 0x36, 0x32, 0x6f, 0x13, 0x83, 0x1b, 0x20, 0x57, 0x55, + 0x91, 0x2a, 0x34, 0x9e, 0x20, 0x75, 0x82, 0xa8, 0xa8, 0xd2, 0xc9, 0x69, 0xb9, 0xe0, 0x82, 0x2a, + 0x4b, 0x8e, 0x58, 0xa5, 0xcc, 0x36, 0xb6, 0x33, 0x34, 0xf1, 0xf2, 0x28, 0x8e, 0x93, 0xae, 0x48, + 0x48, 0xbb, 0xeb, 0x39, 0xe7, 0xd7, 0xff, 0xf9, 0xab, 0x02, 0x27, 0xb9, 0xe0, 0x46, 0x89, 0x72, + 0x22, 0x95, 0x30, 0x82, 0x0c, 0x72, 0xa1, 0x30, 0xa0, 0x70, 0x90, 0x4a, 0xc4, 0x82, 0xbc, 0x86, + 0x23, 0x8d, 0xbc, 0x58, 0xdf, 0x4a, 0x7d, 0xd5, 0x7b, 0xd7, 0x1b, 0x0d, 0xd8, 0xb0, 0x9e, 0xa7, + 0x52, 0x93, 0xb7, 0xe0, 0x29, 0xcc, 0x71, 0xf3, 0x80, 0xf6, 0xda, 0xb7, 0x57, 0x70, 0xab, 0xa9, + 0xd4, 0x41, 0x04, 0x40, 0x15, 0x16, 0xc8, 0xcd, 0x26, 0x2b, 0xc9, 0x1b, 0x38, 0xaa, 0x34, 0x2a, + 0x9e, 0xdd, 0xa3, 0x6d, 0x3a, 0x66, 0xdd, 0x5c, 0xdf, 0x64, 0xa6, 0xf5, 0x6f, 0xa1, 0x0a, 0xdb, + 0x73, 0xcc, 0xba, 0x39, 0xb8, 0x83, 0x33, 0x5a, 0x6e, 0x90, 0x9b, 0xb0, 0x32, 0x77, 0x0c, 0x7f, + 0x55, 0xa8, 0x0d, 0xf9, 0x08, 0x90, 0x77, 0xd5, 0xb6, 0xce, 0xbb, 0xf6, 0x27, 0xf5, 0xd3, 0x27, + 0x5b, 0x24, 0x7b, 0x92, 0x21, 0xef, 0xe1, 0x40, 0xd7, 0x46, 0xb6, 0xdf, 0xbb, 0xf6, 0x9a, 0xb0, + 0x95, 0x64, 0xcd, 0x25, 0xf8, 0x03, 0x24, 0x45, 0xf5, 0x80, 0xaa, 0x21, 0x69, 0x29, 0xb8, 0x46, + 0x32, 0x82, 0x43, 0x85, 0xba, 0x2a, 0x8d, 0xc5, 0x9c, 0xb6, 0x18, 0x97, 0xa9, 0x4a, 0xc3, 0xdc, + 0x9d, 0x5c, 0xc1, 0xf0, 0x1e, 0xb5, 0xce, 0x7e, 0xa2, 0x93, 0x68, 0xc7, 0x2d, 0x7c, 0xff, 0xbf, + 0xf0, 0xef, 0x70, 0xd1, 0x68, 0x52, 0xc1, 0x39, 0xe6, 0xa6, 0x35, 0x1d, 0xc1, 0xc0, 0x3c, 0x4a, + 0x74, 0xf0, 0x0b, 0xe7, 0xd8, 0x64, 0x36, 0x82, 0x2f, 0x1f, 0x25, 0x32, 0x9b, 0xa8, 0xf1, 0x59, + 0x51, 0x28, 0xd4, 0xba, 0xc5, 0xbb, 0x31, 0xf8, 0x01, 0x97, 0x8d, 0x58, 0xd7, 0xed, 0xdc, 0x3e, + 0xfc, 0xe3, 0x76, 0xbe, 0x53, 0xff, 0x5c, 0xbd, 0x71, 0x02, 0xb0, 0xfd, 0x3b, 0x88, 0x0f, 0x2f, + 0xc2, 0xd5, 0xf2, 0xcb, 0x3a, 0x5d, 0x51, 0x1a, 0xa7, 0xa9, 0xbf, 0x47, 0x2e, 0xe1, 0xcc, 0x6e, + 0x66, 0xc9, 0xb7, 0x70, 0x3e, 0x8b, 0xd6, 0x94, 0xc5, 0x91, 0xdf, 0x23, 0xaf, 0xe0, 0xdc, 0xad, + 0x97, 0x31, 0x4b, 0xc2, 0xf9, 0x3a, 0x66, 0x6c, 0xc1, 0xfc, 0xfe, 0x38, 0x80, 0xd3, 0x5d, 0x43, + 0x32, 0x84, 0xfd, 0x25, 0xbd, 0xf1, 0xf7, 0xea, 0x1f, 0xab, 0xe8, 0xc6, 0xef, 0x8d, 0x23, 0x38, + 0xd9, 0x79, 0x66, 0x8d, 0xa5, 0x8b, 0x24, 0x79, 0x82, 0x7d, 0x09, 0x9e, 0xdd, 0x7c, 0x0e, 0x67, + 0x73, 0x0b, 0x6c, 0x23, 0xd3, 0xf9, 0x82, 0x7e, 0x8d, 0x23, 0xbf, 0x7f, 0x7b, 0x68, 0x3f, 0xfa, + 0x4f, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0x65, 0xfc, 0xeb, 0x5c, 0x05, 0x03, 0x00, 0x00, +} diff --git a/internal/core/control.proto b/internal/core/control.proto new file mode 100644 index 0000000..53d8279 --- /dev/null +++ b/internal/core/control.proto @@ -0,0 +1,50 @@ +syntax = "proto3"; +package core; + +message Speed { + uint64 send_bps = 1; + uint64 receive_bps = 2; +} + +message Credential { + string username = 1; + string password = 2; +} + +enum AuthResult { + AUTH_SUCCESS = 0; + AUTH_INVALID_CRED = 1; + AUTH_INTERNAL_ERROR = 2; +} + +message ClientAuthRequest { + Credential credential = 1; + Speed speed = 2; +} + +message ServerAuthResponse { + AuthResult result = 1; + string message = 2; + Speed speed = 3; +} + +enum ConnectionType { + TCP = 0; + UDP = 1; +} + +enum ConnectResult { + CONN_SUCCESS = 0; + CONN_FAILED = 1; + CONN_BLOCKED = 2; +} + +message ClientConnectRequest { + ConnectionType type = 1; + string address = 2; +} + +message ServerConnectResponse { + ConnectResult result = 1; + string message = 2; +} \ No newline at end of file diff --git a/internal/forwarder/protogen.go b/internal/core/protogen.go similarity index 72% rename from internal/forwarder/protogen.go rename to internal/core/protogen.go index aa1f7f3..afa8971 100644 --- a/internal/forwarder/protogen.go +++ b/internal/core/protogen.go @@ -1,3 +1,3 @@ -package forwarder +package core //go:generate protoc --go_out=. control.proto diff --git a/internal/core/server.go b/internal/core/server.go new file mode 100644 index 0000000..411da71 --- /dev/null +++ b/internal/core/server.go @@ -0,0 +1,199 @@ +package core + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "github.com/lucas-clemente/quic-go" + "github.com/tobyxdd/hysteria/internal/utils" + "io" + "net" + "sync/atomic" +) + +type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string) +type ClientDisconnectedFunc func(addr net.Addr, username string, err error) +type HandleRequestFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string) (ConnectResult, string, io.ReadWriteCloser) +type RequestClosedFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string, err error) + +type Server struct { + inboundBytes, outboundBytes uint64 // atomic + + listener quic.Listener + sendBPS, recvBPS uint64 + + congestionFactory CongestionFactory + clientAuthFunc ClientAuthFunc + clientDisconnectedFunc ClientDisconnectedFunc + handleRequestFunc HandleRequestFunc + requestClosedFunc RequestClosedFunc +} + +func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, + sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, + clientAuthFunc ClientAuthFunc, + clientDisconnectedFunc ClientDisconnectedFunc, + handleRequestFunc HandleRequestFunc, + requestClosedFunc RequestClosedFunc) (*Server, error) { + listener, err := quic.ListenAddr(addr, tlsConfig, quicConfig) + if err != nil { + return nil, err + } + s := &Server{ + listener: listener, + sendBPS: sendBPS, + recvBPS: recvBPS, + congestionFactory: congestionFactory, + clientAuthFunc: clientAuthFunc, + clientDisconnectedFunc: clientDisconnectedFunc, + handleRequestFunc: handleRequestFunc, + requestClosedFunc: requestClosedFunc, + } + return s, nil +} + +func (s *Server) Serve() error { + for { + cs, err := s.listener.Accept(context.Background()) + if err != nil { + return err + } + go s.handleClient(cs) + } +} + +func (s *Server) Stats() (uint64, uint64) { + return atomic.LoadUint64(&s.inboundBytes), atomic.LoadUint64(&s.outboundBytes) +} + +func (s *Server) Close() error { + return s.listener.Close() +} + +func (s *Server) handleClient(cs quic.Session) { + // Expect the client to create a control stream to send its own information + ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout) + ctlStream, err := cs.AcceptStream(ctx) + ctxCancel() + if err != nil { + _ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error") + return + } + // Handle the control stream + username, ok, err := s.handleControlStream(cs, ctlStream) + if err != nil { + _ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error") + return + } + if !ok { + _ = cs.CloseWithError(closeErrorCodeGeneric, "authentication failure") + return + } + // Start accepting streams + var closeErr error + for { + stream, err := cs.AcceptStream(context.Background()) + if err != nil { + closeErr = err + break + } + go s.handleStream(cs.RemoteAddr(), username, stream) + } + s.clientDisconnectedFunc(cs.RemoteAddr(), username, closeErr) + _ = cs.CloseWithError(closeErrorCodeGeneric, "generic") +} + +// Auth & negotiate speed +func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) (string, bool, error) { + req, err := readClientAuthRequest(stream) + if err != nil { + return "", false, err + } + // Speed + if req.Speed == nil || req.Speed.SendBps == 0 || req.Speed.ReceiveBps == 0 { + return "", false, errors.New("incorrect speed provided by the client") + } + serverSendBPS, serverReceiveBPS := req.Speed.ReceiveBps, req.Speed.SendBps + if s.sendBPS > 0 && serverSendBPS > s.sendBPS { + serverSendBPS = s.sendBPS + } + if s.recvBPS > 0 && serverReceiveBPS > s.recvBPS { + serverReceiveBPS = s.recvBPS + } + // Auth + if req.Credential == nil { + return "", false, errors.New("incorrect credential provided by the client") + } + authResult, msg := s.clientAuthFunc(cs.RemoteAddr(), req.Credential.Username, req.Credential.Password, + serverSendBPS, serverReceiveBPS) + // Response + err = writeServerAuthResponse(stream, &ServerAuthResponse{ + Result: authResult, + Message: msg, + Speed: &Speed{ + SendBps: serverSendBPS, + ReceiveBps: serverReceiveBPS, + }, + }) + if err != nil { + return "", false, err + } + // Set the congestion accordingly + if authResult == AuthResult_AUTH_SUCCESS && s.congestionFactory != nil { + cs.SetCongestion(s.congestionFactory(serverSendBPS)) + } + return req.Credential.Username, authResult == AuthResult_AUTH_SUCCESS, nil +} + +func (s *Server) handleStream(addr net.Addr, username string, stream quic.Stream) { + defer stream.Close() + // Read request + req, err := readClientConnectRequest(stream) + if err != nil { + return + } + // Create connection with the handler + result, msg, conn := s.handleRequestFunc(addr, username, int(stream.StreamID()), req.Type, req.Address) + defer func() { + if conn != nil { + _ = conn.Close() + } + }() + // Send response + err = writeServerConnectResponse(stream, &ServerConnectResponse{ + Result: result, + Message: msg, + }) + if err != nil { + s.requestClosedFunc(addr, username, int(stream.StreamID()), req.Type, req.Address, err) + return + } + if result != ConnectResult_CONN_SUCCESS { + s.requestClosedFunc(addr, username, int(stream.StreamID()), req.Type, req.Address, + fmt.Errorf("handler returned an unsuccessful state %s (msg: %s)", result.String(), msg)) + return + } + switch req.Type { + case ConnectionType_TCP: + err = s.pipePair(stream, conn) + case ConnectionType_UDP: + err = s.pipePair(&utils.PacketReadWriteCloser{Orig: stream}, conn) + default: + err = fmt.Errorf("unsupported connection type %s", req.Type.String()) + } + s.requestClosedFunc(addr, username, int(stream.StreamID()), req.Type, req.Address, err) +} + +func (s *Server) pipePair(rw1, rw2 io.ReadWriter) error { + // Pipes + errChan := make(chan error, 2) + go func() { + errChan <- utils.Pipe(rw2, rw1, &s.outboundBytes) + }() + go func() { + errChan <- utils.Pipe(rw1, rw2, &s.inboundBytes) + }() + // We only need the first error + return <-errChan +} diff --git a/internal/core/types.go b/internal/core/types.go new file mode 100644 index 0000000..6d3f4ce --- /dev/null +++ b/internal/core/types.go @@ -0,0 +1,13 @@ +package core + +import ( + "encoding/binary" + "github.com/lucas-clemente/quic-go/congestion" + "time" +) + +const controlStreamTimeout = 10 * time.Second + +var controlProtocolEndian = binary.BigEndian + +type CongestionFactory func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos diff --git a/internal/forwarder/client.go b/internal/forwarder/client.go deleted file mode 100644 index 71d313f..0000000 --- a/internal/forwarder/client.go +++ /dev/null @@ -1,203 +0,0 @@ -package forwarder - -import ( - "context" - "crypto/tls" - "errors" - "github.com/lucas-clemente/quic-go" - "github.com/tobyxdd/hysteria/internal/utils" - "net" - "sync" - "sync/atomic" -) - -type QUICClient struct { - inboundBytes, outboundBytes uint64 // atomic - - reconnectMutex sync.Mutex - quicSession quic.Session - listener net.Listener - remoteAddr string - name string - tlsConfig *tls.Config - sendBPS, recvBPS uint64 - recvWindowConn, recvWindow uint64 - closed bool - - newCongestion CongestionFactory - onServerConnected ServerConnectedCallback - onServerError ServerErrorCallback - onNewTCPConnection NewTCPConnectionCallback - onTCPConnectionClosed TCPConnectionClosedCallback -} - -func NewQUICClient(addr string, remoteAddr string, name string, tlsConfig *tls.Config, - sendBPS uint64, recvBPS uint64, recvWindowConn uint64, recvWindow uint64, - newCongestion CongestionFactory, - onServerConnected ServerConnectedCallback, - onServerError ServerErrorCallback, - onNewTCPConnection NewTCPConnectionCallback, - onTCPConnectionClosed TCPConnectionClosedCallback) (*QUICClient, error) { - // Local TCP listener - listener, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - c := &QUICClient{ - listener: listener, - remoteAddr: remoteAddr, - name: name, - tlsConfig: tlsConfig, - sendBPS: sendBPS, - recvBPS: recvBPS, - recvWindowConn: recvWindowConn, - recvWindow: recvWindow, - newCongestion: newCongestion, - onServerConnected: onServerConnected, - onServerError: onServerError, - onNewTCPConnection: onNewTCPConnection, - onTCPConnectionClosed: onTCPConnectionClosed, - } - if err := c.connectToServer(); err != nil { - _ = c.listener.Close() - return nil, err - } - go c.acceptLoop() - return c, nil -} - -func (c *QUICClient) Close() error { - err1 := c.listener.Close() - c.reconnectMutex.Lock() - err2 := c.quicSession.CloseWithError(closeErrorCodeGeneric, "generic") - c.closed = true - c.reconnectMutex.Unlock() - if err1 != nil { - return err1 - } - return err2 -} - -func (c *QUICClient) Stats() (string, uint64, uint64) { - return c.remoteAddr, atomic.LoadUint64(&c.inboundBytes), atomic.LoadUint64(&c.outboundBytes) -} - -func (c *QUICClient) acceptLoop() { - for { - conn, err := c.listener.Accept() - if err != nil { - break - } - go c.handleConn(conn) - } -} - -func (c *QUICClient) connectToServer() error { - qs, err := quic.DialAddr(c.remoteAddr, c.tlsConfig, &quic.Config{ - MaxReceiveStreamFlowControlWindow: c.recvWindowConn, - MaxReceiveConnectionFlowControlWindow: c.recvWindow, - KeepAlive: true, - }) - if err != nil { - c.onServerError(err) - return err - } - // Control stream - ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout) - ctlStream, err := qs.OpenStreamSync(ctx) - ctxCancel() - if err != nil { - _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error") - c.onServerError(err) - return err - } - banner, cSendBPS, cRecvBPS, err := handleControlStream(qs, ctlStream, c.name, c.sendBPS, c.recvBPS, c.newCongestion) - if err != nil { - _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error") - c.onServerError(err) - return err - } - // All good - c.quicSession = qs - c.onServerConnected(qs.RemoteAddr(), banner, cSendBPS, cRecvBPS) - return nil -} - -func (c *QUICClient) openStreamWithReconnect() (quic.Stream, error) { - c.reconnectMutex.Lock() - defer c.reconnectMutex.Unlock() - if c.closed { - return nil, errors.New("client closed") - } - stream, err := c.quicSession.OpenStream() - if err == nil { - // All good - return stream, nil - } - // Something is wrong - c.onServerError(err) - if nErr, ok := err.(net.Error); ok && nErr.Temporary() { - // Temporary error, just return - return nil, err - } - // Permanent error, need to reconnect - if err := c.connectToServer(); err != nil { - // Still error, oops - return nil, err - } - // We are not going to try again even if it still fails the second time - stream, err = c.quicSession.OpenStream() - if err != nil { - c.onServerError(err) - } - return stream, err -} - -// Negotiate speed, return banner, send & receive speed -func handleControlStream(qs quic.Session, stream quic.Stream, name string, sendBPS uint64, recvBPS uint64, - newCongestion CongestionFactory) (string, uint64, uint64, error) { - err := writeClientSpeedRequest(stream, &ClientSpeedRequest{ - Name: name, - Speed: &Speed{ - SendBps: sendBPS, - ReceiveBps: recvBPS, - }, - }) - if err != nil { - return "", 0, 0, err - } - // Response - resp, err := readServerSpeedResponse(stream) - if err != nil { - return "", 0, 0, err - } - // Set the congestion accordingly - if newCongestion != nil { - qs.SetCongestion(newCongestion(resp.Speed.ReceiveBps)) - } - return resp.Banner, resp.Speed.ReceiveBps, resp.Speed.SendBps, nil -} - -func (c *QUICClient) handleConn(conn net.Conn) { - c.onNewTCPConnection(conn.RemoteAddr()) - defer conn.Close() - stream, err := c.openStreamWithReconnect() - if err != nil { - c.onTCPConnectionClosed(conn.RemoteAddr(), err) - return - } - defer stream.Close() - // Pipes - errChan := make(chan error, 2) - go func() { - // TCP to QUIC - errChan <- utils.Pipe(conn, stream, &c.outboundBytes) - }() - go func() { - // QUIC to TCP - errChan <- utils.Pipe(stream, conn, &c.inboundBytes) - }() - // We only need the first error - err = <-errChan - c.onTCPConnectionClosed(conn.RemoteAddr(), err) -} diff --git a/internal/forwarder/control.go b/internal/forwarder/control.go deleted file mode 100644 index 398ef59..0000000 --- a/internal/forwarder/control.go +++ /dev/null @@ -1,67 +0,0 @@ -package forwarder - -import ( - "encoding/binary" - "github.com/golang/protobuf/proto" - "io" -) - -const ( - closeErrorCodeGeneric = 0 - closeErrorCodeProtocolFailure = 1 -) - -func readDataBlock(r io.Reader) ([]byte, error) { - var sz uint32 - if err := binary.Read(r, controlProtocolEndian, &sz); err != nil { - return nil, err - } - buf := make([]byte, sz) - _, err := io.ReadFull(r, buf) - return buf, err -} - -func writeDataBlock(w io.Writer, data []byte) error { - sz := uint32(len(data)) - if err := binary.Write(w, controlProtocolEndian, &sz); err != nil { - return err - } - _, err := w.Write(data) - return err -} - -func readClientSpeedRequest(r io.Reader) (*ClientSpeedRequest, error) { - bs, err := readDataBlock(r) - if err != nil { - return nil, err - } - var req ClientSpeedRequest - err = proto.Unmarshal(bs, &req) - return &req, err -} - -func writeClientSpeedRequest(w io.Writer, req *ClientSpeedRequest) error { - bs, err := proto.Marshal(req) - if err != nil { - return err - } - return writeDataBlock(w, bs) -} - -func readServerSpeedResponse(r io.Reader) (*ServerSpeedResponse, error) { - bs, err := readDataBlock(r) - if err != nil { - return nil, err - } - var resp ServerSpeedResponse - err = proto.Unmarshal(bs, &resp) - return &resp, err -} - -func writeServerSpeedResponse(w io.Writer, resp *ServerSpeedResponse) error { - bs, err := proto.Marshal(resp) - if err != nil { - return err - } - return writeDataBlock(w, bs) -} diff --git a/internal/forwarder/control.pb.go b/internal/forwarder/control.pb.go deleted file mode 100644 index 21b5df9..0000000 --- a/internal/forwarder/control.pb.go +++ /dev/null @@ -1,206 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// source: control.proto - -package forwarder - -import ( - fmt "fmt" - proto "github.com/golang/protobuf/proto" - math "math" -) - -// Reference imports to suppress errors if they are not otherwise used. -var _ = proto.Marshal -var _ = fmt.Errorf -var _ = math.Inf - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the proto package it is being compiled against. -// A compilation error at this line likely means your copy of the -// proto package needs to be updated. -const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package - -type Speed struct { - SendBps uint64 `protobuf:"varint,1,opt,name=send_bps,json=sendBps,proto3" json:"send_bps,omitempty"` - ReceiveBps uint64 `protobuf:"varint,2,opt,name=receive_bps,json=receiveBps,proto3" json:"receive_bps,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *Speed) Reset() { *m = Speed{} } -func (m *Speed) String() string { return proto.CompactTextString(m) } -func (*Speed) ProtoMessage() {} -func (*Speed) Descriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{0} -} - -func (m *Speed) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_Speed.Unmarshal(m, b) -} -func (m *Speed) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_Speed.Marshal(b, m, deterministic) -} -func (m *Speed) XXX_Merge(src proto.Message) { - xxx_messageInfo_Speed.Merge(m, src) -} -func (m *Speed) XXX_Size() int { - return xxx_messageInfo_Speed.Size(m) -} -func (m *Speed) XXX_DiscardUnknown() { - xxx_messageInfo_Speed.DiscardUnknown(m) -} - -var xxx_messageInfo_Speed proto.InternalMessageInfo - -func (m *Speed) GetSendBps() uint64 { - if m != nil { - return m.SendBps - } - return 0 -} - -func (m *Speed) GetReceiveBps() uint64 { - if m != nil { - return m.ReceiveBps - } - return 0 -} - -type ClientSpeedRequest struct { - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - Speed *Speed `protobuf:"bytes,2,opt,name=speed,proto3" json:"speed,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *ClientSpeedRequest) Reset() { *m = ClientSpeedRequest{} } -func (m *ClientSpeedRequest) String() string { return proto.CompactTextString(m) } -func (*ClientSpeedRequest) ProtoMessage() {} -func (*ClientSpeedRequest) Descriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{1} -} - -func (m *ClientSpeedRequest) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_ClientSpeedRequest.Unmarshal(m, b) -} -func (m *ClientSpeedRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_ClientSpeedRequest.Marshal(b, m, deterministic) -} -func (m *ClientSpeedRequest) XXX_Merge(src proto.Message) { - xxx_messageInfo_ClientSpeedRequest.Merge(m, src) -} -func (m *ClientSpeedRequest) XXX_Size() int { - return xxx_messageInfo_ClientSpeedRequest.Size(m) -} -func (m *ClientSpeedRequest) XXX_DiscardUnknown() { - xxx_messageInfo_ClientSpeedRequest.DiscardUnknown(m) -} - -var xxx_messageInfo_ClientSpeedRequest proto.InternalMessageInfo - -func (m *ClientSpeedRequest) GetName() string { - if m != nil { - return m.Name - } - return "" -} - -func (m *ClientSpeedRequest) GetSpeed() *Speed { - if m != nil { - return m.Speed - } - return nil -} - -type ServerSpeedResponse struct { - Banner string `protobuf:"bytes,1,opt,name=banner,proto3" json:"banner,omitempty"` - Limited bool `protobuf:"varint,2,opt,name=limited,proto3" json:"limited,omitempty"` - Limit *Speed `protobuf:"bytes,3,opt,name=limit,proto3" json:"limit,omitempty"` - Speed *Speed `protobuf:"bytes,4,opt,name=speed,proto3" json:"speed,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` -} - -func (m *ServerSpeedResponse) Reset() { *m = ServerSpeedResponse{} } -func (m *ServerSpeedResponse) String() string { return proto.CompactTextString(m) } -func (*ServerSpeedResponse) ProtoMessage() {} -func (*ServerSpeedResponse) Descriptor() ([]byte, []int) { - return fileDescriptor_0c5120591600887d, []int{2} -} - -func (m *ServerSpeedResponse) XXX_Unmarshal(b []byte) error { - return xxx_messageInfo_ServerSpeedResponse.Unmarshal(m, b) -} -func (m *ServerSpeedResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { - return xxx_messageInfo_ServerSpeedResponse.Marshal(b, m, deterministic) -} -func (m *ServerSpeedResponse) XXX_Merge(src proto.Message) { - xxx_messageInfo_ServerSpeedResponse.Merge(m, src) -} -func (m *ServerSpeedResponse) XXX_Size() int { - return xxx_messageInfo_ServerSpeedResponse.Size(m) -} -func (m *ServerSpeedResponse) XXX_DiscardUnknown() { - xxx_messageInfo_ServerSpeedResponse.DiscardUnknown(m) -} - -var xxx_messageInfo_ServerSpeedResponse proto.InternalMessageInfo - -func (m *ServerSpeedResponse) GetBanner() string { - if m != nil { - return m.Banner - } - return "" -} - -func (m *ServerSpeedResponse) GetLimited() bool { - if m != nil { - return m.Limited - } - return false -} - -func (m *ServerSpeedResponse) GetLimit() *Speed { - if m != nil { - return m.Limit - } - return nil -} - -func (m *ServerSpeedResponse) GetSpeed() *Speed { - if m != nil { - return m.Speed - } - return nil -} - -func init() { - proto.RegisterType((*Speed)(nil), "forwarder.Speed") - proto.RegisterType((*ClientSpeedRequest)(nil), "forwarder.ClientSpeedRequest") - proto.RegisterType((*ServerSpeedResponse)(nil), "forwarder.ServerSpeedResponse") -} - -func init() { - proto.RegisterFile("control.proto", fileDescriptor_0c5120591600887d) -} - -var fileDescriptor_0c5120591600887d = []byte{ - // 220 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x90, 0x4d, 0x4a, 0xc6, 0x30, - 0x10, 0x86, 0xa9, 0xf6, 0xfb, 0x9b, 0x0f, 0x41, 0x46, 0x90, 0xba, 0x52, 0xba, 0x10, 0x57, 0x5d, - 0xe8, 0x0d, 0xbe, 0x5e, 0x40, 0xd2, 0x03, 0x48, 0x7f, 0x46, 0x08, 0xb4, 0x49, 0x9c, 0x89, 0xf5, - 0x28, 0x5e, 0x57, 0x3a, 0x8d, 0xba, 0xd2, 0xdd, 0xbc, 0x79, 0x92, 0x67, 0x5e, 0x02, 0x17, 0xbd, - 0x77, 0x91, 0xfd, 0x58, 0x05, 0xf6, 0xd1, 0xe3, 0xe1, 0xd5, 0xf3, 0x47, 0xcb, 0x03, 0x71, 0x59, - 0xc3, 0xa6, 0x09, 0x44, 0x03, 0xde, 0xc0, 0x5e, 0xc8, 0x0d, 0x2f, 0x5d, 0x90, 0x22, 0xbb, 0xcb, - 0x1e, 0x72, 0xb3, 0x5b, 0xf2, 0x29, 0x08, 0xde, 0xc2, 0x91, 0xa9, 0x27, 0x3b, 0x93, 0xd2, 0x33, - 0xa5, 0x90, 0x8e, 0x4e, 0x41, 0xca, 0x67, 0xc0, 0x7a, 0xb4, 0xe4, 0xa2, 0xaa, 0x0c, 0xbd, 0xbd, - 0x93, 0x44, 0x44, 0xc8, 0x5d, 0x3b, 0x91, 0xda, 0x0e, 0x46, 0x67, 0xbc, 0x87, 0x8d, 0x2c, 0x77, - 0x54, 0x72, 0x7c, 0xbc, 0xac, 0x7e, 0x9a, 0x54, 0xeb, 0xdb, 0x15, 0x97, 0x9f, 0x19, 0x5c, 0x35, - 0xc4, 0x33, 0x71, 0x52, 0x4a, 0xf0, 0x4e, 0x08, 0xaf, 0x61, 0xdb, 0xb5, 0xce, 0x11, 0x27, 0x6b, - 0x4a, 0x58, 0xc0, 0x6e, 0xb4, 0x93, 0x8d, 0xc9, 0xbc, 0x37, 0xdf, 0x71, 0xd9, 0xa8, 0x63, 0x71, - 0xfe, 0xd7, 0x46, 0xc5, 0xbf, 0xcd, 0xf2, 0x7f, 0x9b, 0x75, 0x5b, 0xfd, 0xc2, 0xa7, 0xaf, 0x00, - 0x00, 0x00, 0xff, 0xff, 0xb2, 0x10, 0x5a, 0xf2, 0x53, 0x01, 0x00, 0x00, -} diff --git a/internal/forwarder/control.proto b/internal/forwarder/control.proto deleted file mode 100644 index f6c1b9a..0000000 --- a/internal/forwarder/control.proto +++ /dev/null @@ -1,19 +0,0 @@ -syntax = "proto3"; -package forwarder; - -message Speed { - uint64 send_bps = 1; - uint64 receive_bps = 2; -} - -message ClientSpeedRequest { - string name = 1; - Speed speed = 2; -} - -message ServerSpeedResponse { - string banner = 1; - bool limited = 2; - Speed limit = 3; - Speed speed = 4; -} \ No newline at end of file diff --git a/internal/forwarder/params.go b/internal/forwarder/params.go deleted file mode 100644 index 0b240a3..0000000 --- a/internal/forwarder/params.go +++ /dev/null @@ -1,10 +0,0 @@ -package forwarder - -import ( - "encoding/binary" - "time" -) - -const controlStreamTimeout = 10 * time.Second - -var controlProtocolEndian = binary.BigEndian diff --git a/internal/forwarder/server.go b/internal/forwarder/server.go deleted file mode 100644 index a3ac0f6..0000000 --- a/internal/forwarder/server.go +++ /dev/null @@ -1,176 +0,0 @@ -package forwarder - -import ( - "context" - "crypto/tls" - "errors" - "github.com/lucas-clemente/quic-go" - "github.com/tobyxdd/hysteria/internal/utils" - "net" - "sync/atomic" -) - -type QUICServer struct { - inboundBytes, outboundBytes uint64 // atomic - - listener quic.Listener - remoteAddr string - banner string - sendBPS, recvBPS uint64 - - newCongestion CongestionFactory - onClientConnected ClientConnectedCallback - onClientDisconnected ClientDisconnectedCallback - onClientNewStream ClientNewStreamCallback - onClientStreamClosed ClientStreamClosedCallback - onTCPError TCPErrorCallback -} - -func NewQUICServer(addr string, remoteAddr string, banner string, tlsConfig *tls.Config, - sendBPS uint64, recvBPS uint64, recvWindowConn uint64, recvWindowClients uint64, - clientMaxConn int, newCongestion CongestionFactory, - onClientConnected ClientConnectedCallback, - onClientDisconnected ClientDisconnectedCallback, - onClientNewStream ClientNewStreamCallback, - onClientStreamClosed ClientStreamClosedCallback, - onTCPError TCPErrorCallback) (*QUICServer, error) { - listener, err := quic.ListenAddr(addr, tlsConfig, &quic.Config{ - MaxReceiveStreamFlowControlWindow: recvWindowConn, - MaxReceiveConnectionFlowControlWindow: recvWindowClients, - MaxIncomingStreams: clientMaxConn, - KeepAlive: true, - }) - if err != nil { - return nil, err - } - s := &QUICServer{ - listener: listener, - remoteAddr: remoteAddr, - banner: banner, - sendBPS: sendBPS, - recvBPS: recvBPS, - newCongestion: newCongestion, - onClientConnected: onClientConnected, - onClientDisconnected: onClientDisconnected, - onClientNewStream: onClientNewStream, - onClientStreamClosed: onClientStreamClosed, - onTCPError: onTCPError, - } - go s.acceptLoop() - return s, nil -} - -func (s *QUICServer) Close() error { - return s.listener.Close() -} - -func (s *QUICServer) Stats() (string, uint64, uint64) { - return s.remoteAddr, atomic.LoadUint64(&s.inboundBytes), atomic.LoadUint64(&s.outboundBytes) -} - -func (s *QUICServer) acceptLoop() { - for { - cs, err := s.listener.Accept(context.Background()) - if err != nil { - break - } - go s.handleClient(cs) - } -} - -func (s *QUICServer) handleClient(cs quic.Session) { - // Expect the client to create a control stream and send its own information - ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout) - ctlStream, err := cs.AcceptStream(ctx) - ctxCancel() - if err != nil { - _ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error") - return - } - name, sSend, sRecv, err := s.handleControlStream(cs, ctlStream) - if err != nil { - _ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error") - return - } - // Only after a successful exchange of information do we consider this a valid client - s.onClientConnected(cs.RemoteAddr(), name, sSend, sRecv) - // Start accepting streams to be forwarded - var closeErr error - for { - stream, err := cs.AcceptStream(context.Background()) - if err != nil { - closeErr = err - break - } - go s.handleStream(cs.RemoteAddr(), name, stream) - } - s.onClientDisconnected(cs.RemoteAddr(), name, closeErr) - _ = cs.CloseWithError(closeErrorCodeGeneric, "generic") -} - -// Negotiate speed & return client name -func (s *QUICServer) handleControlStream(cs quic.Session, stream quic.Stream) (string, uint64, uint64, error) { - req, err := readClientSpeedRequest(stream) - if err != nil { - return "", 0, 0, err - } - if req.Speed == nil || req.Speed.SendBps == 0 || req.Speed.ReceiveBps == 0 { - return "", 0, 0, errors.New("incorrect speed information provided by the client") - } - limited := false - serverSendBPS, serverReceiveBPS := req.Speed.ReceiveBps, req.Speed.SendBps - if s.sendBPS > 0 && serverSendBPS > s.sendBPS { - limited = true - serverSendBPS = s.sendBPS - } - if s.recvBPS > 0 && serverReceiveBPS > s.recvBPS { - limited = true - serverReceiveBPS = s.recvBPS - } - // Response - err = writeServerSpeedResponse(stream, &ServerSpeedResponse{ - Banner: s.banner, - Limited: limited, - Limit: &Speed{ - SendBps: s.sendBPS, - ReceiveBps: s.recvBPS, - }, - Speed: &Speed{ - SendBps: serverSendBPS, - ReceiveBps: serverReceiveBPS, - }, - }) - if err != nil { - return "", 0, 0, err - } - // Set the congestion accordingly - if s.newCongestion != nil { - cs.SetCongestion(s.newCongestion(serverSendBPS)) - } - return req.Name, serverSendBPS, serverReceiveBPS, nil -} - -func (s *QUICServer) handleStream(addr net.Addr, name string, stream quic.Stream) { - s.onClientNewStream(addr, name, int(stream.StreamID())) - defer stream.Close() - tcpConn, err := net.Dial("tcp", s.remoteAddr) - if err != nil { - s.onTCPError(s.remoteAddr, err) - s.onClientStreamClosed(addr, name, int(stream.StreamID()), err) - return - } - defer tcpConn.Close() - // Pipes - errChan := make(chan error, 2) - go func() { - // TCP to QUIC - errChan <- utils.Pipe(tcpConn, stream, &s.outboundBytes) - }() - go func() { - // QUIC to TCP - errChan <- utils.Pipe(stream, tcpConn, &s.inboundBytes) - }() - // We only need the first error - err = <-errChan - s.onClientStreamClosed(addr, name, int(stream.StreamID()), err) -} diff --git a/internal/forwarder/types.go b/internal/forwarder/types.go deleted file mode 100644 index f8c3156..0000000 --- a/internal/forwarder/types.go +++ /dev/null @@ -1,21 +0,0 @@ -package forwarder - -import ( - "github.com/lucas-clemente/quic-go/congestion" - "net" -) - -type CongestionFactory func(refBPS uint64) congestion.SendAlgorithmWithDebugInfos - -// For server -type ClientConnectedCallback func(addr net.Addr, name string, sSend uint64, sRecv uint64) -type ClientDisconnectedCallback func(addr net.Addr, name string, err error) -type ClientNewStreamCallback func(addr net.Addr, name string, id int) -type ClientStreamClosedCallback func(addr net.Addr, name string, id int, err error) -type TCPErrorCallback func(remoteAddr string, err error) - -// For client -type ServerConnectedCallback func(addr net.Addr, banner string, cSend uint64, cRecv uint64) -type ServerErrorCallback func(err error) -type NewTCPConnectionCallback func(addr net.Addr) -type TCPConnectionClosedCallback func(addr net.Addr, err error) diff --git a/internal/utils/packet_readwritecloser.go b/internal/utils/packet_readwritecloser.go new file mode 100644 index 0000000..ecbcc66 --- /dev/null +++ b/internal/utils/packet_readwritecloser.go @@ -0,0 +1,35 @@ +package utils + +import ( + "encoding/binary" + "fmt" + "io" +) + +type PacketReadWriteCloser struct { + Orig io.ReadWriteCloser +} + +func (rw *PacketReadWriteCloser) Read(p []byte) (n int, err error) { + var sz uint32 + if err := binary.Read(rw.Orig, binary.BigEndian, &sz); err != nil { + return 0, err + } + if int(sz) <= len(p) { + return io.ReadFull(rw.Orig, p[:sz]) + } else { + return 0, fmt.Errorf("the buffer is too small to hold %d bytes of packet data", sz) + } +} + +func (rw *PacketReadWriteCloser) Write(p []byte) (n int, err error) { + sz := uint32(len(p)) + if err := binary.Write(rw.Orig, binary.BigEndian, &sz); err != nil { + return 0, err + } + return rw.Orig.Write(p) +} + +func (rw *PacketReadWriteCloser) Close() error { + return rw.Orig.Close() +} diff --git a/internal/utils/pipe.go b/internal/utils/pipe.go index a0dd250..fd35b45 100644 --- a/internal/utils/pipe.go +++ b/internal/utils/pipe.go @@ -5,7 +5,7 @@ import ( "sync/atomic" ) -const pipeBufferSize = 16384 +const pipeBufferSize = 65536 func Pipe(src, dst io.ReadWriter, atomicCounter *uint64) error { buf := make([]byte, pipeBufferSize) @@ -13,7 +13,9 @@ func Pipe(src, dst io.ReadWriter, atomicCounter *uint64) error { rn, err := src.Read(buf) if rn > 0 { wn, err := dst.Write(buf[:rn]) - atomic.AddUint64(atomicCounter, uint64(wn)) + if atomicCounter != nil { + atomic.AddUint64(atomicCounter, uint64(wn)) + } if err != nil { return err } diff --git a/pkg/core/interface.go b/pkg/core/interface.go new file mode 100644 index 0000000..04bdec1 --- /dev/null +++ b/pkg/core/interface.go @@ -0,0 +1,71 @@ +package core + +import ( + "crypto/tls" + "github.com/lucas-clemente/quic-go" + "github.com/tobyxdd/hysteria/internal/core" + "io" + "net" +) + +type AuthResult int32 + +const ( + AuthSuccess = AuthResult(iota) + AuthInvalidCred + AuthInternalError +) + +type ConnectResult int32 + +const ( + ConnSuccess = ConnectResult(iota) + ConnFailed + ConnBlocked +) + +type CongestionFactory core.CongestionFactory +type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string) +type ClientDisconnectedFunc core.ClientDisconnectedFunc +type HandleRequestFunc func(addr net.Addr, username string, id int, isUDP bool, reqAddr string) (ConnectResult, string, io.ReadWriteCloser) +type RequestClosedFunc func(addr net.Addr, username string, id int, isUDP bool, reqAddr string, err error) + +type Server interface { + Serve() error + Stats() (inbound uint64, outbound uint64) + Close() error +} + +func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, + sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, + clientAuthFunc ClientAuthFunc, + clientDisconnectedFunc ClientDisconnectedFunc, + handleRequestFunc HandleRequestFunc, + requestClosedFunc RequestClosedFunc) (Server, error) { + return core.NewServer(addr, tlsConfig, quicConfig, sendBPS, recvBPS, core.CongestionFactory(congestionFactory), + func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) { + r, msg := clientAuthFunc(addr, username, password, sSend, sRecv) + return core.AuthResult(r), msg + }, + core.ClientDisconnectedFunc(clientDisconnectedFunc), + func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) { + r, msg, conn := handleRequestFunc(addr, username, id, reqType == core.ConnectionType_UDP, reqAddr) + return core.ConnectResult(r), msg, conn + }, + func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) { + requestClosedFunc(addr, username, id, reqType == core.ConnectionType_UDP, reqAddr, err) + }) +} + +type Client interface { + Dial(udp bool, addr string) (io.ReadWriteCloser, error) + Stats() (inbound uint64, outbound uint64) + Close() error +} + +func NewClient(serverAddr string, username string, password string, + tlsConfig *tls.Config, quicConfig *quic.Config, sendBPS uint64, recvBPS uint64, + congestionFactory CongestionFactory) (Client, error) { + return core.NewClient(serverAddr, username, password, tlsConfig, quicConfig, sendBPS, recvBPS, + core.CongestionFactory(congestionFactory)) +} diff --git a/pkg/forwarder/client.go b/pkg/forwarder/client.go deleted file mode 100644 index c75e993..0000000 --- a/pkg/forwarder/client.go +++ /dev/null @@ -1,70 +0,0 @@ -package forwarder - -import ( - "crypto/tls" - "errors" - "github.com/tobyxdd/hysteria/internal/forwarder" - "net" -) - -type client struct { - qc *forwarder.QUICClient -} - -func NewClient(localAddr string, remoteAddr string, config ClientConfig, callbacks ClientCallbacks) (Client, error) { - // Fix config first - if config.Speed == nil || config.Speed.SendBPS == 0 || config.Speed.ReceiveBPS == 0 { - return nil, errors.New("invalid speed") - } - if config.TLSConfig == nil { - config.TLSConfig = &tls.Config{NextProtos: []string{TLSAppProtocol}} - } - if config.MaxReceiveWindowPerConnection == 0 { - config.MaxReceiveWindowPerConnection = defaultReceiveWindowConn - } - if config.MaxReceiveWindow == 0 { - config.MaxReceiveWindow = defaultReceiveWindow - } - qc, err := forwarder.NewQUICClient(localAddr, remoteAddr, config.Name, config.TLSConfig, - config.Speed.SendBPS, config.Speed.ReceiveBPS, - config.MaxReceiveWindowPerConnection, config.MaxReceiveWindow, - forwarder.CongestionFactory(config.CongestionFactory), - func(addr net.Addr, banner string, cSend uint64, cRecv uint64) { - if callbacks.ServerConnectedCallback != nil { - callbacks.ServerConnectedCallback(addr, banner, cSend, cRecv) - } - }, - func(err error) { - if callbacks.ServerErrorCallback != nil { - callbacks.ServerErrorCallback(err) - } - }, - func(addr net.Addr) { - if callbacks.NewTCPConnectionCallback != nil { - callbacks.NewTCPConnectionCallback(addr) - } - }, - func(addr net.Addr, err error) { - if callbacks.TCPConnectionClosedCallback != nil { - callbacks.TCPConnectionClosedCallback(addr, err) - } - }, - ) - if err != nil { - return nil, err - } - return &client{qc: qc}, nil -} - -func (c *client) Stats() Stats { - addr, in, out := c.qc.Stats() - return Stats{ - RemoteAddr: addr, - inboundBytes: in, - outboundBytes: out, - } -} - -func (c *client) Close() error { - return c.Close() -} diff --git a/pkg/forwarder/interface.go b/pkg/forwarder/interface.go deleted file mode 100644 index 758acfb..0000000 --- a/pkg/forwarder/interface.go +++ /dev/null @@ -1,89 +0,0 @@ -package forwarder - -import ( - "crypto/tls" - "github.com/tobyxdd/hysteria/internal/forwarder" - "net" -) - -type CongestionFactory forwarder.CongestionFactory - -// A server can support multiple forwarding entries (listenAddr/remoteAddr pairs) -type Server interface { - Add(listenAddr, remoteAddr string) error - Remove(listenAddr string) error - Stats() map[string]Stats -} - -// An empty ServerConfig is a valid one -type ServerConfig struct { - // A banner message that will be sent to the client after the connection is established. - // No message if not set. - BannerMessage string - // TLSConfig is used to configure the TLS server. - // Use an insecure self-signed certificate if not set. - TLSConfig *tls.Config - // MaxSpeedPerClient is the maximum allowed sending and receiving speed for each client. - // Sending speed will never exceed this limit, even if a client demands a larger value. - // No restrictions if not set. - MaxSpeedPerClient *Speed - // Corresponds to MaxReceiveStreamFlowControlWindow in QUIC. - MaxReceiveWindowPerConnection uint64 - // Corresponds to MaxReceiveConnectionFlowControlWindow in QUIC. - MaxReceiveWindowPerClient uint64 - // Max number of simultaneous connections allowed for a client - MaxConnectionPerClient int - // Congestion factory - CongestionFactory CongestionFactory -} - -type ServerCallbacks struct { - ClientConnectedCallback func(listenAddr string, clientAddr net.Addr, name string, sSend uint64, sRecv uint64) - ClientDisconnectedCallback func(listenAddr string, clientAddr net.Addr, name string, err error) - ClientNewStreamCallback func(listenAddr string, clientAddr net.Addr, name string, id int) - ClientStreamClosedCallback func(listenAddr string, clientAddr net.Addr, name string, id int, err error) - TCPErrorCallback func(listenAddr string, remoteAddr string, err error) -} - -// A client supports one forwarding entry -type Client interface { - Stats() Stats - Close() error -} - -// An empty ClientConfig is NOT a valid one, as Speed must be set -type ClientConfig struct { - // A client can report its name to the server after the connection is established. - // No name if not set. - Name string - // TLSConfig is used to configure the TLS client. - // Use default settings if not set. - TLSConfig *tls.Config - // Speed reported by the client when negotiating with the server. - // The actual speed will also depend on the configuration of the server. - Speed *Speed - // Corresponds to MaxReceiveStreamFlowControlWindow in QUIC. - MaxReceiveWindowPerConnection uint64 - // Corresponds to MaxReceiveConnectionFlowControlWindow in QUIC. - MaxReceiveWindow uint64 - // Congestion factory - CongestionFactory CongestionFactory -} - -type ClientCallbacks struct { - ServerConnectedCallback func(addr net.Addr, banner string, cSend uint64, cRecv uint64) - ServerErrorCallback func(err error) - NewTCPConnectionCallback func(addr net.Addr) - TCPConnectionClosedCallback func(addr net.Addr, err error) -} - -type Speed struct { - SendBPS uint64 - ReceiveBPS uint64 -} - -type Stats struct { - RemoteAddr string - inboundBytes uint64 - outboundBytes uint64 -} diff --git a/pkg/forwarder/params.go b/pkg/forwarder/params.go deleted file mode 100644 index ffa6d90..0000000 --- a/pkg/forwarder/params.go +++ /dev/null @@ -1,9 +0,0 @@ -package forwarder - -const ( - TLSAppProtocol = "hysteria-forwarder" - - defaultReceiveWindowConn = 33554432 - defaultReceiveWindow = 67108864 - defaultMaxClientConn = 100 -) diff --git a/pkg/forwarder/server.go b/pkg/forwarder/server.go deleted file mode 100644 index e05d34b..0000000 --- a/pkg/forwarder/server.go +++ /dev/null @@ -1,119 +0,0 @@ -package forwarder - -import ( - "crypto/rand" - "crypto/rsa" - "crypto/tls" - "crypto/x509" - "encoding/pem" - "github.com/tobyxdd/hysteria/internal/forwarder" - "math/big" - "net" -) - -type server struct { - config ServerConfig - callbacks ServerCallbacks - entries map[string]*forwarder.QUICServer -} - -func NewServer(config ServerConfig, callbacks ServerCallbacks) Server { - // Fix config first - if config.TLSConfig == nil { - config.TLSConfig = generateInsecureTLSConfig() - } - if config.MaxSpeedPerClient == nil { - config.MaxSpeedPerClient = &Speed{0, 0} - } - if config.MaxReceiveWindowPerConnection == 0 { - config.MaxReceiveWindowPerConnection = defaultReceiveWindowConn - } - if config.MaxReceiveWindowPerClient == 0 { - config.MaxReceiveWindowPerClient = defaultReceiveWindow - } - if config.MaxConnectionPerClient <= 0 { - config.MaxConnectionPerClient = defaultMaxClientConn - } - return &server{config: config, callbacks: callbacks, entries: make(map[string]*forwarder.QUICServer)} -} - -func (s *server) Add(listenAddr, remoteAddr string) error { - qs, err := forwarder.NewQUICServer(listenAddr, remoteAddr, s.config.BannerMessage, s.config.TLSConfig, - s.config.MaxSpeedPerClient.SendBPS, s.config.MaxSpeedPerClient.ReceiveBPS, - s.config.MaxReceiveWindowPerConnection, s.config.MaxReceiveWindowPerClient, - s.config.MaxConnectionPerClient, forwarder.CongestionFactory(s.config.CongestionFactory), - func(addr net.Addr, name string, sSend uint64, sRecv uint64) { - if s.callbacks.ClientConnectedCallback != nil { - s.callbacks.ClientConnectedCallback(listenAddr, addr, name, sSend, sRecv) - } - }, - func(addr net.Addr, name string, err error) { - if s.callbacks.ClientDisconnectedCallback != nil { - s.callbacks.ClientDisconnectedCallback(listenAddr, addr, name, err) - } - }, - func(addr net.Addr, name string, id int) { - if s.callbacks.ClientNewStreamCallback != nil { - s.callbacks.ClientNewStreamCallback(listenAddr, addr, name, id) - } - }, - func(addr net.Addr, name string, id int, err error) { - if s.callbacks.ClientStreamClosedCallback != nil { - s.callbacks.ClientStreamClosedCallback(listenAddr, addr, name, id, err) - } - }, - func(remoteAddr string, err error) { - if s.callbacks.TCPErrorCallback != nil { - s.callbacks.TCPErrorCallback(listenAddr, remoteAddr, err) - } - }, - ) - if err != nil { - return err - } - s.entries[listenAddr] = qs - return nil -} - -func (s *server) Remove(listenAddr string) error { - defer delete(s.entries, listenAddr) - if qs, ok := s.entries[listenAddr]; ok && qs != nil { - return qs.Close() - } - return nil -} - -func (s *server) Stats() map[string]Stats { - r := make(map[string]Stats, len(s.entries)) - for laddr, sv := range s.entries { - addr, in, out := sv.Stats() - r[laddr] = Stats{ - RemoteAddr: addr, - inboundBytes: in, - outboundBytes: out, - } - } - return r -} - -func generateInsecureTLSConfig() *tls.Config { - key, err := rsa.GenerateKey(rand.Reader, 1024) - if err != nil { - panic(err) - } - template := x509.Certificate{SerialNumber: big.NewInt(1)} - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) - if err != nil { - panic(err) - } - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - panic(err) - } - return &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - NextProtos: []string{TLSAppProtocol}, - } -}