mirror of
https://github.com/cedar2025/hysteria.git
synced 2025-08-21 08:41:48 +00:00
Most things work fine now, except:
- UDP support has been temporarily removed, pending upstream QUIC library support for unreliable messages - SOCKS5 server needs some rework - Authentication
This commit is contained in:
@@ -3,11 +3,6 @@ package main
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/sirupsen/logrus"
|
||||
@@ -17,42 +12,38 @@ import (
|
||||
hyHTTP "github.com/tobyxdd/hysteria/pkg/http"
|
||||
"github.com/tobyxdd/hysteria/pkg/obfs"
|
||||
"github.com/tobyxdd/hysteria/pkg/socks5"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func proxyClient(args []string) {
|
||||
var config proxyClientConfig
|
||||
err := loadConfig(&config, args)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("Unable to load configuration")
|
||||
}
|
||||
if err := config.Check(); err != nil {
|
||||
logrus.WithField("error", err).Fatal("Configuration error")
|
||||
}
|
||||
logrus.WithField("config", config.String()).Info("Configuration loaded")
|
||||
|
||||
func client(config *clientConfig) {
|
||||
logrus.WithField("config", config.String()).Info("Client configuration loaded")
|
||||
// TLS
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: config.Insecure,
|
||||
NextProtos: []string{proxyTLSProtocol},
|
||||
NextProtos: []string{tlsProtocolName},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
// Load CA
|
||||
if len(config.CustomCAFile) > 0 {
|
||||
bs, err := ioutil.ReadFile(config.CustomCAFile)
|
||||
if len(config.CustomCA) > 0 {
|
||||
bs, err := ioutil.ReadFile(config.CustomCA)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"file": config.CustomCAFile,
|
||||
}).Fatal("Unable to load CA file")
|
||||
"file": config.CustomCA,
|
||||
}).Fatal("Failed to load CA")
|
||||
}
|
||||
cp := x509.NewCertPool()
|
||||
if !cp.AppendCertsFromPEM(bs) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"file": config.CustomCAFile,
|
||||
}).Fatal("Unable to parse CA file")
|
||||
"file": config.CustomCA,
|
||||
}).Fatal("Failed to parse CA")
|
||||
}
|
||||
tlsConfig.RootCAs = cp
|
||||
}
|
||||
|
||||
// QUIC config
|
||||
quicConfig := &quic.Config{
|
||||
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
|
||||
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindow,
|
||||
@@ -64,46 +55,54 @@ func proxyClient(args []string) {
|
||||
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
|
||||
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
|
||||
// Auth
|
||||
var auth []byte
|
||||
if len(config.Auth) > 0 {
|
||||
auth = config.Auth
|
||||
} else {
|
||||
auth = []byte(config.AuthString)
|
||||
}
|
||||
// Obfuscator
|
||||
var obfuscator core.Obfuscator
|
||||
if len(config.Obfs) > 0 {
|
||||
obfuscator = obfs.XORObfuscator(config.Obfs)
|
||||
}
|
||||
|
||||
// ACL
|
||||
var aclEngine *acl.Engine
|
||||
if len(config.ACLFile) > 0 {
|
||||
aclEngine, err = acl.LoadFromFile(config.ACLFile)
|
||||
if len(config.ACL) > 0 {
|
||||
var err error
|
||||
aclEngine, err = acl.LoadFromFile(config.ACL)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"file": config.ACLFile,
|
||||
}).Fatal("Unable to parse ACL")
|
||||
"file": config.ACL,
|
||||
}).Fatal("Failed to parse ACL")
|
||||
}
|
||||
}
|
||||
|
||||
client, err := core.NewClient(config.ServerAddr, config.Username, config.Password, tlsConfig, quicConfig,
|
||||
// Client
|
||||
client, err := core.NewClient(config.Server, auth, tlsConfig, quicConfig,
|
||||
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
|
||||
func(refBPS uint64) congestion.CongestionControl {
|
||||
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
|
||||
}, obfuscator)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("Client initialization failed")
|
||||
logrus.WithField("error", err).Fatal("Failed to initialize client")
|
||||
}
|
||||
defer client.Close()
|
||||
logrus.WithField("addr", config.ServerAddr).Info("Connected")
|
||||
logrus.WithField("addr", config.Server).Info("Connected")
|
||||
|
||||
// Local
|
||||
errChan := make(chan error)
|
||||
|
||||
if len(config.SOCKS5Addr) > 0 {
|
||||
if len(config.SOCKS5.Listen) > 0 {
|
||||
go func() {
|
||||
var authFunc func(user, password string) bool
|
||||
if config.SOCKS5User != "" && config.SOCKS5Password != "" {
|
||||
if config.SOCKS5.User != "" && config.SOCKS5.Password != "" {
|
||||
authFunc = func(user, password string) bool {
|
||||
return config.SOCKS5User == user && config.SOCKS5Password == password
|
||||
return config.SOCKS5.User == user && config.SOCKS5.Password == password
|
||||
}
|
||||
}
|
||||
socks5server, err := socks5.NewServer(client, config.SOCKS5Addr, authFunc, config.SOCKS5Timeout, aclEngine,
|
||||
config.SOCKS5DisableUDP,
|
||||
socks5server, err := socks5.NewServer(client, config.SOCKS5.Listen, authFunc, config.SOCKS5.Timeout, aclEngine,
|
||||
config.SOCKS5.DisableUDP,
|
||||
func(addr net.Addr, reqAddr string, action acl.Action, arg string) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"action": actionToString(action, arg),
|
||||
@@ -144,22 +143,22 @@ func proxyClient(args []string) {
|
||||
}).Debug("SOCKS5 UDP tunnel closed")
|
||||
})
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("SOCKS5 server initialization failed")
|
||||
logrus.WithField("error", err).Fatal("Failed to initialize SOCKS5 server")
|
||||
}
|
||||
logrus.WithField("addr", config.SOCKS5Addr).Info("SOCKS5 server up and running")
|
||||
logrus.WithField("addr", config.SOCKS5.Listen).Info("SOCKS5 server up and running")
|
||||
errChan <- socks5server.ListenAndServe()
|
||||
}()
|
||||
}
|
||||
|
||||
if len(config.HTTPAddr) > 0 {
|
||||
if len(config.HTTP.Listen) > 0 {
|
||||
go func() {
|
||||
var authFunc func(user, password string) bool
|
||||
if config.HTTPUser != "" && config.HTTPPassword != "" {
|
||||
if config.HTTP.User != "" && config.HTTP.Password != "" {
|
||||
authFunc = func(user, password string) bool {
|
||||
return config.HTTPUser == user && config.HTTPPassword == password
|
||||
return config.HTTP.User == user && config.HTTP.Password == password
|
||||
}
|
||||
}
|
||||
proxy, err := hyHTTP.NewProxyHTTPServer(client, time.Duration(config.HTTPTimeout)*time.Second, aclEngine,
|
||||
proxy, err := hyHTTP.NewProxyHTTPServer(client, time.Duration(config.HTTP.Timeout)*time.Second, aclEngine,
|
||||
func(reqAddr string, action acl.Action, arg string) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"action": actionToString(action, arg),
|
||||
@@ -168,14 +167,14 @@ func proxyClient(args []string) {
|
||||
},
|
||||
authFunc)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("HTTP server initialization failed")
|
||||
logrus.WithField("error", err).Fatal("Failed to initialize HTTP server")
|
||||
}
|
||||
if config.HTTPSCert != "" && config.HTTPSKey != "" {
|
||||
logrus.WithField("addr", config.HTTPAddr).Info("HTTPS server up and running")
|
||||
errChan <- http.ListenAndServeTLS(config.HTTPAddr, config.HTTPSCert, config.HTTPSKey, proxy)
|
||||
if config.HTTP.Cert != "" && config.HTTP.Key != "" {
|
||||
logrus.WithField("addr", config.HTTP.Listen).Info("HTTPS server up and running")
|
||||
errChan <- http.ListenAndServeTLS(config.HTTP.Listen, config.HTTP.Cert, config.HTTP.Key, proxy)
|
||||
} else {
|
||||
logrus.WithField("addr", config.HTTPAddr).Info("HTTP server up and running")
|
||||
errChan <- http.ListenAndServe(config.HTTPAddr, proxy)
|
||||
logrus.WithField("addr", config.HTTP.Listen).Info("HTTP server up and running")
|
||||
errChan <- http.ListenAndServe(config.HTTP.Listen, proxy)
|
||||
}
|
||||
}()
|
||||
}
|
175
cmd/config.go
175
cmd/config.go
@@ -1,86 +1,127 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
mbpsToBps = 125000
|
||||
dialTimeout = 10 * time.Second
|
||||
mbpsToBps = 125000
|
||||
|
||||
DefaultMaxReceiveStreamFlowControlWindow = 33554432
|
||||
DefaultMaxReceiveConnectionFlowControlWindow = 67108864
|
||||
DefaultMaxIncomingStreams = 200
|
||||
DefaultMaxIncomingStreams = 1024
|
||||
|
||||
tlsProtocolName = "hysteria"
|
||||
)
|
||||
|
||||
func loadConfig(cfg interface{}, args []string) error {
|
||||
cfgVal := reflect.ValueOf(cfg).Elem()
|
||||
fs := flag.NewFlagSet("", flag.ContinueOnError)
|
||||
fsValMap := make(map[reflect.Value]interface{}, cfgVal.NumField())
|
||||
for i := 0; i < cfgVal.NumField(); i++ {
|
||||
structField := cfgVal.Type().Field(i)
|
||||
tag := structField.Tag
|
||||
switch structField.Type.Kind() {
|
||||
case reflect.String:
|
||||
fsValMap[cfgVal.Field(i)] =
|
||||
fs.String(jsonTagToFlagName(tag.Get("json")), "", tag.Get("desc"))
|
||||
case reflect.Int:
|
||||
fsValMap[cfgVal.Field(i)] =
|
||||
fs.Int(jsonTagToFlagName(tag.Get("json")), 0, tag.Get("desc"))
|
||||
case reflect.Uint64:
|
||||
fsValMap[cfgVal.Field(i)] =
|
||||
fs.Uint64(jsonTagToFlagName(tag.Get("json")), 0, tag.Get("desc"))
|
||||
case reflect.Bool:
|
||||
var bf optionalBoolFlag
|
||||
fs.Var(&bf, jsonTagToFlagName(tag.Get("json")), tag.Get("desc"))
|
||||
fsValMap[cfgVal.Field(i)] = &bf
|
||||
}
|
||||
type serverConfig struct {
|
||||
Listen string `json:"listen"`
|
||||
CertFile string `json:"cert"`
|
||||
KeyFile string `json:"key"`
|
||||
// Optional below
|
||||
UpMbps int `json:"up_mbps"`
|
||||
DownMbps int `json:"down_mbps"`
|
||||
DisableUDP bool `json:"disable_udp"`
|
||||
ACL string `json:"acl"`
|
||||
Obfs string `json:"obfs"`
|
||||
Auth struct {
|
||||
Mode string `json:"mode"`
|
||||
Config interface{} `json:"config"`
|
||||
} `json:"auth"`
|
||||
ReceiveWindowConn uint64 `json:"recv_window_conn"`
|
||||
ReceiveWindowClient uint64 `json:"recv_window_client"`
|
||||
MaxConnClient int `json:"max_conn_client"`
|
||||
}
|
||||
|
||||
func (c *serverConfig) Check() error {
|
||||
if len(c.Listen) == 0 {
|
||||
return errors.New("no listen address")
|
||||
}
|
||||
configFile := fs.String("config", "", "Configuration file")
|
||||
// Parse
|
||||
if err := fs.Parse(args); err != nil {
|
||||
os.Exit(1)
|
||||
if len(c.CertFile) == 0 || len(c.KeyFile) == 0 {
|
||||
return errors.New("TLS cert or key not provided")
|
||||
}
|
||||
// Put together the config
|
||||
if len(*configFile) > 0 {
|
||||
cb, err := ioutil.ReadFile(*configFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := json.Unmarshal(cb, cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.UpMbps < 0 || c.DownMbps < 0 {
|
||||
return errors.New("invalid speed")
|
||||
}
|
||||
// Flags override config from file
|
||||
for field, val := range fsValMap {
|
||||
switch v := val.(type) {
|
||||
case *string:
|
||||
if len(*v) > 0 {
|
||||
field.SetString(*v)
|
||||
}
|
||||
case *int:
|
||||
if *v != 0 {
|
||||
field.SetInt(int64(*v))
|
||||
}
|
||||
case *uint64:
|
||||
if *v != 0 {
|
||||
field.SetUint(*v)
|
||||
}
|
||||
case *optionalBoolFlag:
|
||||
if v.Exists {
|
||||
field.SetBool(v.Value)
|
||||
}
|
||||
}
|
||||
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
|
||||
(c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) {
|
||||
return errors.New("invalid receive window size")
|
||||
}
|
||||
if c.MaxConnClient < 0 {
|
||||
return errors.New("invalid max connections per client")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func jsonTagToFlagName(tag string) string {
|
||||
return strings.ReplaceAll(tag, "_", "-")
|
||||
func (c *serverConfig) String() string {
|
||||
return fmt.Sprintf("%+v", *c)
|
||||
}
|
||||
|
||||
type clientConfig struct {
|
||||
Server string `json:"server"`
|
||||
UpMbps int `json:"up_mbps"`
|
||||
DownMbps int `json:"down_mbps"`
|
||||
// Optional below
|
||||
SOCKS5 struct {
|
||||
Listen string `json:"listen"`
|
||||
Timeout int `json:"timeout"`
|
||||
DisableUDP bool `json:"disable_udp"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
} `json:"socks5"`
|
||||
HTTP struct {
|
||||
Listen string `json:"listen"`
|
||||
Timeout int `json:"timeout"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password"`
|
||||
Cert string `json:"cert"`
|
||||
Key string `json:"key"`
|
||||
} `json:"http"`
|
||||
Relay struct {
|
||||
Listen string `json:"listen"`
|
||||
Remote string `json:"remote"`
|
||||
Timeout int `json:"timeout"`
|
||||
} `json:"relay"`
|
||||
ACL string `json:"acl"`
|
||||
Obfs string `json:"obfs"`
|
||||
Auth []byte `json:"auth"`
|
||||
AuthString string `json:"auth_str"`
|
||||
Insecure bool `json:"insecure"`
|
||||
CustomCA string `json:"ca"`
|
||||
ReceiveWindowConn uint64 `json:"recv_window_conn"`
|
||||
ReceiveWindow uint64 `json:"recv_window"`
|
||||
}
|
||||
|
||||
func (c *clientConfig) Check() error {
|
||||
if len(c.SOCKS5.Listen) == 0 && len(c.HTTP.Listen) == 0 && len(c.Relay.Listen) == 0 {
|
||||
return errors.New("no SOCKS5, HTTP or relay listen address")
|
||||
}
|
||||
if len(c.Relay.Listen) > 0 && len(c.Relay.Remote) == 0 {
|
||||
return errors.New("no relay remote address")
|
||||
}
|
||||
if c.SOCKS5.Timeout != 0 && c.SOCKS5.Timeout <= 4 {
|
||||
return errors.New("invalid SOCKS5 timeout")
|
||||
}
|
||||
if c.HTTP.Timeout != 0 && c.HTTP.Timeout <= 4 {
|
||||
return errors.New("invalid HTTP timeout")
|
||||
}
|
||||
if c.Relay.Timeout != 0 && c.Relay.Timeout <= 4 {
|
||||
return errors.New("invalid relay timeout")
|
||||
}
|
||||
if len(c.Server) == 0 {
|
||||
return errors.New("no server address")
|
||||
}
|
||||
if c.UpMbps <= 0 || c.DownMbps <= 0 {
|
||||
return errors.New("invalid speed")
|
||||
}
|
||||
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
|
||||
(c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) {
|
||||
return errors.New("invalid receive window size")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientConfig) String() string {
|
||||
return fmt.Sprintf("%+v", *c)
|
||||
}
|
||||
|
83
cmd/main.go
83
cmd/main.go
@@ -1,11 +1,13 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/yosuke-furukawa/json5/encoding/json5"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Injected when compiling
|
||||
@@ -15,12 +17,10 @@ var (
|
||||
appDate = "Unknown"
|
||||
)
|
||||
|
||||
var modeMap = map[string]func(args []string){
|
||||
"relay server": relayServer,
|
||||
"relay client": relayClient,
|
||||
"proxy server": proxyServer,
|
||||
"proxy client": proxyClient,
|
||||
}
|
||||
var (
|
||||
configPath = flag.String("config", "config.json", "Config file")
|
||||
showVersion = flag.Bool("version", false, "Show version")
|
||||
)
|
||||
|
||||
func init() {
|
||||
logrus.SetOutput(os.Stdout)
|
||||
@@ -50,37 +50,68 @@ func init() {
|
||||
TimestampFormat: tsFormat,
|
||||
})
|
||||
}
|
||||
|
||||
flag.Parse()
|
||||
}
|
||||
|
||||
func main() {
|
||||
if len(os.Args) == 2 && strings.ToLower(strings.TrimSpace(os.Args[1])) == "version" {
|
||||
if *showVersion {
|
||||
// Print version and quit
|
||||
fmt.Printf("%-10s%s\n", "Version:", appVersion)
|
||||
fmt.Printf("%-10s%s\n", "Commit:", appCommit)
|
||||
fmt.Printf("%-10s%s\n", "Date:", appDate)
|
||||
return
|
||||
}
|
||||
if len(os.Args) < 3 {
|
||||
fmt.Println()
|
||||
fmt.Printf("Usage: %s MODE SUBMODE [OPTIONS]\n\n"+
|
||||
"Available mode/submode combinations: "+getModes()+"\n"+
|
||||
"Use -h to see the available options for a mode.\n\n", os.Args[0])
|
||||
return
|
||||
cb, err := ioutil.ReadFile(*configPath)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"file": *configPath,
|
||||
"error": err,
|
||||
}).Fatal("Failed to read configuration")
|
||||
}
|
||||
modeStr := fmt.Sprintf("%s %s", strings.ToLower(strings.TrimSpace(os.Args[1])),
|
||||
strings.ToLower(strings.TrimSpace(os.Args[2])))
|
||||
f := modeMap[modeStr]
|
||||
if f != nil {
|
||||
f(os.Args[3:])
|
||||
mode := flag.Arg(0)
|
||||
if strings.EqualFold(mode, "server") {
|
||||
// server mode
|
||||
c, err := parseServerConfig(cb)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"file": *configPath,
|
||||
"error": err,
|
||||
}).Fatal("Failed to parse server configuration")
|
||||
}
|
||||
server(c)
|
||||
} else if len(mode) == 0 || strings.EqualFold(mode, "client") {
|
||||
// client mode
|
||||
c, err := parseClientConfig(cb)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"file": *configPath,
|
||||
"error": err,
|
||||
}).Fatal("Failed to parse client configuration")
|
||||
}
|
||||
client(c)
|
||||
} else {
|
||||
fmt.Println("Invalid mode:", modeStr)
|
||||
// invalid
|
||||
fmt.Println()
|
||||
fmt.Printf("Usage: %s MODE [OPTIONS]\n\n"+
|
||||
"Available modes: server, client\n\n", os.Args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func getModes() string {
|
||||
modes := make([]string, 0, len(modeMap))
|
||||
for mode := range modeMap {
|
||||
modes = append(modes, mode)
|
||||
func parseServerConfig(cb []byte) (*serverConfig, error) {
|
||||
var c serverConfig
|
||||
err := json5.Unmarshal(cb, &c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return strings.Join(modes, ", ")
|
||||
return &c, c.Check()
|
||||
}
|
||||
|
||||
func parseClientConfig(cb []byte) (*clientConfig, error) {
|
||||
var c clientConfig
|
||||
err := json5.Unmarshal(cb, &c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &c, c.Check()
|
||||
}
|
||||
|
@@ -1,99 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const proxyTLSProtocol = "hysteria-proxy"
|
||||
|
||||
type proxyClientConfig struct {
|
||||
SOCKS5Addr string `json:"socks5_addr" desc:"SOCKS5 listen address"`
|
||||
SOCKS5Timeout int `json:"socks5_timeout" desc:"SOCKS5 connection timeout in seconds"`
|
||||
SOCKS5DisableUDP bool `json:"socks5_disable_udp" desc:"Disable SOCKS5 UDP support"`
|
||||
SOCKS5User string `json:"socks5_user" desc:"SOCKS5 auth username"`
|
||||
SOCKS5Password string `json:"socks5_password" desc:"SOCKS5 auth password"`
|
||||
HTTPAddr string `json:"http_addr" desc:"HTTP listen address"`
|
||||
HTTPTimeout int `json:"http_timeout" desc:"HTTP connection timeout in seconds"`
|
||||
HTTPUser string `json:"http_user" desc:"HTTP basic auth username"`
|
||||
HTTPPassword string `json:"http_password" desc:"HTTP basic auth password"`
|
||||
HTTPSCert string `json:"https_cert" desc:"HTTPS certificate file"`
|
||||
HTTPSKey string `json:"https_key" desc:"HTTPS key file"`
|
||||
ACLFile string `json:"acl" desc:"Access control list"`
|
||||
ServerAddr string `json:"server" desc:"Server address"`
|
||||
Username string `json:"username" desc:"Authentication username"`
|
||||
Password string `json:"password" desc:"Authentication password"`
|
||||
Insecure bool `json:"insecure" desc:"Ignore TLS certificate errors"`
|
||||
CustomCAFile string `json:"ca" desc:"Specify a trusted CA file"`
|
||||
UpMbps int `json:"up_mbps" desc:"Upload speed in Mbps"`
|
||||
DownMbps int `json:"down_mbps" desc:"Download speed in Mbps"`
|
||||
ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"`
|
||||
ReceiveWindow uint64 `json:"recv_window" desc:"Max receive window size"`
|
||||
Obfs string `json:"obfs" desc:"Obfuscation key"`
|
||||
}
|
||||
|
||||
func (c *proxyClientConfig) Check() error {
|
||||
if len(c.SOCKS5Addr) == 0 && len(c.HTTPAddr) == 0 {
|
||||
return errors.New("no SOCKS5 or HTTP listen address")
|
||||
}
|
||||
if c.SOCKS5Timeout != 0 && c.SOCKS5Timeout <= 4 {
|
||||
return errors.New("invalid SOCKS5 timeout")
|
||||
}
|
||||
if c.HTTPTimeout != 0 && c.HTTPTimeout <= 4 {
|
||||
return errors.New("invalid HTTP timeout")
|
||||
}
|
||||
if len(c.ServerAddr) == 0 {
|
||||
return errors.New("no server address")
|
||||
}
|
||||
if c.UpMbps <= 0 || c.DownMbps <= 0 {
|
||||
return errors.New("invalid speed")
|
||||
}
|
||||
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
|
||||
(c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) {
|
||||
return errors.New("invalid receive window size")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *proxyClientConfig) String() string {
|
||||
return fmt.Sprintf("%+v", *c)
|
||||
}
|
||||
|
||||
type proxyServerConfig struct {
|
||||
ListenAddr string `json:"listen" desc:"Server listen address"`
|
||||
DisableUDP bool `json:"disable_udp" desc:"Disable UDP support"`
|
||||
ACLFile string `json:"acl" desc:"Access control list"`
|
||||
CertFile string `json:"cert" desc:"TLS certificate file"`
|
||||
KeyFile string `json:"key" desc:"TLS key file"`
|
||||
AuthFile string `json:"auth" desc:"Authentication file"`
|
||||
UpMbps int `json:"up_mbps" desc:"Max upload speed per client in Mbps"`
|
||||
DownMbps int `json:"down_mbps" desc:"Max download speed per client in Mbps"`
|
||||
ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"`
|
||||
ReceiveWindowClient uint64 `json:"recv_window_client" desc:"Max receive window size per client"`
|
||||
MaxConnClient int `json:"max_conn_client" desc:"Max simultaneous connections allowed per client"`
|
||||
Obfs string `json:"obfs" desc:"Obfuscation key"`
|
||||
}
|
||||
|
||||
func (c *proxyServerConfig) Check() error {
|
||||
if len(c.ListenAddr) == 0 {
|
||||
return errors.New("no listen address")
|
||||
}
|
||||
if len(c.CertFile) == 0 || len(c.KeyFile) == 0 {
|
||||
return errors.New("TLS cert or key not provided")
|
||||
}
|
||||
if c.UpMbps < 0 || c.DownMbps < 0 {
|
||||
return errors.New("invalid speed")
|
||||
}
|
||||
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
|
||||
(c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) {
|
||||
return errors.New("invalid receive window size")
|
||||
}
|
||||
if c.MaxConnClient < 0 {
|
||||
return errors.New("invalid max connections per client")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *proxyServerConfig) String() string {
|
||||
return fmt.Sprintf("%+v", *c)
|
||||
}
|
@@ -1,298 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tobyxdd/hysteria/pkg/acl"
|
||||
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
|
||||
"github.com/tobyxdd/hysteria/pkg/core"
|
||||
"github.com/tobyxdd/hysteria/pkg/obfs"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func proxyServer(args []string) {
|
||||
var config proxyServerConfig
|
||||
err := loadConfig(&config, args)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("Unable to load configuration")
|
||||
}
|
||||
if err := config.Check(); err != nil {
|
||||
logrus.WithField("error", err).Fatal("Configuration error")
|
||||
}
|
||||
logrus.WithField("config", config.String()).Info("Configuration loaded")
|
||||
// Load cert
|
||||
cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"cert": config.CertFile,
|
||||
"key": config.KeyFile,
|
||||
}).Fatal("Unable to load the certificate")
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
NextProtos: []string{proxyTLSProtocol},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
|
||||
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient,
|
||||
MaxIncomingStreams: int64(config.MaxConnClient),
|
||||
KeepAlive: true,
|
||||
}
|
||||
if quicConfig.MaxReceiveStreamFlowControlWindow == 0 {
|
||||
quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow
|
||||
}
|
||||
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
|
||||
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
if quicConfig.MaxIncomingStreams == 0 {
|
||||
quicConfig.MaxIncomingStreams = DefaultMaxIncomingStreams
|
||||
}
|
||||
|
||||
if len(config.AuthFile) == 0 {
|
||||
logrus.Warn("No authentication configured, this server can be used by anyone")
|
||||
}
|
||||
|
||||
var obfuscator core.Obfuscator
|
||||
if len(config.Obfs) > 0 {
|
||||
obfuscator = obfs.XORObfuscator(config.Obfs)
|
||||
}
|
||||
|
||||
var aclEngine *acl.Engine
|
||||
if len(config.ACLFile) > 0 {
|
||||
aclEngine, err = acl.LoadFromFile(config.ACLFile)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"file": config.ACLFile,
|
||||
}).Fatal("Unable to parse ACL")
|
||||
}
|
||||
aclEngine.DefaultAction = acl.ActionDirect
|
||||
}
|
||||
|
||||
server, err := core.NewServer(config.ListenAddr, tlsConfig, quicConfig,
|
||||
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
|
||||
func(refBPS uint64) congestion.CongestionControl {
|
||||
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
|
||||
},
|
||||
obfuscator,
|
||||
func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) {
|
||||
if len(config.AuthFile) == 0 {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"addr": addr.String(),
|
||||
"username": username,
|
||||
"up": sSend / mbpsToBps,
|
||||
"down": sRecv / mbpsToBps,
|
||||
}).Info("Client connected")
|
||||
return core.AuthResultSuccess, ""
|
||||
} else {
|
||||
// Need auth
|
||||
ok, err := checkAuth(config.AuthFile, username, password)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err.Error(),
|
||||
"addr": addr.String(),
|
||||
"username": username,
|
||||
}).Error("Client authentication error")
|
||||
return core.AuthResultInternalError, "Server auth error"
|
||||
}
|
||||
if ok {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"addr": addr.String(),
|
||||
"username": username,
|
||||
"up": sSend / mbpsToBps,
|
||||
"down": sRecv / mbpsToBps,
|
||||
}).Info("Client authenticated")
|
||||
return core.AuthResultSuccess, ""
|
||||
} else {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"addr": addr.String(),
|
||||
"username": username,
|
||||
"up": sSend / mbpsToBps,
|
||||
"down": sRecv / mbpsToBps,
|
||||
}).Info("Client rejected due to invalid credential")
|
||||
return core.AuthResultInvalidCred, "Invalid credential"
|
||||
}
|
||||
}
|
||||
},
|
||||
func(addr net.Addr, username string, err error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err.Error(),
|
||||
"addr": addr.String(),
|
||||
"username": username,
|
||||
}).Info("Client disconnected")
|
||||
},
|
||||
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
|
||||
packet := reqType == core.ConnectionTypePacket
|
||||
if packet && config.DisableUDP {
|
||||
return core.ConnectResultBlocked, "UDP disabled", nil
|
||||
}
|
||||
host, port, err := net.SplitHostPort(reqAddr)
|
||||
if err != nil {
|
||||
return core.ConnectResultFailed, err.Error(), nil
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
// IP request, clear host for ACL engine
|
||||
host = ""
|
||||
}
|
||||
action, arg := acl.ActionDirect, ""
|
||||
if aclEngine != nil {
|
||||
action, arg = aclEngine.Lookup(host, ip)
|
||||
}
|
||||
switch action {
|
||||
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
|
||||
if !packet {
|
||||
// TCP
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"action": "direct",
|
||||
"username": username,
|
||||
"src": addr.String(),
|
||||
"dst": reqAddr,
|
||||
}).Debug("New TCP request")
|
||||
conn, err := net.DialTimeout("tcp", reqAddr, dialTimeout)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"dst": reqAddr,
|
||||
}).Error("TCP error")
|
||||
return core.ConnectResultFailed, err.Error(), nil
|
||||
}
|
||||
return core.ConnectResultSuccess, "", conn
|
||||
} else {
|
||||
// UDP
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"action": "direct",
|
||||
"username": username,
|
||||
"src": addr.String(),
|
||||
"dst": reqAddr,
|
||||
}).Debug("New UDP request")
|
||||
conn, err := net.Dial("udp", reqAddr)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"dst": reqAddr,
|
||||
}).Error("UDP error")
|
||||
return core.ConnectResultFailed, err.Error(), nil
|
||||
}
|
||||
return core.ConnectResultSuccess, "", conn
|
||||
}
|
||||
case acl.ActionBlock:
|
||||
if !packet {
|
||||
// TCP
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"action": "block",
|
||||
"username": username,
|
||||
"src": addr.String(),
|
||||
"dst": reqAddr,
|
||||
}).Debug("New TCP request")
|
||||
return core.ConnectResultBlocked, "blocked by ACL", nil
|
||||
} else {
|
||||
// UDP
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"action": "block",
|
||||
"username": username,
|
||||
"src": addr.String(),
|
||||
"dst": reqAddr,
|
||||
}).Debug("New UDP request")
|
||||
return core.ConnectResultBlocked, "blocked by ACL", nil
|
||||
}
|
||||
case acl.ActionHijack:
|
||||
hijackAddr := net.JoinHostPort(arg, port)
|
||||
if !packet {
|
||||
// TCP
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"action": "hijack",
|
||||
"username": username,
|
||||
"src": addr.String(),
|
||||
"dst": reqAddr,
|
||||
"rdst": arg,
|
||||
}).Debug("New TCP request")
|
||||
conn, err := net.DialTimeout("tcp", hijackAddr, dialTimeout)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"dst": hijackAddr,
|
||||
}).Error("TCP error")
|
||||
return core.ConnectResultFailed, err.Error(), nil
|
||||
}
|
||||
return core.ConnectResultSuccess, "", conn
|
||||
} else {
|
||||
// UDP
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"action": "hijack",
|
||||
"username": username,
|
||||
"src": addr.String(),
|
||||
"dst": reqAddr,
|
||||
"rdst": arg,
|
||||
}).Debug("New UDP request")
|
||||
conn, err := net.Dial("udp", hijackAddr)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"dst": hijackAddr,
|
||||
}).Error("UDP error")
|
||||
return core.ConnectResultFailed, err.Error(), nil
|
||||
}
|
||||
return core.ConnectResultSuccess, "", conn
|
||||
}
|
||||
default:
|
||||
return core.ConnectResultFailed, "server ACL error", nil
|
||||
}
|
||||
},
|
||||
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
|
||||
packet := reqType == core.ConnectionTypePacket
|
||||
if !packet {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"username": username,
|
||||
"src": addr.String(),
|
||||
"dst": reqAddr,
|
||||
}).Debug("TCP request closed")
|
||||
} else {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"username": username,
|
||||
"src": addr.String(),
|
||||
"dst": reqAddr,
|
||||
}).Debug("UDP request closed")
|
||||
}
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("Server initialization failed")
|
||||
}
|
||||
defer server.Close()
|
||||
logrus.WithField("addr", config.ListenAddr).Info("Server up and running")
|
||||
|
||||
err = server.Serve()
|
||||
logrus.WithField("error", err).Fatal("Server shutdown")
|
||||
}
|
||||
|
||||
func checkAuth(authFile, username, password string) (bool, error) {
|
||||
f, err := os.Open(authFile)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
pair := strings.Fields(scanner.Text())
|
||||
if len(pair) != 2 {
|
||||
// Invalid format
|
||||
continue
|
||||
}
|
||||
if username == pair[0] && password == pair[1] {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
@@ -1,119 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/sirupsen/logrus"
|
||||
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
|
||||
"github.com/tobyxdd/hysteria/pkg/core"
|
||||
"github.com/tobyxdd/hysteria/pkg/obfs"
|
||||
"github.com/tobyxdd/hysteria/pkg/utils"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os/user"
|
||||
)
|
||||
|
||||
func relayClient(args []string) {
|
||||
var config relayClientConfig
|
||||
err := loadConfig(&config, args)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("Unable to load configuration")
|
||||
}
|
||||
if err := config.Check(); err != nil {
|
||||
logrus.WithField("error", err).Fatal("Configuration error")
|
||||
}
|
||||
if len(config.Name) == 0 {
|
||||
usr, err := user.Current()
|
||||
if err == nil {
|
||||
config.Name = usr.Name
|
||||
}
|
||||
}
|
||||
logrus.WithField("config", config.String()).Info("Configuration loaded")
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: config.Insecure,
|
||||
NextProtos: []string{relayTLSProtocol},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
// Load CA
|
||||
if len(config.CustomCAFile) > 0 {
|
||||
bs, err := ioutil.ReadFile(config.CustomCAFile)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"file": config.CustomCAFile,
|
||||
}).Fatal("Unable to load CA file")
|
||||
}
|
||||
cp := x509.NewCertPool()
|
||||
if !cp.AppendCertsFromPEM(bs) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"file": config.CustomCAFile,
|
||||
}).Fatal("Unable to parse CA file")
|
||||
}
|
||||
tlsConfig.RootCAs = cp
|
||||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
|
||||
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindow,
|
||||
KeepAlive: true,
|
||||
}
|
||||
if quicConfig.MaxReceiveStreamFlowControlWindow == 0 {
|
||||
quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow
|
||||
}
|
||||
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
|
||||
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
|
||||
var obfuscator core.Obfuscator
|
||||
if len(config.Obfs) > 0 {
|
||||
obfuscator = obfs.XORObfuscator(config.Obfs)
|
||||
}
|
||||
|
||||
client, err := core.NewClient(config.ServerAddr, config.Name, "", tlsConfig, quicConfig,
|
||||
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
|
||||
func(refBPS uint64) congestion.CongestionControl {
|
||||
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
|
||||
}, obfuscator)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("Client initialization failed")
|
||||
}
|
||||
defer client.Close()
|
||||
logrus.WithField("addr", config.ServerAddr).Info("Connected")
|
||||
|
||||
listener, err := net.Listen("tcp", config.ListenAddr)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("TCP listen failed")
|
||||
}
|
||||
defer listener.Close()
|
||||
logrus.WithField("addr", listener.Addr().String()).Info("TCP server listening")
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("TCP accept failed")
|
||||
}
|
||||
go relayClientHandleConn(conn, client)
|
||||
}
|
||||
}
|
||||
|
||||
func relayClientHandleConn(conn net.Conn, client *core.Client) {
|
||||
logrus.WithField("src", conn.RemoteAddr().String()).Debug("New connection")
|
||||
var closeErr error
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": closeErr,
|
||||
"src": conn.RemoteAddr().String(),
|
||||
}).Debug("Connection closed")
|
||||
}()
|
||||
rwc, err := client.Dial(false, "")
|
||||
if err != nil {
|
||||
closeErr = err
|
||||
return
|
||||
}
|
||||
defer rwc.Close()
|
||||
closeErr = utils.PipePair(conn, rwc, nil, nil)
|
||||
}
|
@@ -1,82 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const relayTLSProtocol = "hysteria-relay"
|
||||
|
||||
type relayClientConfig struct {
|
||||
ListenAddr string `json:"listen" desc:"TCP listen address"`
|
||||
ServerAddr string `json:"server" desc:"Server address"`
|
||||
Name string `json:"name" desc:"Client name presented to the server"`
|
||||
Insecure bool `json:"insecure" desc:"Ignore TLS certificate errors"`
|
||||
CustomCAFile string `json:"ca" desc:"Specify a trusted CA file"`
|
||||
UpMbps int `json:"up_mbps" desc:"Upload speed in Mbps"`
|
||||
DownMbps int `json:"down_mbps" desc:"Download speed in Mbps"`
|
||||
ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"`
|
||||
ReceiveWindow uint64 `json:"recv_window" desc:"Max receive window size"`
|
||||
Obfs string `json:"obfs" desc:"Obfuscation key"`
|
||||
}
|
||||
|
||||
func (c *relayClientConfig) Check() error {
|
||||
if len(c.ListenAddr) == 0 {
|
||||
return errors.New("no listen address")
|
||||
}
|
||||
if len(c.ServerAddr) == 0 {
|
||||
return errors.New("no server address")
|
||||
}
|
||||
if c.UpMbps <= 0 || c.DownMbps <= 0 {
|
||||
return errors.New("invalid speed")
|
||||
}
|
||||
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
|
||||
(c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) {
|
||||
return errors.New("invalid receive window size")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *relayClientConfig) String() string {
|
||||
return fmt.Sprintf("%+v", *c)
|
||||
}
|
||||
|
||||
type relayServerConfig struct {
|
||||
ListenAddr string `json:"listen" desc:"Server listen address"`
|
||||
RemoteAddr string `json:"remote" desc:"Remote relay address"`
|
||||
CertFile string `json:"cert" desc:"TLS certificate file"`
|
||||
KeyFile string `json:"key" desc:"TLS key file"`
|
||||
UpMbps int `json:"up_mbps" desc:"Max upload speed per client in Mbps"`
|
||||
DownMbps int `json:"down_mbps" desc:"Max download speed per client in Mbps"`
|
||||
ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"`
|
||||
ReceiveWindowClient uint64 `json:"recv_window_client" desc:"Max receive window size per client"`
|
||||
MaxConnClient int `json:"max_conn_client" desc:"Max simultaneous connections allowed per client"`
|
||||
Obfs string `json:"obfs" desc:"Obfuscation key"`
|
||||
}
|
||||
|
||||
func (c *relayServerConfig) Check() error {
|
||||
if len(c.ListenAddr) == 0 {
|
||||
return errors.New("no listen address")
|
||||
}
|
||||
if len(c.RemoteAddr) == 0 {
|
||||
return errors.New("no remote address")
|
||||
}
|
||||
if len(c.CertFile) == 0 || len(c.KeyFile) == 0 {
|
||||
return errors.New("TLS cert or key not provided")
|
||||
}
|
||||
if c.UpMbps < 0 || c.DownMbps < 0 {
|
||||
return errors.New("invalid speed")
|
||||
}
|
||||
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
|
||||
(c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) {
|
||||
return errors.New("invalid receive window size")
|
||||
}
|
||||
if c.MaxConnClient < 0 {
|
||||
return errors.New("invalid max connections per client")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *relayServerConfig) String() string {
|
||||
return fmt.Sprintf("%+v", *c)
|
||||
}
|
@@ -1,121 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/sirupsen/logrus"
|
||||
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
|
||||
"github.com/tobyxdd/hysteria/pkg/core"
|
||||
"github.com/tobyxdd/hysteria/pkg/obfs"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
func relayServer(args []string) {
|
||||
var config relayServerConfig
|
||||
err := loadConfig(&config, args)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("Unable to load configuration")
|
||||
}
|
||||
if err := config.Check(); err != nil {
|
||||
logrus.WithField("error", err).Fatal("Configuration error")
|
||||
}
|
||||
logrus.WithField("config", config.String()).Info("Configuration loaded")
|
||||
// Load cert
|
||||
cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"cert": config.CertFile,
|
||||
"key": config.KeyFile,
|
||||
}).Fatal("Unable to load the certificate")
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
NextProtos: []string{relayTLSProtocol},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
quicConfig := &quic.Config{
|
||||
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
|
||||
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient,
|
||||
MaxIncomingStreams: int64(config.MaxConnClient),
|
||||
KeepAlive: true,
|
||||
}
|
||||
if quicConfig.MaxReceiveStreamFlowControlWindow == 0 {
|
||||
quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow
|
||||
}
|
||||
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
|
||||
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
if quicConfig.MaxIncomingStreams == 0 {
|
||||
quicConfig.MaxIncomingStreams = DefaultMaxIncomingStreams
|
||||
}
|
||||
|
||||
var obfuscator core.Obfuscator
|
||||
if len(config.Obfs) > 0 {
|
||||
obfuscator = obfs.XORObfuscator(config.Obfs)
|
||||
}
|
||||
|
||||
server, err := core.NewServer(config.ListenAddr, tlsConfig, quicConfig,
|
||||
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
|
||||
func(refBPS uint64) congestion.CongestionControl {
|
||||
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
|
||||
},
|
||||
obfuscator,
|
||||
func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) {
|
||||
// No authentication logic in relay, just log username and speed
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"addr": addr.String(),
|
||||
"username": username,
|
||||
"up": sSend / mbpsToBps,
|
||||
"down": sRecv / mbpsToBps,
|
||||
}).Info("Client connected")
|
||||
return core.AuthResultSuccess, ""
|
||||
},
|
||||
func(addr net.Addr, username string, err error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err.Error(),
|
||||
"addr": addr.String(),
|
||||
"username": username,
|
||||
}).Info("Client disconnected")
|
||||
},
|
||||
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
|
||||
packet := reqType == core.ConnectionTypePacket
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"username": username,
|
||||
"src": addr.String(),
|
||||
"id": id,
|
||||
}).Debug("New stream")
|
||||
if packet {
|
||||
return core.ConnectResultBlocked, "unsupported", nil
|
||||
}
|
||||
conn, err := net.DialTimeout("tcp", config.RemoteAddr, dialTimeout)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"dst": config.RemoteAddr,
|
||||
}).Error("TCP error")
|
||||
return core.ConnectResultFailed, err.Error(), nil
|
||||
}
|
||||
return core.ConnectResultSuccess, "", conn
|
||||
},
|
||||
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"username": username,
|
||||
"src": addr.String(),
|
||||
"id": id,
|
||||
}).Debug("Stream closed")
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("Server initialization failed")
|
||||
}
|
||||
defer server.Close()
|
||||
logrus.WithField("addr", config.ListenAddr).Info("Server up and running")
|
||||
|
||||
err = server.Serve()
|
||||
logrus.WithField("error", err).Fatal("Server shutdown")
|
||||
}
|
99
cmd/server.go
Normal file
99
cmd/server.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tobyxdd/hysteria/pkg/acl"
|
||||
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
|
||||
"github.com/tobyxdd/hysteria/pkg/core"
|
||||
"github.com/tobyxdd/hysteria/pkg/obfs"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func server(config *serverConfig) {
|
||||
logrus.WithField("config", config.String()).Info("Server configuration loaded")
|
||||
// Load cert
|
||||
cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"cert": config.CertFile,
|
||||
"key": config.KeyFile,
|
||||
}).Fatal("Failed to load the certificate")
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
NextProtos: []string{tlsProtocolName},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
// QUIC config
|
||||
quicConfig := &quic.Config{
|
||||
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
|
||||
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient,
|
||||
MaxIncomingStreams: int64(config.MaxConnClient),
|
||||
KeepAlive: true,
|
||||
}
|
||||
if quicConfig.MaxReceiveStreamFlowControlWindow == 0 {
|
||||
quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow
|
||||
}
|
||||
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
|
||||
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
|
||||
}
|
||||
if quicConfig.MaxIncomingStreams == 0 {
|
||||
quicConfig.MaxIncomingStreams = DefaultMaxIncomingStreams
|
||||
}
|
||||
// Auth
|
||||
var authFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
|
||||
if len(config.Auth.Mode) == 0 || strings.EqualFold(config.Auth.Mode, "none") {
|
||||
logrus.Warn("No authentication configured")
|
||||
authFunc = func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) {
|
||||
return true, "Welcome"
|
||||
}
|
||||
} else {
|
||||
// TODO
|
||||
logrus.WithField("mode", config.Auth.Mode).Fatal("Unsupported authentication mode")
|
||||
}
|
||||
// Obfuscator
|
||||
var obfuscator core.Obfuscator
|
||||
if len(config.Obfs) > 0 {
|
||||
obfuscator = obfs.XORObfuscator(config.Obfs)
|
||||
}
|
||||
// ACL
|
||||
var aclEngine *acl.Engine
|
||||
if len(config.ACL) > 0 {
|
||||
aclEngine, err = acl.LoadFromFile(config.ACL)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"error": err,
|
||||
"file": config.ACL,
|
||||
}).Fatal("Failed to parse ACL")
|
||||
}
|
||||
aclEngine.DefaultAction = acl.ActionDirect
|
||||
}
|
||||
// Server
|
||||
server, err := core.NewServer(config.Listen, tlsConfig, quicConfig,
|
||||
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
|
||||
func(refBPS uint64) congestion.CongestionControl {
|
||||
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
|
||||
}, aclEngine, obfuscator, authFunc, func(addr net.Addr, auth []byte, udp bool, reqAddr string) {
|
||||
if !udp {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"src": addr.String(),
|
||||
"dst": reqAddr,
|
||||
}).Debug("New TCP request")
|
||||
} else {
|
||||
// TODO
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("Failed to initialize server")
|
||||
}
|
||||
defer server.Close()
|
||||
logrus.WithField("addr", config.Listen).Info("Server up and running")
|
||||
|
||||
err = server.Serve()
|
||||
logrus.WithField("error", err).Fatal("Server shutdown")
|
||||
}
|
28
cmd/utils.go
28
cmd/utils.go
@@ -1,28 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type optionalBoolFlag struct {
|
||||
Exists bool
|
||||
Value bool
|
||||
}
|
||||
|
||||
func (flag *optionalBoolFlag) String() string {
|
||||
return strconv.FormatBool(flag.Value)
|
||||
}
|
||||
|
||||
func (flag *optionalBoolFlag) Set(s string) error {
|
||||
v, err := strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
flag.Exists = true
|
||||
flag.Value = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (flag *optionalBoolFlag) IsBoolFlag() bool {
|
||||
return true
|
||||
}
|
2
go.mod
2
go.mod
@@ -8,11 +8,13 @@ require (
|
||||
github.com/golang/protobuf v1.4.2
|
||||
github.com/hashicorp/golang-lru v0.5.4
|
||||
github.com/lucas-clemente/quic-go v0.19.3
|
||||
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/sirupsen/logrus v1.6.0
|
||||
github.com/txthinking/runnergroup v0.0.0-20200327135940-540a793bb997 // indirect
|
||||
github.com/txthinking/socks5 v0.0.0-20200327133705-caf148ab5e9d
|
||||
github.com/txthinking/x v0.0.0-20200330144832-5ad2416896a9 // indirect
|
||||
github.com/yosuke-furukawa/json5 v0.1.1
|
||||
)
|
||||
|
||||
replace github.com/lucas-clemente/quic-go => github.com/tobyxdd/quic-go v0.19.4-0.20210127052624-0ecb862c82b5
|
||||
|
4
go.sum
4
go.sum
@@ -87,6 +87,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 h1:EnfXoSqDfSNJv0VBNqY/88RNnhSGYkrHaO0mmFGbVsc=
|
||||
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40/go.mod h1:vy1vK6wD6j7xX6O6hXe621WabdtNkou2h7uRtTfRMyg=
|
||||
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
|
||||
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc=
|
||||
@@ -166,6 +168,8 @@ github.com/txthinking/x v0.0.0-20200330144832-5ad2416896a9 h1:ngJOce33YJJT1PFTfC
|
||||
github.com/txthinking/x v0.0.0-20200330144832-5ad2416896a9/go.mod h1:WgqbSEmUYSjEV3B1qmee/PpP2NYEz4bL9/+mF1ma+s4=
|
||||
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
|
||||
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
||||
github.com/yosuke-furukawa/json5 v0.1.1 h1:0F9mNwTvOuDNH243hoPqvf+dxa5QsKnZzU20uNsh3ZI=
|
||||
github.com/yosuke-furukawa/json5 v0.1.1/go.mod h1:sw49aWDqNdRJ6DYUtIQiaA3xyj2IL9tjeNYmX2ixwcU=
|
||||
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
|
||||
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
|
||||
go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE=
|
||||
|
@@ -6,44 +6,45 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/tobyxdd/hysteria/pkg/core/pb"
|
||||
"github.com/tobyxdd/hysteria/pkg/utils"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"github.com/lunixbochs/struc"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClosed = errors.New("client closed")
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
inboundBytes, outboundBytes uint64 // atomic
|
||||
type CongestionFactory func(refBPS uint64) congestion.CongestionControl
|
||||
|
||||
reconnectMutex sync.Mutex
|
||||
closed bool
|
||||
quicSession quic.Session
|
||||
serverAddr string
|
||||
username, password string
|
||||
tlsConfig *tls.Config
|
||||
quicConfig *quic.Config
|
||||
sendBPS, recvBPS uint64
|
||||
congestionFactory CongestionFactory
|
||||
obfuscator Obfuscator
|
||||
type Client struct {
|
||||
serverAddr string
|
||||
sendBPS, recvBPS uint64
|
||||
auth []byte
|
||||
congestionFactory CongestionFactory
|
||||
obfuscator Obfuscator
|
||||
|
||||
tlsConfig *tls.Config
|
||||
quicConfig *quic.Config
|
||||
|
||||
quicSession quic.Session
|
||||
reconnectMutex sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewClient(serverAddr string, username string, password string, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||
func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator Obfuscator) (*Client, error) {
|
||||
c := &Client{
|
||||
serverAddr: serverAddr,
|
||||
username: username,
|
||||
password: password,
|
||||
tlsConfig: tlsConfig,
|
||||
quicConfig: quicConfig,
|
||||
sendBPS: sendBPS,
|
||||
recvBPS: recvBPS,
|
||||
auth: auth,
|
||||
congestionFactory: congestionFactory,
|
||||
obfuscator: obfuscator,
|
||||
tlsConfig: tlsConfig,
|
||||
quicConfig: quicConfig,
|
||||
}
|
||||
if err := c.connectToServer(); err != nil {
|
||||
return nil, err
|
||||
@@ -51,58 +52,6 @@ func NewClient(serverAddr string, username string, password string, tlsConfig *t
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) Dial(packet bool, addr string) (net.Conn, error) {
|
||||
stream, localAddr, remoteAddr, err := c.openStreamWithReconnect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Send request
|
||||
req := &pb.ClientConnectRequest{Address: addr}
|
||||
if packet {
|
||||
req.Type = pb.ConnectionType_Packet
|
||||
} else {
|
||||
req.Type = pb.ConnectionType_Stream
|
||||
}
|
||||
err = writeClientConnectRequest(stream, req)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
resp, err := readServerConnectResponse(stream)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if resp.Result != pb.ConnectResult_CONN_SUCCESS {
|
||||
_ = stream.Close()
|
||||
return nil, fmt.Errorf("server rejected the connection %s (msg: %s)",
|
||||
resp.Result.String(), resp.Message)
|
||||
}
|
||||
connWrap := &utils.QUICStreamWrapperConn{
|
||||
Orig: stream,
|
||||
PseudoLocalAddr: localAddr,
|
||||
PseudoRemoteAddr: remoteAddr,
|
||||
}
|
||||
if packet {
|
||||
return &utils.PacketWrapperConn{Orig: connWrap}, nil
|
||||
} else {
|
||||
return connWrap, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Stats() (uint64, uint64) {
|
||||
return atomic.LoadUint64(&c.inboundBytes), atomic.LoadUint64(&c.outboundBytes)
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.reconnectMutex.Lock()
|
||||
defer c.reconnectMutex.Unlock()
|
||||
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "generic")
|
||||
c.closed = true
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) connectToServer() error {
|
||||
serverUDPAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
|
||||
if err != nil {
|
||||
@@ -124,51 +73,50 @@ func (c *Client) connectToServer() error {
|
||||
return err
|
||||
}
|
||||
// Control stream
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout)
|
||||
ctlStream, err := qs.OpenStreamSync(ctx)
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
|
||||
stream, err := qs.OpenStreamSync(ctx)
|
||||
ctxCancel()
|
||||
if err != nil {
|
||||
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error")
|
||||
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
|
||||
return err
|
||||
}
|
||||
result, msg, err := c.handleControlStream(qs, ctlStream)
|
||||
ok, msg, err := c.handleControlStream(qs, stream)
|
||||
if err != nil {
|
||||
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
|
||||
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
|
||||
return err
|
||||
}
|
||||
if result != pb.AuthResult_AUTH_SUCCESS {
|
||||
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "authentication failure")
|
||||
return fmt.Errorf("authentication failure %s (msg: %s)", result.String(), msg)
|
||||
if !ok {
|
||||
_ = qs.CloseWithError(closeErrorCodeAuth, "auth error")
|
||||
return fmt.Errorf("auth error: %s", msg)
|
||||
}
|
||||
// All good
|
||||
c.quicSession = qs
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (pb.AuthResult, string, error) {
|
||||
err := writeClientAuthRequest(stream, &pb.ClientAuthRequest{
|
||||
Credential: &pb.Credential{
|
||||
Username: c.username,
|
||||
Password: c.password,
|
||||
},
|
||||
Speed: &pb.Speed{
|
||||
SendBps: c.sendBPS,
|
||||
ReceiveBps: c.recvBPS,
|
||||
func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool, string, error) {
|
||||
// Send client hello
|
||||
err := struc.Pack(stream, &clientHello{
|
||||
Rate: transmissionRate{
|
||||
SendBPS: c.sendBPS,
|
||||
RecvBPS: c.recvBPS,
|
||||
},
|
||||
Auth: c.auth,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
return false, "", err
|
||||
}
|
||||
// Response
|
||||
resp, err := readServerAuthResponse(stream)
|
||||
// Receive server hello
|
||||
var sh serverHello
|
||||
err = struc.Unpack(stream, &sh)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
return false, "", err
|
||||
}
|
||||
// Set the congestion accordingly
|
||||
if resp.Result == pb.AuthResult_AUTH_SUCCESS && c.congestionFactory != nil {
|
||||
qs.SetCongestionControl(c.congestionFactory(resp.Speed.ReceiveBps))
|
||||
if sh.OK && c.congestionFactory != nil {
|
||||
qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS))
|
||||
}
|
||||
return resp.Result, resp.Message, nil
|
||||
return true, sh.Message, nil
|
||||
}
|
||||
|
||||
func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, error) {
|
||||
@@ -196,3 +144,81 @@ func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, err
|
||||
stream, err = c.quicSession.OpenStream()
|
||||
return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), err
|
||||
}
|
||||
|
||||
func (c *Client) DialTCP(addr string) (net.Conn, error) {
|
||||
stream, localAddr, remoteAddr, err := c.openStreamWithReconnect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Send request
|
||||
err = struc.Pack(stream, &clientRequest{
|
||||
UDP: false,
|
||||
Address: addr,
|
||||
})
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
var sr serverResponse
|
||||
err = struc.Unpack(stream, &sr)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !sr.OK {
|
||||
_ = stream.Close()
|
||||
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
|
||||
}
|
||||
return &quicConn{
|
||||
Orig: stream,
|
||||
PseudoLocalAddr: localAddr,
|
||||
PseudoRemoteAddr: remoteAddr,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.reconnectMutex.Lock()
|
||||
defer c.reconnectMutex.Unlock()
|
||||
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "")
|
||||
c.closed = true
|
||||
return err
|
||||
}
|
||||
|
||||
type quicConn struct {
|
||||
Orig quic.Stream
|
||||
PseudoLocalAddr net.Addr
|
||||
PseudoRemoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (w *quicConn) Read(b []byte) (n int, err error) {
|
||||
return w.Orig.Read(b)
|
||||
}
|
||||
|
||||
func (w *quicConn) Write(b []byte) (n int, err error) {
|
||||
return w.Orig.Write(b)
|
||||
}
|
||||
|
||||
func (w *quicConn) Close() error {
|
||||
return w.Orig.Close()
|
||||
}
|
||||
|
||||
func (w *quicConn) LocalAddr() net.Addr {
|
||||
return w.PseudoLocalAddr
|
||||
}
|
||||
|
||||
func (w *quicConn) RemoteAddr() net.Addr {
|
||||
return w.PseudoRemoteAddr
|
||||
}
|
||||
|
||||
func (w *quicConn) SetDeadline(t time.Time) error {
|
||||
return w.Orig.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (w *quicConn) SetReadDeadline(t time.Time) error {
|
||||
return w.Orig.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (w *quicConn) SetWriteDeadline(t time.Time) error {
|
||||
return w.Orig.SetWriteDeadline(t)
|
||||
}
|
||||
|
@@ -1,104 +0,0 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/tobyxdd/hysteria/pkg/core/pb"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
closeErrorCodeGeneric = 0
|
||||
closeErrorCodeProtocolFailure = 1
|
||||
)
|
||||
|
||||
func readDataBlock(r io.Reader) ([]byte, error) {
|
||||
var sz uint32
|
||||
if err := binary.Read(r, controlProtocolEndian, &sz); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf := make([]byte, sz)
|
||||
_, err := io.ReadFull(r, buf)
|
||||
return buf, err
|
||||
}
|
||||
|
||||
func writeDataBlock(w io.Writer, data []byte) error {
|
||||
sz := uint32(len(data))
|
||||
if err := binary.Write(w, controlProtocolEndian, &sz); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := w.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func readClientAuthRequest(r io.Reader) (*pb.ClientAuthRequest, error) {
|
||||
bs, err := readDataBlock(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var req pb.ClientAuthRequest
|
||||
err = proto.Unmarshal(bs, &req)
|
||||
return &req, err
|
||||
}
|
||||
|
||||
func writeClientAuthRequest(w io.Writer, req *pb.ClientAuthRequest) error {
|
||||
bs, err := proto.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeDataBlock(w, bs)
|
||||
}
|
||||
|
||||
func readServerAuthResponse(r io.Reader) (*pb.ServerAuthResponse, error) {
|
||||
bs, err := readDataBlock(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resp pb.ServerAuthResponse
|
||||
err = proto.Unmarshal(bs, &resp)
|
||||
return &resp, err
|
||||
}
|
||||
|
||||
func writeServerAuthResponse(w io.Writer, resp *pb.ServerAuthResponse) error {
|
||||
bs, err := proto.Marshal(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeDataBlock(w, bs)
|
||||
}
|
||||
|
||||
func readClientConnectRequest(r io.Reader) (*pb.ClientConnectRequest, error) {
|
||||
bs, err := readDataBlock(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var req pb.ClientConnectRequest
|
||||
err = proto.Unmarshal(bs, &req)
|
||||
return &req, err
|
||||
}
|
||||
|
||||
func writeClientConnectRequest(w io.Writer, req *pb.ClientConnectRequest) error {
|
||||
bs, err := proto.Marshal(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeDataBlock(w, bs)
|
||||
}
|
||||
|
||||
func readServerConnectResponse(r io.Reader) (*pb.ServerConnectResponse, error) {
|
||||
bs, err := readDataBlock(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resp pb.ServerConnectResponse
|
||||
err = proto.Unmarshal(bs, &resp)
|
||||
return &resp, err
|
||||
}
|
||||
|
||||
func writeServerConnectResponse(w io.Writer, resp *pb.ServerConnectResponse) error {
|
||||
bs, err := proto.Marshal(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeDataBlock(w, bs)
|
||||
}
|
@@ -1,440 +0,0 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// source: control.proto
|
||||
|
||||
package pb
|
||||
|
||||
import (
|
||||
fmt "fmt"
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
math "math"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
var _ = proto.Marshal
|
||||
var _ = fmt.Errorf
|
||||
var _ = math.Inf
|
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the proto package it is being compiled against.
|
||||
// A compilation error at this line likely means your copy of the
|
||||
// proto package needs to be updated.
|
||||
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
|
||||
|
||||
type AuthResult int32
|
||||
|
||||
const (
|
||||
AuthResult_AUTH_SUCCESS AuthResult = 0
|
||||
AuthResult_AUTH_INVALID_CRED AuthResult = 1
|
||||
AuthResult_AUTH_INTERNAL_ERROR AuthResult = 2
|
||||
)
|
||||
|
||||
var AuthResult_name = map[int32]string{
|
||||
0: "AUTH_SUCCESS",
|
||||
1: "AUTH_INVALID_CRED",
|
||||
2: "AUTH_INTERNAL_ERROR",
|
||||
}
|
||||
|
||||
var AuthResult_value = map[string]int32{
|
||||
"AUTH_SUCCESS": 0,
|
||||
"AUTH_INVALID_CRED": 1,
|
||||
"AUTH_INTERNAL_ERROR": 2,
|
||||
}
|
||||
|
||||
func (x AuthResult) String() string {
|
||||
return proto.EnumName(AuthResult_name, int32(x))
|
||||
}
|
||||
|
||||
func (AuthResult) EnumDescriptor() ([]byte, []int) {
|
||||
return fileDescriptor_0c5120591600887d, []int{0}
|
||||
}
|
||||
|
||||
type ConnectionType int32
|
||||
|
||||
const (
|
||||
ConnectionType_Stream ConnectionType = 0
|
||||
ConnectionType_Packet ConnectionType = 1
|
||||
)
|
||||
|
||||
var ConnectionType_name = map[int32]string{
|
||||
0: "Stream",
|
||||
1: "Packet",
|
||||
}
|
||||
|
||||
var ConnectionType_value = map[string]int32{
|
||||
"Stream": 0,
|
||||
"Packet": 1,
|
||||
}
|
||||
|
||||
func (x ConnectionType) String() string {
|
||||
return proto.EnumName(ConnectionType_name, int32(x))
|
||||
}
|
||||
|
||||
func (ConnectionType) EnumDescriptor() ([]byte, []int) {
|
||||
return fileDescriptor_0c5120591600887d, []int{1}
|
||||
}
|
||||
|
||||
type ConnectResult int32
|
||||
|
||||
const (
|
||||
ConnectResult_CONN_SUCCESS ConnectResult = 0
|
||||
ConnectResult_CONN_FAILED ConnectResult = 1
|
||||
ConnectResult_CONN_BLOCKED ConnectResult = 2
|
||||
)
|
||||
|
||||
var ConnectResult_name = map[int32]string{
|
||||
0: "CONN_SUCCESS",
|
||||
1: "CONN_FAILED",
|
||||
2: "CONN_BLOCKED",
|
||||
}
|
||||
|
||||
var ConnectResult_value = map[string]int32{
|
||||
"CONN_SUCCESS": 0,
|
||||
"CONN_FAILED": 1,
|
||||
"CONN_BLOCKED": 2,
|
||||
}
|
||||
|
||||
func (x ConnectResult) String() string {
|
||||
return proto.EnumName(ConnectResult_name, int32(x))
|
||||
}
|
||||
|
||||
func (ConnectResult) EnumDescriptor() ([]byte, []int) {
|
||||
return fileDescriptor_0c5120591600887d, []int{2}
|
||||
}
|
||||
|
||||
type Speed struct {
|
||||
SendBps uint64 `protobuf:"varint,1,opt,name=send_bps,json=sendBps,proto3" json:"send_bps,omitempty"`
|
||||
ReceiveBps uint64 `protobuf:"varint,2,opt,name=receive_bps,json=receiveBps,proto3" json:"receive_bps,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Speed) Reset() { *m = Speed{} }
|
||||
func (m *Speed) String() string { return proto.CompactTextString(m) }
|
||||
func (*Speed) ProtoMessage() {}
|
||||
func (*Speed) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_0c5120591600887d, []int{0}
|
||||
}
|
||||
|
||||
func (m *Speed) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_Speed.Unmarshal(m, b)
|
||||
}
|
||||
func (m *Speed) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_Speed.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *Speed) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_Speed.Merge(m, src)
|
||||
}
|
||||
func (m *Speed) XXX_Size() int {
|
||||
return xxx_messageInfo_Speed.Size(m)
|
||||
}
|
||||
func (m *Speed) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_Speed.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_Speed proto.InternalMessageInfo
|
||||
|
||||
func (m *Speed) GetSendBps() uint64 {
|
||||
if m != nil {
|
||||
return m.SendBps
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (m *Speed) GetReceiveBps() uint64 {
|
||||
if m != nil {
|
||||
return m.ReceiveBps
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type Credential struct {
|
||||
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
|
||||
Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Credential) Reset() { *m = Credential{} }
|
||||
func (m *Credential) String() string { return proto.CompactTextString(m) }
|
||||
func (*Credential) ProtoMessage() {}
|
||||
func (*Credential) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_0c5120591600887d, []int{1}
|
||||
}
|
||||
|
||||
func (m *Credential) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_Credential.Unmarshal(m, b)
|
||||
}
|
||||
func (m *Credential) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_Credential.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *Credential) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_Credential.Merge(m, src)
|
||||
}
|
||||
func (m *Credential) XXX_Size() int {
|
||||
return xxx_messageInfo_Credential.Size(m)
|
||||
}
|
||||
func (m *Credential) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_Credential.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_Credential proto.InternalMessageInfo
|
||||
|
||||
func (m *Credential) GetUsername() string {
|
||||
if m != nil {
|
||||
return m.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *Credential) GetPassword() string {
|
||||
if m != nil {
|
||||
return m.Password
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ClientAuthRequest struct {
|
||||
Credential *Credential `protobuf:"bytes,1,opt,name=credential,proto3" json:"credential,omitempty"`
|
||||
Speed *Speed `protobuf:"bytes,2,opt,name=speed,proto3" json:"speed,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *ClientAuthRequest) Reset() { *m = ClientAuthRequest{} }
|
||||
func (m *ClientAuthRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*ClientAuthRequest) ProtoMessage() {}
|
||||
func (*ClientAuthRequest) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_0c5120591600887d, []int{2}
|
||||
}
|
||||
|
||||
func (m *ClientAuthRequest) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ClientAuthRequest.Unmarshal(m, b)
|
||||
}
|
||||
func (m *ClientAuthRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_ClientAuthRequest.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *ClientAuthRequest) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_ClientAuthRequest.Merge(m, src)
|
||||
}
|
||||
func (m *ClientAuthRequest) XXX_Size() int {
|
||||
return xxx_messageInfo_ClientAuthRequest.Size(m)
|
||||
}
|
||||
func (m *ClientAuthRequest) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_ClientAuthRequest.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_ClientAuthRequest proto.InternalMessageInfo
|
||||
|
||||
func (m *ClientAuthRequest) GetCredential() *Credential {
|
||||
if m != nil {
|
||||
return m.Credential
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *ClientAuthRequest) GetSpeed() *Speed {
|
||||
if m != nil {
|
||||
return m.Speed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ServerAuthResponse struct {
|
||||
Result AuthResult `protobuf:"varint,1,opt,name=result,proto3,enum=core.AuthResult" json:"result,omitempty"`
|
||||
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
|
||||
Speed *Speed `protobuf:"bytes,3,opt,name=speed,proto3" json:"speed,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *ServerAuthResponse) Reset() { *m = ServerAuthResponse{} }
|
||||
func (m *ServerAuthResponse) String() string { return proto.CompactTextString(m) }
|
||||
func (*ServerAuthResponse) ProtoMessage() {}
|
||||
func (*ServerAuthResponse) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_0c5120591600887d, []int{3}
|
||||
}
|
||||
|
||||
func (m *ServerAuthResponse) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ServerAuthResponse.Unmarshal(m, b)
|
||||
}
|
||||
func (m *ServerAuthResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_ServerAuthResponse.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *ServerAuthResponse) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_ServerAuthResponse.Merge(m, src)
|
||||
}
|
||||
func (m *ServerAuthResponse) XXX_Size() int {
|
||||
return xxx_messageInfo_ServerAuthResponse.Size(m)
|
||||
}
|
||||
func (m *ServerAuthResponse) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_ServerAuthResponse.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_ServerAuthResponse proto.InternalMessageInfo
|
||||
|
||||
func (m *ServerAuthResponse) GetResult() AuthResult {
|
||||
if m != nil {
|
||||
return m.Result
|
||||
}
|
||||
return AuthResult_AUTH_SUCCESS
|
||||
}
|
||||
|
||||
func (m *ServerAuthResponse) GetMessage() string {
|
||||
if m != nil {
|
||||
return m.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *ServerAuthResponse) GetSpeed() *Speed {
|
||||
if m != nil {
|
||||
return m.Speed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ClientConnectRequest struct {
|
||||
Type ConnectionType `protobuf:"varint,1,opt,name=type,proto3,enum=core.ConnectionType" json:"type,omitempty"`
|
||||
Address string `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *ClientConnectRequest) Reset() { *m = ClientConnectRequest{} }
|
||||
func (m *ClientConnectRequest) String() string { return proto.CompactTextString(m) }
|
||||
func (*ClientConnectRequest) ProtoMessage() {}
|
||||
func (*ClientConnectRequest) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_0c5120591600887d, []int{4}
|
||||
}
|
||||
|
||||
func (m *ClientConnectRequest) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ClientConnectRequest.Unmarshal(m, b)
|
||||
}
|
||||
func (m *ClientConnectRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_ClientConnectRequest.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *ClientConnectRequest) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_ClientConnectRequest.Merge(m, src)
|
||||
}
|
||||
func (m *ClientConnectRequest) XXX_Size() int {
|
||||
return xxx_messageInfo_ClientConnectRequest.Size(m)
|
||||
}
|
||||
func (m *ClientConnectRequest) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_ClientConnectRequest.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_ClientConnectRequest proto.InternalMessageInfo
|
||||
|
||||
func (m *ClientConnectRequest) GetType() ConnectionType {
|
||||
if m != nil {
|
||||
return m.Type
|
||||
}
|
||||
return ConnectionType_Stream
|
||||
}
|
||||
|
||||
func (m *ClientConnectRequest) GetAddress() string {
|
||||
if m != nil {
|
||||
return m.Address
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type ServerConnectResponse struct {
|
||||
Result ConnectResult `protobuf:"varint,1,opt,name=result,proto3,enum=core.ConnectResult" json:"result,omitempty"`
|
||||
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
|
||||
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
||||
XXX_unrecognized []byte `json:"-"`
|
||||
XXX_sizecache int32 `json:"-"`
|
||||
}
|
||||
|
||||
func (m *ServerConnectResponse) Reset() { *m = ServerConnectResponse{} }
|
||||
func (m *ServerConnectResponse) String() string { return proto.CompactTextString(m) }
|
||||
func (*ServerConnectResponse) ProtoMessage() {}
|
||||
func (*ServerConnectResponse) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_0c5120591600887d, []int{5}
|
||||
}
|
||||
|
||||
func (m *ServerConnectResponse) XXX_Unmarshal(b []byte) error {
|
||||
return xxx_messageInfo_ServerConnectResponse.Unmarshal(m, b)
|
||||
}
|
||||
func (m *ServerConnectResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
return xxx_messageInfo_ServerConnectResponse.Marshal(b, m, deterministic)
|
||||
}
|
||||
func (m *ServerConnectResponse) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_ServerConnectResponse.Merge(m, src)
|
||||
}
|
||||
func (m *ServerConnectResponse) XXX_Size() int {
|
||||
return xxx_messageInfo_ServerConnectResponse.Size(m)
|
||||
}
|
||||
func (m *ServerConnectResponse) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_ServerConnectResponse.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_ServerConnectResponse proto.InternalMessageInfo
|
||||
|
||||
func (m *ServerConnectResponse) GetResult() ConnectResult {
|
||||
if m != nil {
|
||||
return m.Result
|
||||
}
|
||||
return ConnectResult_CONN_SUCCESS
|
||||
}
|
||||
|
||||
func (m *ServerConnectResponse) GetMessage() string {
|
||||
if m != nil {
|
||||
return m.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterEnum("core.AuthResult", AuthResult_name, AuthResult_value)
|
||||
proto.RegisterEnum("core.ConnectionType", ConnectionType_name, ConnectionType_value)
|
||||
proto.RegisterEnum("core.ConnectResult", ConnectResult_name, ConnectResult_value)
|
||||
proto.RegisterType((*Speed)(nil), "core.Speed")
|
||||
proto.RegisterType((*Credential)(nil), "core.Credential")
|
||||
proto.RegisterType((*ClientAuthRequest)(nil), "core.ClientAuthRequest")
|
||||
proto.RegisterType((*ServerAuthResponse)(nil), "core.ServerAuthResponse")
|
||||
proto.RegisterType((*ClientConnectRequest)(nil), "core.ClientConnectRequest")
|
||||
proto.RegisterType((*ServerConnectResponse)(nil), "core.ServerConnectResponse")
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterFile("control.proto", fileDescriptor_0c5120591600887d)
|
||||
}
|
||||
|
||||
var fileDescriptor_0c5120591600887d = []byte{
|
||||
// 434 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xd1, 0x6e, 0xd3, 0x30,
|
||||
0x14, 0x86, 0xd7, 0xd2, 0x75, 0xdb, 0x09, 0x1b, 0x99, 0xb7, 0x89, 0xc1, 0x0d, 0x90, 0xab, 0xaa,
|
||||
0x48, 0x15, 0x1a, 0x4f, 0x90, 0x3a, 0x41, 0x54, 0x54, 0x29, 0x72, 0x3a, 0x2e, 0xb8, 0xa0, 0xca,
|
||||
0x92, 0x23, 0x56, 0x91, 0xda, 0xc6, 0x76, 0x86, 0x26, 0x5e, 0x1e, 0xc5, 0x71, 0xd2, 0x16, 0x09,
|
||||
0x89, 0xbb, 0x9e, 0x73, 0x7e, 0xfd, 0x9f, 0xbf, 0x2a, 0x70, 0x9a, 0x0b, 0x6e, 0x94, 0x28, 0x27,
|
||||
0x52, 0x09, 0x23, 0xc8, 0x20, 0x17, 0x0a, 0x03, 0x0a, 0x87, 0xa9, 0x44, 0x2c, 0xc8, 0x0b, 0x38,
|
||||
0xd6, 0xc8, 0x8b, 0xd5, 0x9d, 0xd4, 0xd7, 0xbd, 0xd7, 0xbd, 0xd1, 0x80, 0x1d, 0xd5, 0xf3, 0x54,
|
||||
0x6a, 0xf2, 0x0a, 0x3c, 0x85, 0x39, 0xae, 0x1f, 0xd0, 0x5e, 0xfb, 0xf6, 0x0a, 0x6e, 0x35, 0x95,
|
||||
0x3a, 0x88, 0x00, 0xa8, 0xc2, 0x02, 0xb9, 0x59, 0x67, 0x25, 0x79, 0x09, 0xc7, 0x95, 0x46, 0xc5,
|
||||
0xb3, 0x0d, 0xda, 0xa6, 0x13, 0xd6, 0xcd, 0xf5, 0x4d, 0x66, 0x5a, 0xff, 0x12, 0xaa, 0xb0, 0x3d,
|
||||
0x27, 0xac, 0x9b, 0x83, 0x7b, 0x38, 0xa7, 0xe5, 0x1a, 0xb9, 0x09, 0x2b, 0x73, 0xcf, 0xf0, 0x67,
|
||||
0x85, 0xda, 0x90, 0x77, 0x00, 0x79, 0x57, 0x6d, 0xeb, 0xbc, 0x1b, 0x7f, 0x52, 0x3f, 0x7d, 0xb2,
|
||||
0x45, 0xb2, 0x9d, 0x0c, 0x79, 0x03, 0x87, 0xba, 0x36, 0xb2, 0xfd, 0xde, 0x8d, 0xd7, 0x84, 0xad,
|
||||
0x24, 0x6b, 0x2e, 0xc1, 0x6f, 0x20, 0x29, 0xaa, 0x07, 0x54, 0x0d, 0x49, 0x4b, 0xc1, 0x35, 0x92,
|
||||
0x11, 0x0c, 0x15, 0xea, 0xaa, 0x34, 0x16, 0x73, 0xd6, 0x62, 0x5c, 0xa6, 0x2a, 0x0d, 0x73, 0x77,
|
||||
0x72, 0x0d, 0x47, 0x1b, 0xd4, 0x3a, 0xfb, 0x8e, 0x4e, 0xa2, 0x1d, 0xb7, 0xf0, 0x27, 0xff, 0x84,
|
||||
0x7f, 0x85, 0xcb, 0x46, 0x93, 0x0a, 0xce, 0x31, 0x37, 0xad, 0xe9, 0x08, 0x06, 0xe6, 0x51, 0xa2,
|
||||
0x83, 0x5f, 0x3a, 0xc7, 0x26, 0xb3, 0x16, 0x7c, 0xf9, 0x28, 0x91, 0xd9, 0x44, 0x8d, 0xcf, 0x8a,
|
||||
0x42, 0xa1, 0xd6, 0x2d, 0xde, 0x8d, 0xc1, 0x37, 0xb8, 0x6a, 0xc4, 0xba, 0x6e, 0xe7, 0xf6, 0xf6,
|
||||
0x2f, 0xb7, 0x8b, 0xbd, 0xfa, 0xff, 0xd5, 0x1b, 0x27, 0x00, 0xdb, 0xbf, 0x83, 0xf8, 0xf0, 0x34,
|
||||
0xbc, 0x5d, 0x7e, 0x5c, 0xa5, 0xb7, 0x94, 0xc6, 0x69, 0xea, 0x1f, 0x90, 0x2b, 0x38, 0xb7, 0x9b,
|
||||
0x59, 0xf2, 0x25, 0x9c, 0xcf, 0xa2, 0x15, 0x65, 0x71, 0xe4, 0xf7, 0xc8, 0x73, 0xb8, 0x70, 0xeb,
|
||||
0x65, 0xcc, 0x92, 0x70, 0xbe, 0x8a, 0x19, 0x5b, 0x30, 0xbf, 0x3f, 0x1e, 0xc1, 0xd9, 0xbe, 0x21,
|
||||
0x01, 0x18, 0xa6, 0x46, 0x61, 0xb6, 0xf1, 0x0f, 0xea, 0xdf, 0x9f, 0xb3, 0xfc, 0x07, 0x1a, 0xbf,
|
||||
0x37, 0x8e, 0xe0, 0x74, 0xef, 0xb1, 0x35, 0x9c, 0x2e, 0x92, 0x64, 0x07, 0xfe, 0x0c, 0x3c, 0xbb,
|
||||
0xf9, 0x10, 0xce, 0xe6, 0x16, 0xdb, 0x46, 0xa6, 0xf3, 0x05, 0xfd, 0x14, 0x47, 0x7e, 0xff, 0x6e,
|
||||
0x68, 0x3f, 0xfd, 0xf7, 0x7f, 0x02, 0x00, 0x00, 0xff, 0xff, 0xd2, 0x2f, 0x8d, 0x6f, 0x0b, 0x03,
|
||||
0x00, 0x00,
|
||||
}
|
@@ -1,50 +0,0 @@
|
||||
syntax = "proto3";
|
||||
package core;
|
||||
|
||||
message Speed {
|
||||
uint64 send_bps = 1;
|
||||
uint64 receive_bps = 2;
|
||||
}
|
||||
|
||||
message Credential {
|
||||
string username = 1;
|
||||
string password = 2;
|
||||
}
|
||||
|
||||
enum AuthResult {
|
||||
AUTH_SUCCESS = 0;
|
||||
AUTH_INVALID_CRED = 1;
|
||||
AUTH_INTERNAL_ERROR = 2;
|
||||
}
|
||||
|
||||
message ClientAuthRequest {
|
||||
Credential credential = 1;
|
||||
Speed speed = 2;
|
||||
}
|
||||
|
||||
message ServerAuthResponse {
|
||||
AuthResult result = 1;
|
||||
string message = 2;
|
||||
Speed speed = 3;
|
||||
}
|
||||
|
||||
enum ConnectionType {
|
||||
Stream = 0;
|
||||
Packet = 1;
|
||||
}
|
||||
|
||||
enum ConnectResult {
|
||||
CONN_SUCCESS = 0;
|
||||
CONN_FAILED = 1;
|
||||
CONN_BLOCKED = 2;
|
||||
}
|
||||
|
||||
message ClientConnectRequest {
|
||||
ConnectionType type = 1;
|
||||
string address = 2;
|
||||
}
|
||||
|
||||
message ServerConnectResponse {
|
||||
ConnectResult result = 1;
|
||||
string message = 2;
|
||||
}
|
@@ -1,3 +0,0 @@
|
||||
package pb
|
||||
|
||||
//go:generate protoc --go_out=. control.proto
|
45
pkg/core/protocol.go
Normal file
45
pkg/core/protocol.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
protocolVersion = uint8(1)
|
||||
protocolTimeout = 10 * time.Second
|
||||
|
||||
closeErrorCodeGeneric = 0
|
||||
closeErrorCodeProtocol = 1
|
||||
closeErrorCodeAuth = 2
|
||||
)
|
||||
|
||||
type transmissionRate struct {
|
||||
SendBPS uint64
|
||||
RecvBPS uint64
|
||||
}
|
||||
|
||||
type clientHello struct {
|
||||
Rate transmissionRate
|
||||
AuthLen uint16 `struc:"sizeof=Auth"`
|
||||
Auth []byte
|
||||
}
|
||||
|
||||
type serverHello struct {
|
||||
OK bool
|
||||
Rate transmissionRate
|
||||
MessageLen uint16 `struc:"sizeof=Message"`
|
||||
Message string
|
||||
}
|
||||
|
||||
type clientRequest struct {
|
||||
UDP bool
|
||||
AddressLen uint16 `struc:"sizeof=Address"`
|
||||
Address string
|
||||
}
|
||||
|
||||
type serverResponse struct {
|
||||
OK bool
|
||||
UDPSessionID uint32
|
||||
MessageLen uint16 `struc:"sizeof=Message"`
|
||||
Message string
|
||||
}
|
@@ -4,61 +4,32 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/tobyxdd/hysteria/pkg/core/pb"
|
||||
"github.com/lunixbochs/struc"
|
||||
"github.com/tobyxdd/hysteria/pkg/acl"
|
||||
"github.com/tobyxdd/hysteria/pkg/utils"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type AuthResult int32
|
||||
type ConnectionType int32
|
||||
type ConnectResult int32
|
||||
const dialTimeout = 10 * time.Second
|
||||
|
||||
const (
|
||||
AuthResultSuccess AuthResult = iota
|
||||
AuthResultInvalidCred
|
||||
AuthResultInternalError
|
||||
)
|
||||
|
||||
const (
|
||||
ConnectionTypeStream ConnectionType = iota
|
||||
ConnectionTypePacket
|
||||
)
|
||||
|
||||
const (
|
||||
ConnectResultSuccess ConnectResult = iota
|
||||
ConnectResultFailed
|
||||
ConnectResultBlocked
|
||||
)
|
||||
|
||||
type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string)
|
||||
type ClientDisconnectedFunc func(addr net.Addr, username string, err error)
|
||||
type HandleRequestFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string) (ConnectResult, string, io.ReadWriteCloser)
|
||||
type RequestClosedFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string, err error)
|
||||
type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
|
||||
type RequestFunc func(addr net.Addr, auth []byte, udp bool, reqAddr string)
|
||||
|
||||
type Server struct {
|
||||
inboundBytes, outboundBytes uint64 // atomic
|
||||
sendBPS, recvBPS uint64
|
||||
congestionFactory CongestionFactory
|
||||
authFunc AuthFunc
|
||||
requestFunc RequestFunc
|
||||
aclEngine *acl.Engine
|
||||
|
||||
listener quic.Listener
|
||||
sendBPS, recvBPS uint64
|
||||
|
||||
congestionFactory CongestionFactory
|
||||
clientAuthFunc ClientAuthFunc
|
||||
clientDisconnectedFunc ClientDisconnectedFunc
|
||||
handleRequestFunc HandleRequestFunc
|
||||
requestClosedFunc RequestClosedFunc
|
||||
listener quic.Listener
|
||||
}
|
||||
|
||||
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
|
||||
obfuscator Obfuscator,
|
||||
clientAuthFunc ClientAuthFunc,
|
||||
clientDisconnectedFunc ClientDisconnectedFunc,
|
||||
handleRequestFunc HandleRequestFunc,
|
||||
requestClosedFunc RequestClosedFunc) (*Server, error) {
|
||||
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, aclEngine *acl.Engine,
|
||||
obfuscator Obfuscator, authFunc AuthFunc, requestFunc RequestFunc) (*Server, error) {
|
||||
packetConn, err := net.ListenPacket("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -75,14 +46,13 @@ func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||
return nil, err
|
||||
}
|
||||
s := &Server{
|
||||
listener: listener,
|
||||
sendBPS: sendBPS,
|
||||
recvBPS: recvBPS,
|
||||
congestionFactory: congestionFactory,
|
||||
clientAuthFunc: clientAuthFunc,
|
||||
clientDisconnectedFunc: clientDisconnectedFunc,
|
||||
handleRequestFunc: handleRequestFunc,
|
||||
requestClosedFunc: requestClosedFunc,
|
||||
listener: listener,
|
||||
sendBPS: sendBPS,
|
||||
recvBPS: recvBPS,
|
||||
congestionFactory: congestionFactory,
|
||||
authFunc: authFunc,
|
||||
requestFunc: requestFunc,
|
||||
aclEngine: aclEngine,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@@ -97,128 +67,156 @@ func (s *Server) Serve() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Stats() (uint64, uint64) {
|
||||
return atomic.LoadUint64(&s.inboundBytes), atomic.LoadUint64(&s.outboundBytes)
|
||||
}
|
||||
|
||||
func (s *Server) Close() error {
|
||||
return s.listener.Close()
|
||||
}
|
||||
|
||||
func (s *Server) handleClient(cs quic.Session) {
|
||||
// Expect the client to create a control stream to send its own information
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout)
|
||||
ctlStream, err := cs.AcceptStream(ctx)
|
||||
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
|
||||
stream, err := cs.AcceptStream(ctx)
|
||||
ctxCancel()
|
||||
if err != nil {
|
||||
_ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error")
|
||||
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
|
||||
return
|
||||
}
|
||||
// Handle the control stream
|
||||
username, ok, err := s.handleControlStream(cs, ctlStream)
|
||||
auth, ok, err := s.handleControlStream(cs, stream)
|
||||
if err != nil {
|
||||
_ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
|
||||
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
_ = cs.CloseWithError(closeErrorCodeGeneric, "authentication failure")
|
||||
_ = cs.CloseWithError(closeErrorCodeAuth, "auth error")
|
||||
return
|
||||
}
|
||||
// Start accepting streams
|
||||
var closeErr error
|
||||
for {
|
||||
stream, err := cs.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
closeErr = err
|
||||
break
|
||||
}
|
||||
go s.handleStream(cs.LocalAddr(), cs.RemoteAddr(), username, stream)
|
||||
go s.handleStream(cs.RemoteAddr(), auth, stream)
|
||||
}
|
||||
s.clientDisconnectedFunc(cs.RemoteAddr(), username, closeErr)
|
||||
_ = cs.CloseWithError(closeErrorCodeGeneric, "generic")
|
||||
_ = cs.CloseWithError(closeErrorCodeGeneric, "")
|
||||
}
|
||||
|
||||
// Auth & negotiate speed
|
||||
func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) (string, bool, error) {
|
||||
req, err := readClientAuthRequest(stream)
|
||||
func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, error) {
|
||||
var ch clientHello
|
||||
err := struc.Unpack(stream, &ch)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
return nil, false, err
|
||||
}
|
||||
// Speed
|
||||
if req.Speed == nil || req.Speed.SendBps == 0 || req.Speed.ReceiveBps == 0 {
|
||||
return "", false, errors.New("incorrect speed provided by the client")
|
||||
if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 {
|
||||
return nil, false, errors.New("invalid rate from client")
|
||||
}
|
||||
serverSendBPS, serverReceiveBPS := req.Speed.ReceiveBps, req.Speed.SendBps
|
||||
serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS
|
||||
if s.sendBPS > 0 && serverSendBPS > s.sendBPS {
|
||||
serverSendBPS = s.sendBPS
|
||||
}
|
||||
if s.recvBPS > 0 && serverReceiveBPS > s.recvBPS {
|
||||
serverReceiveBPS = s.recvBPS
|
||||
if s.recvBPS > 0 && serverRecvBPS > s.recvBPS {
|
||||
serverRecvBPS = s.recvBPS
|
||||
}
|
||||
// Auth
|
||||
if req.Credential == nil {
|
||||
return "", false, errors.New("incorrect credential provided by the client")
|
||||
}
|
||||
authResult, msg := s.clientAuthFunc(cs.RemoteAddr(), req.Credential.Username, req.Credential.Password,
|
||||
serverSendBPS, serverReceiveBPS)
|
||||
ok, msg := s.authFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS)
|
||||
// Response
|
||||
err = writeServerAuthResponse(stream, &pb.ServerAuthResponse{
|
||||
Result: pb.AuthResult(authResult),
|
||||
Message: msg,
|
||||
Speed: &pb.Speed{
|
||||
SendBps: serverSendBPS,
|
||||
ReceiveBps: serverReceiveBPS,
|
||||
err = struc.Pack(stream, &serverHello{
|
||||
OK: ok,
|
||||
Rate: transmissionRate{
|
||||
SendBPS: serverSendBPS,
|
||||
RecvBPS: serverRecvBPS,
|
||||
},
|
||||
Message: msg,
|
||||
})
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
return nil, false, err
|
||||
}
|
||||
// Set the congestion accordingly
|
||||
if authResult == AuthResultSuccess && s.congestionFactory != nil {
|
||||
if ok && s.congestionFactory != nil {
|
||||
cs.SetCongestionControl(s.congestionFactory(serverSendBPS))
|
||||
}
|
||||
return req.Credential.Username, authResult == AuthResultSuccess, nil
|
||||
return ch.Auth, ok, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleStream(localAddr net.Addr, remoteAddr net.Addr, username string, stream quic.Stream) {
|
||||
func (s *Server) handleStream(remoteAddr net.Addr, auth []byte, stream quic.Stream) {
|
||||
defer stream.Close()
|
||||
// Read request
|
||||
req, err := readClientConnectRequest(stream)
|
||||
var req clientRequest
|
||||
err := struc.Unpack(stream, &req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Create connection with the handler
|
||||
result, msg, conn := s.handleRequestFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address)
|
||||
defer func() {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
s.requestFunc(remoteAddr, auth, req.UDP, req.Address)
|
||||
if !req.UDP {
|
||||
// TCP connection
|
||||
s.handleTCP(stream, req.Address)
|
||||
} else {
|
||||
// UDP connection
|
||||
// TODO
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleTCP(stream quic.Stream, reqAddr string) {
|
||||
host, port, err := net.SplitHostPort(reqAddr)
|
||||
if err != nil {
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "invalid address",
|
||||
})
|
||||
return
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
// IP request, clear host for ACL engine
|
||||
host = ""
|
||||
}
|
||||
action, arg := acl.ActionDirect, ""
|
||||
if s.aclEngine != nil {
|
||||
action, arg = s.aclEngine.Lookup(host, ip)
|
||||
}
|
||||
|
||||
var conn net.Conn // Connection to be piped
|
||||
switch action {
|
||||
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
|
||||
conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout)
|
||||
if err != nil {
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
}()
|
||||
// Send response
|
||||
err = writeServerConnectResponse(stream, &pb.ServerConnectResponse{
|
||||
Result: pb.ConnectResult(result),
|
||||
Message: msg,
|
||||
case acl.ActionBlock:
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "blocked by ACL",
|
||||
})
|
||||
return
|
||||
case acl.ActionHijack:
|
||||
hijackAddr := net.JoinHostPort(arg, port)
|
||||
conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout)
|
||||
if err != nil {
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
default:
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "ACL error",
|
||||
})
|
||||
return
|
||||
}
|
||||
// So far so good if we reach here
|
||||
err = struc.Pack(stream, &serverResponse{
|
||||
OK: true,
|
||||
})
|
||||
if err != nil {
|
||||
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, err)
|
||||
return
|
||||
}
|
||||
if result != ConnectResultSuccess {
|
||||
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address,
|
||||
fmt.Errorf("handler returned an unsuccessful state %d (msg: %s)", result, msg))
|
||||
return
|
||||
}
|
||||
switch req.Type {
|
||||
case pb.ConnectionType_Stream:
|
||||
err = utils.PipePair(stream, conn, &s.outboundBytes, &s.inboundBytes)
|
||||
case pb.ConnectionType_Packet:
|
||||
err = utils.PipePair(&utils.PacketWrapperConn{Orig: &utils.QUICStreamWrapperConn{
|
||||
Orig: stream,
|
||||
PseudoLocalAddr: localAddr,
|
||||
PseudoRemoteAddr: remoteAddr,
|
||||
}}, conn, &s.outboundBytes, &s.inboundBytes)
|
||||
default:
|
||||
err = fmt.Errorf("unsupported connection type %s", req.Type.String())
|
||||
}
|
||||
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, err)
|
||||
_ = utils.Pipe2Way(stream, conn)
|
||||
}
|
||||
|
@@ -1,13 +0,0 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"github.com/lucas-clemente/quic-go/congestion"
|
||||
"time"
|
||||
)
|
||||
|
||||
const controlStreamTimeout = 10 * time.Second
|
||||
|
||||
var controlProtocolEndian = binary.BigEndian
|
||||
|
||||
type CongestionFactory func(refBPS uint64) congestion.CongestionControl
|
@@ -42,9 +42,9 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng
|
||||
case acl.ActionDirect:
|
||||
return net.Dial(network, addr)
|
||||
case acl.ActionProxy:
|
||||
return hyClient.Dial(false, addr)
|
||||
return hyClient.DialTCP(addr)
|
||||
case acl.ActionBlock:
|
||||
return nil, errors.New("blocked in ACL")
|
||||
return nil, errors.New("blocked by ACL")
|
||||
case acl.ActionHijack:
|
||||
return net.Dial(network, net.JoinHostPort(arg, port))
|
||||
default:
|
||||
@@ -53,7 +53,6 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng
|
||||
},
|
||||
IdleConnTimeout: idleTimeout,
|
||||
// TODO: Disable HTTP2 support? ref: https://github.com/elazarl/goproxy/issues/361
|
||||
//TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
|
||||
}
|
||||
proxy.ConnectDial = nil
|
||||
if basicAuthFunc != nil {
|
||||
|
@@ -156,12 +156,8 @@ func (s *Server) handle(c *net.TCPConn, r *socks5.Request) error {
|
||||
return s.handleTCP(c, r)
|
||||
} else if r.Cmd == socks5.CmdUDP {
|
||||
// UDP
|
||||
if !s.DisableUDP {
|
||||
return s.handleUDP(c, r)
|
||||
} else {
|
||||
_ = sendReply(c, socks5.RepCommandNotSupported)
|
||||
return ErrUnsupportedCmd
|
||||
}
|
||||
_ = sendReply(c, socks5.RepCommandNotSupported)
|
||||
return ErrUnsupportedCmd
|
||||
} else {
|
||||
_ = sendReply(c, socks5.RepCommandNotSupported)
|
||||
return ErrUnsupportedCmd
|
||||
@@ -193,7 +189,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
|
||||
closeErr = pipePair(c, rc, s.TCPDeadline)
|
||||
return nil
|
||||
case acl.ActionProxy:
|
||||
rc, err := s.HyClient.Dial(false, addr)
|
||||
rc, err := s.HyClient.DialTCP(addr)
|
||||
if err != nil {
|
||||
_ = sendReply(c, socks5.RepHostUnreachable)
|
||||
closeErr = err
|
||||
@@ -225,132 +221,6 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleUDP(c *net.TCPConn, r *socks5.Request) error {
|
||||
s.NewUDPAssociateFunc(c.RemoteAddr())
|
||||
var closeErr error
|
||||
defer func() {
|
||||
s.UDPAssociateClosedFunc(c.RemoteAddr(), closeErr)
|
||||
}()
|
||||
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{
|
||||
IP: s.TCPAddr.IP,
|
||||
Zone: s.TCPAddr.Zone,
|
||||
})
|
||||
if err != nil {
|
||||
_ = sendReply(c, socks5.RepServerFailure)
|
||||
closeErr = err
|
||||
return err
|
||||
}
|
||||
defer udpConn.Close()
|
||||
// Send UDP server addr to the client
|
||||
atyp, addr, port, err := socks5.ParseAddress(udpConn.LocalAddr().String())
|
||||
if err != nil {
|
||||
_ = sendReply(c, socks5.RepServerFailure)
|
||||
closeErr = err
|
||||
return err
|
||||
}
|
||||
_, _ = socks5.NewReply(socks5.RepSuccess, atyp, addr, port).WriteTo(c)
|
||||
// Let UDP server do its job, we hold the TCP connection here
|
||||
go s.udpServer(udpConn)
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
if s.TCPDeadline != 0 {
|
||||
_ = c.SetDeadline(time.Now().Add(time.Duration(s.TCPDeadline) * time.Second))
|
||||
}
|
||||
_, err := c.Read(buf)
|
||||
if err != nil {
|
||||
closeErr = err
|
||||
break
|
||||
}
|
||||
}
|
||||
// As the TCP connection closes, so does the UDP listener
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) udpServer(c *net.UDPConn) {
|
||||
var clientAddr *net.UDPAddr
|
||||
remoteMap := make(map[string]io.ReadWriteCloser) // Remote addr <-> Remote conn
|
||||
buf := make([]byte, utils.PipeBufferSize)
|
||||
var closeErr error
|
||||
|
||||
for {
|
||||
n, caddr, err := c.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
closeErr = err
|
||||
break
|
||||
}
|
||||
d, err := socks5.NewDatagramFromBytes(buf[:n])
|
||||
if err != nil || d.Frag != 0 {
|
||||
// Ignore bad packets
|
||||
continue
|
||||
}
|
||||
if clientAddr == nil {
|
||||
// Whoever sends the first valid packet is our client :P
|
||||
clientAddr = caddr
|
||||
} else if caddr.String() != clientAddr.String() {
|
||||
// We already have a client and you're not it!
|
||||
continue
|
||||
}
|
||||
domain, ip, port, addr := parseDatagramRequestAddress(d)
|
||||
rc := remoteMap[addr]
|
||||
if rc == nil {
|
||||
// Need a new entry
|
||||
action, arg := acl.ActionProxy, ""
|
||||
if s.ACLEngine != nil {
|
||||
action, arg = s.ACLEngine.Lookup(domain, ip)
|
||||
}
|
||||
s.NewUDPTunnelFunc(clientAddr, addr, action, arg)
|
||||
// Handle according to the action
|
||||
switch action {
|
||||
case acl.ActionDirect:
|
||||
rc, err = net.Dial("udp", addr)
|
||||
if err != nil {
|
||||
s.UDPTunnelClosedFunc(clientAddr, addr, err)
|
||||
continue
|
||||
}
|
||||
// The other direction
|
||||
go udpReversePipe(clientAddr, c, rc)
|
||||
remoteMap[addr] = rc
|
||||
case acl.ActionProxy:
|
||||
rc, err = s.HyClient.Dial(true, addr)
|
||||
if err != nil {
|
||||
s.UDPTunnelClosedFunc(clientAddr, addr, err)
|
||||
continue
|
||||
}
|
||||
// The other direction
|
||||
go udpReversePipe(clientAddr, c, rc)
|
||||
remoteMap[addr] = rc
|
||||
case acl.ActionBlock:
|
||||
s.UDPTunnelClosedFunc(clientAddr, addr, errors.New("blocked in ACL"))
|
||||
continue
|
||||
case acl.ActionHijack:
|
||||
rc, err = net.Dial("udp", net.JoinHostPort(arg, port))
|
||||
if err != nil {
|
||||
s.UDPTunnelClosedFunc(clientAddr, addr, err)
|
||||
continue
|
||||
}
|
||||
// The other direction
|
||||
go udpReversePipe(clientAddr, c, rc)
|
||||
remoteMap[addr] = rc
|
||||
default:
|
||||
s.UDPTunnelClosedFunc(clientAddr, addr, fmt.Errorf("unknown action %d", action))
|
||||
continue
|
||||
}
|
||||
}
|
||||
_, err = rc.Write(d.Data)
|
||||
if err != nil {
|
||||
// The connection is no longer valid, close & remove from map
|
||||
_ = rc.Close()
|
||||
delete(remoteMap, addr)
|
||||
s.UDPTunnelClosedFunc(clientAddr, addr, err)
|
||||
}
|
||||
}
|
||||
// Close all remote connections
|
||||
for raddr, rc := range remoteMap {
|
||||
_ = rc.Close()
|
||||
s.UDPTunnelClosedFunc(clientAddr, raddr, closeErr)
|
||||
}
|
||||
}
|
||||
|
||||
func sendReply(conn *net.TCPConn, rep byte) error {
|
||||
p := socks5.NewReply(rep, socks5.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00})
|
||||
_, err := p.WriteTo(conn)
|
||||
@@ -367,24 +237,15 @@ func parseRequestAddress(r *socks5.Request) (domain string, ip net.IP, port stri
|
||||
}
|
||||
}
|
||||
|
||||
func parseDatagramRequestAddress(r *socks5.Datagram) (domain string, ip net.IP, port string, addr string) {
|
||||
p := strconv.Itoa(int(binary.BigEndian.Uint16(r.DstPort)))
|
||||
if r.Atyp == socks5.ATYPDomain {
|
||||
d := string(r.DstAddr[1:])
|
||||
return d, nil, p, net.JoinHostPort(d, p)
|
||||
} else {
|
||||
return "", r.DstAddr, p, net.JoinHostPort(net.IP(r.DstAddr).String(), p)
|
||||
}
|
||||
}
|
||||
|
||||
func pipePair(conn *net.TCPConn, stream io.ReadWriteCloser, deadline int) error {
|
||||
deadlineDuration := time.Duration(deadline) * time.Second
|
||||
errChan := make(chan error, 2)
|
||||
// TCP to stream
|
||||
go func() {
|
||||
buf := make([]byte, utils.PipeBufferSize)
|
||||
for {
|
||||
if deadline != 0 {
|
||||
_ = conn.SetDeadline(time.Now().Add(time.Duration(deadline) * time.Second))
|
||||
_ = conn.SetDeadline(time.Now().Add(deadlineDuration))
|
||||
}
|
||||
rn, err := conn.Read(buf)
|
||||
if rn > 0 {
|
||||
@@ -402,22 +263,7 @@ func pipePair(conn *net.TCPConn, stream io.ReadWriteCloser, deadline int) error
|
||||
}()
|
||||
// Stream to TCP
|
||||
go func() {
|
||||
errChan <- utils.Pipe(stream, conn, nil)
|
||||
errChan <- utils.Pipe(stream, conn)
|
||||
}()
|
||||
return <-errChan
|
||||
}
|
||||
|
||||
func udpReversePipe(clientAddr *net.UDPAddr, c *net.UDPConn, rc io.ReadWriteCloser) {
|
||||
buf := make([]byte, utils.PipeBufferSize)
|
||||
for {
|
||||
n, err := rc.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
d := socks5.NewDatagram(socks5.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}, buf[:n])
|
||||
_, err = c.WriteTo(d.Bytes(), clientAddr)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,96 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type PacketWrapperConn struct {
|
||||
Orig net.Conn
|
||||
}
|
||||
|
||||
func (w *PacketWrapperConn) Read(b []byte) (n int, err error) {
|
||||
var sz uint32
|
||||
if err := binary.Read(w.Orig, binary.BigEndian, &sz); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if int(sz) <= len(b) {
|
||||
return io.ReadFull(w.Orig, b[:sz])
|
||||
} else {
|
||||
return 0, fmt.Errorf("the buffer is too small to hold %d bytes of packet data", sz)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *PacketWrapperConn) Write(b []byte) (n int, err error) {
|
||||
sz := uint32(len(b))
|
||||
if err := binary.Write(w.Orig, binary.BigEndian, &sz); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return w.Orig.Write(b)
|
||||
}
|
||||
|
||||
func (w *PacketWrapperConn) Close() error {
|
||||
return w.Orig.Close()
|
||||
}
|
||||
|
||||
func (w *PacketWrapperConn) LocalAddr() net.Addr {
|
||||
return w.Orig.LocalAddr()
|
||||
}
|
||||
|
||||
func (w *PacketWrapperConn) RemoteAddr() net.Addr {
|
||||
return w.Orig.RemoteAddr()
|
||||
}
|
||||
|
||||
func (w *PacketWrapperConn) SetDeadline(t time.Time) error {
|
||||
return w.Orig.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (w *PacketWrapperConn) SetReadDeadline(t time.Time) error {
|
||||
return w.Orig.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (w *PacketWrapperConn) SetWriteDeadline(t time.Time) error {
|
||||
return w.Orig.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
type QUICStreamWrapperConn struct {
|
||||
Orig quic.Stream
|
||||
PseudoLocalAddr net.Addr
|
||||
PseudoRemoteAddr net.Addr
|
||||
}
|
||||
|
||||
func (w *QUICStreamWrapperConn) Read(b []byte) (n int, err error) {
|
||||
return w.Orig.Read(b)
|
||||
}
|
||||
|
||||
func (w *QUICStreamWrapperConn) Write(b []byte) (n int, err error) {
|
||||
return w.Orig.Write(b)
|
||||
}
|
||||
|
||||
func (w *QUICStreamWrapperConn) Close() error {
|
||||
return w.Orig.Close()
|
||||
}
|
||||
|
||||
func (w *QUICStreamWrapperConn) LocalAddr() net.Addr {
|
||||
return w.PseudoLocalAddr
|
||||
}
|
||||
|
||||
func (w *QUICStreamWrapperConn) RemoteAddr() net.Addr {
|
||||
return w.PseudoRemoteAddr
|
||||
}
|
||||
|
||||
func (w *QUICStreamWrapperConn) SetDeadline(t time.Time) error {
|
||||
return w.Orig.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (w *QUICStreamWrapperConn) SetReadDeadline(t time.Time) error {
|
||||
return w.Orig.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (w *QUICStreamWrapperConn) SetWriteDeadline(t time.Time) error {
|
||||
return w.Orig.SetWriteDeadline(t)
|
||||
}
|
@@ -2,20 +2,16 @@ package utils
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const PipeBufferSize = 65536
|
||||
|
||||
func Pipe(src, dst io.ReadWriter, atomicCounter *uint64) error {
|
||||
func Pipe(src, dst io.ReadWriter) error {
|
||||
buf := make([]byte, PipeBufferSize)
|
||||
for {
|
||||
rn, err := src.Read(buf)
|
||||
if rn > 0 {
|
||||
wn, err := dst.Write(buf[:rn])
|
||||
if atomicCounter != nil {
|
||||
atomic.AddUint64(atomicCounter, uint64(wn))
|
||||
}
|
||||
_, err := dst.Write(buf[:rn])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -26,13 +22,13 @@ func Pipe(src, dst io.ReadWriter, atomicCounter *uint64) error {
|
||||
}
|
||||
}
|
||||
|
||||
func PipePair(rw1, rw2 io.ReadWriter, rw1WriteCounter, rw2WriteCounter *uint64) error {
|
||||
func Pipe2Way(rw1, rw2 io.ReadWriter) error {
|
||||
errChan := make(chan error, 2)
|
||||
go func() {
|
||||
errChan <- Pipe(rw2, rw1, rw1WriteCounter)
|
||||
errChan <- Pipe(rw2, rw1)
|
||||
}()
|
||||
go func() {
|
||||
errChan <- Pipe(rw1, rw2, rw2WriteCounter)
|
||||
errChan <- Pipe(rw1, rw2)
|
||||
}()
|
||||
// We only need the first error
|
||||
return <-errChan
|
||||
|
Reference in New Issue
Block a user