Most things work fine now, except:

- UDP support has been temporarily removed, pending upstream QUIC library support for unreliable messages
- SOCKS5 server needs some rework
- Authentication
This commit is contained in:
Toby
2021-01-28 23:57:53 -08:00
parent d9d07a5b2a
commit 7d280393a3
24 changed files with 618 additions and 1985 deletions

View File

@@ -3,11 +3,6 @@ package main
import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net"
"net/http"
"time"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/sirupsen/logrus"
@@ -17,42 +12,38 @@ import (
hyHTTP "github.com/tobyxdd/hysteria/pkg/http"
"github.com/tobyxdd/hysteria/pkg/obfs"
"github.com/tobyxdd/hysteria/pkg/socks5"
"io/ioutil"
"net"
"net/http"
"time"
)
func proxyClient(args []string) {
var config proxyClientConfig
err := loadConfig(&config, args)
if err != nil {
logrus.WithField("error", err).Fatal("Unable to load configuration")
}
if err := config.Check(); err != nil {
logrus.WithField("error", err).Fatal("Configuration error")
}
logrus.WithField("config", config.String()).Info("Configuration loaded")
func client(config *clientConfig) {
logrus.WithField("config", config.String()).Info("Client configuration loaded")
// TLS
tlsConfig := &tls.Config{
InsecureSkipVerify: config.Insecure,
NextProtos: []string{proxyTLSProtocol},
NextProtos: []string{tlsProtocolName},
MinVersion: tls.VersionTLS13,
}
// Load CA
if len(config.CustomCAFile) > 0 {
bs, err := ioutil.ReadFile(config.CustomCAFile)
if len(config.CustomCA) > 0 {
bs, err := ioutil.ReadFile(config.CustomCA)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"file": config.CustomCAFile,
}).Fatal("Unable to load CA file")
"file": config.CustomCA,
}).Fatal("Failed to load CA")
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(bs) {
logrus.WithFields(logrus.Fields{
"file": config.CustomCAFile,
}).Fatal("Unable to parse CA file")
"file": config.CustomCA,
}).Fatal("Failed to parse CA")
}
tlsConfig.RootCAs = cp
}
// QUIC config
quicConfig := &quic.Config{
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindow,
@@ -64,46 +55,54 @@ func proxyClient(args []string) {
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
}
// Auth
var auth []byte
if len(config.Auth) > 0 {
auth = config.Auth
} else {
auth = []byte(config.AuthString)
}
// Obfuscator
var obfuscator core.Obfuscator
if len(config.Obfs) > 0 {
obfuscator = obfs.XORObfuscator(config.Obfs)
}
// ACL
var aclEngine *acl.Engine
if len(config.ACLFile) > 0 {
aclEngine, err = acl.LoadFromFile(config.ACLFile)
if len(config.ACL) > 0 {
var err error
aclEngine, err = acl.LoadFromFile(config.ACL)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"file": config.ACLFile,
}).Fatal("Unable to parse ACL")
"file": config.ACL,
}).Fatal("Failed to parse ACL")
}
}
client, err := core.NewClient(config.ServerAddr, config.Username, config.Password, tlsConfig, quicConfig,
// Client
client, err := core.NewClient(config.Server, auth, tlsConfig, quicConfig,
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
}, obfuscator)
if err != nil {
logrus.WithField("error", err).Fatal("Client initialization failed")
logrus.WithField("error", err).Fatal("Failed to initialize client")
}
defer client.Close()
logrus.WithField("addr", config.ServerAddr).Info("Connected")
logrus.WithField("addr", config.Server).Info("Connected")
// Local
errChan := make(chan error)
if len(config.SOCKS5Addr) > 0 {
if len(config.SOCKS5.Listen) > 0 {
go func() {
var authFunc func(user, password string) bool
if config.SOCKS5User != "" && config.SOCKS5Password != "" {
if config.SOCKS5.User != "" && config.SOCKS5.Password != "" {
authFunc = func(user, password string) bool {
return config.SOCKS5User == user && config.SOCKS5Password == password
return config.SOCKS5.User == user && config.SOCKS5.Password == password
}
}
socks5server, err := socks5.NewServer(client, config.SOCKS5Addr, authFunc, config.SOCKS5Timeout, aclEngine,
config.SOCKS5DisableUDP,
socks5server, err := socks5.NewServer(client, config.SOCKS5.Listen, authFunc, config.SOCKS5.Timeout, aclEngine,
config.SOCKS5.DisableUDP,
func(addr net.Addr, reqAddr string, action acl.Action, arg string) {
logrus.WithFields(logrus.Fields{
"action": actionToString(action, arg),
@@ -144,22 +143,22 @@ func proxyClient(args []string) {
}).Debug("SOCKS5 UDP tunnel closed")
})
if err != nil {
logrus.WithField("error", err).Fatal("SOCKS5 server initialization failed")
logrus.WithField("error", err).Fatal("Failed to initialize SOCKS5 server")
}
logrus.WithField("addr", config.SOCKS5Addr).Info("SOCKS5 server up and running")
logrus.WithField("addr", config.SOCKS5.Listen).Info("SOCKS5 server up and running")
errChan <- socks5server.ListenAndServe()
}()
}
if len(config.HTTPAddr) > 0 {
if len(config.HTTP.Listen) > 0 {
go func() {
var authFunc func(user, password string) bool
if config.HTTPUser != "" && config.HTTPPassword != "" {
if config.HTTP.User != "" && config.HTTP.Password != "" {
authFunc = func(user, password string) bool {
return config.HTTPUser == user && config.HTTPPassword == password
return config.HTTP.User == user && config.HTTP.Password == password
}
}
proxy, err := hyHTTP.NewProxyHTTPServer(client, time.Duration(config.HTTPTimeout)*time.Second, aclEngine,
proxy, err := hyHTTP.NewProxyHTTPServer(client, time.Duration(config.HTTP.Timeout)*time.Second, aclEngine,
func(reqAddr string, action acl.Action, arg string) {
logrus.WithFields(logrus.Fields{
"action": actionToString(action, arg),
@@ -168,14 +167,14 @@ func proxyClient(args []string) {
},
authFunc)
if err != nil {
logrus.WithField("error", err).Fatal("HTTP server initialization failed")
logrus.WithField("error", err).Fatal("Failed to initialize HTTP server")
}
if config.HTTPSCert != "" && config.HTTPSKey != "" {
logrus.WithField("addr", config.HTTPAddr).Info("HTTPS server up and running")
errChan <- http.ListenAndServeTLS(config.HTTPAddr, config.HTTPSCert, config.HTTPSKey, proxy)
if config.HTTP.Cert != "" && config.HTTP.Key != "" {
logrus.WithField("addr", config.HTTP.Listen).Info("HTTPS server up and running")
errChan <- http.ListenAndServeTLS(config.HTTP.Listen, config.HTTP.Cert, config.HTTP.Key, proxy)
} else {
logrus.WithField("addr", config.HTTPAddr).Info("HTTP server up and running")
errChan <- http.ListenAndServe(config.HTTPAddr, proxy)
logrus.WithField("addr", config.HTTP.Listen).Info("HTTP server up and running")
errChan <- http.ListenAndServe(config.HTTP.Listen, proxy)
}
}()
}

View File

@@ -1,86 +1,127 @@
package main
import (
"encoding/json"
"flag"
"io/ioutil"
"os"
"reflect"
"strings"
"time"
"errors"
"fmt"
)
const (
mbpsToBps = 125000
dialTimeout = 10 * time.Second
mbpsToBps = 125000
DefaultMaxReceiveStreamFlowControlWindow = 33554432
DefaultMaxReceiveConnectionFlowControlWindow = 67108864
DefaultMaxIncomingStreams = 200
DefaultMaxIncomingStreams = 1024
tlsProtocolName = "hysteria"
)
func loadConfig(cfg interface{}, args []string) error {
cfgVal := reflect.ValueOf(cfg).Elem()
fs := flag.NewFlagSet("", flag.ContinueOnError)
fsValMap := make(map[reflect.Value]interface{}, cfgVal.NumField())
for i := 0; i < cfgVal.NumField(); i++ {
structField := cfgVal.Type().Field(i)
tag := structField.Tag
switch structField.Type.Kind() {
case reflect.String:
fsValMap[cfgVal.Field(i)] =
fs.String(jsonTagToFlagName(tag.Get("json")), "", tag.Get("desc"))
case reflect.Int:
fsValMap[cfgVal.Field(i)] =
fs.Int(jsonTagToFlagName(tag.Get("json")), 0, tag.Get("desc"))
case reflect.Uint64:
fsValMap[cfgVal.Field(i)] =
fs.Uint64(jsonTagToFlagName(tag.Get("json")), 0, tag.Get("desc"))
case reflect.Bool:
var bf optionalBoolFlag
fs.Var(&bf, jsonTagToFlagName(tag.Get("json")), tag.Get("desc"))
fsValMap[cfgVal.Field(i)] = &bf
}
type serverConfig struct {
Listen string `json:"listen"`
CertFile string `json:"cert"`
KeyFile string `json:"key"`
// Optional below
UpMbps int `json:"up_mbps"`
DownMbps int `json:"down_mbps"`
DisableUDP bool `json:"disable_udp"`
ACL string `json:"acl"`
Obfs string `json:"obfs"`
Auth struct {
Mode string `json:"mode"`
Config interface{} `json:"config"`
} `json:"auth"`
ReceiveWindowConn uint64 `json:"recv_window_conn"`
ReceiveWindowClient uint64 `json:"recv_window_client"`
MaxConnClient int `json:"max_conn_client"`
}
func (c *serverConfig) Check() error {
if len(c.Listen) == 0 {
return errors.New("no listen address")
}
configFile := fs.String("config", "", "Configuration file")
// Parse
if err := fs.Parse(args); err != nil {
os.Exit(1)
if len(c.CertFile) == 0 || len(c.KeyFile) == 0 {
return errors.New("TLS cert or key not provided")
}
// Put together the config
if len(*configFile) > 0 {
cb, err := ioutil.ReadFile(*configFile)
if err != nil {
return err
}
if err := json.Unmarshal(cb, cfg); err != nil {
return err
}
if c.UpMbps < 0 || c.DownMbps < 0 {
return errors.New("invalid speed")
}
// Flags override config from file
for field, val := range fsValMap {
switch v := val.(type) {
case *string:
if len(*v) > 0 {
field.SetString(*v)
}
case *int:
if *v != 0 {
field.SetInt(int64(*v))
}
case *uint64:
if *v != 0 {
field.SetUint(*v)
}
case *optionalBoolFlag:
if v.Exists {
field.SetBool(v.Value)
}
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) {
return errors.New("invalid receive window size")
}
if c.MaxConnClient < 0 {
return errors.New("invalid max connections per client")
}
return nil
}
func jsonTagToFlagName(tag string) string {
return strings.ReplaceAll(tag, "_", "-")
func (c *serverConfig) String() string {
return fmt.Sprintf("%+v", *c)
}
type clientConfig struct {
Server string `json:"server"`
UpMbps int `json:"up_mbps"`
DownMbps int `json:"down_mbps"`
// Optional below
SOCKS5 struct {
Listen string `json:"listen"`
Timeout int `json:"timeout"`
DisableUDP bool `json:"disable_udp"`
User string `json:"user"`
Password string `json:"password"`
} `json:"socks5"`
HTTP struct {
Listen string `json:"listen"`
Timeout int `json:"timeout"`
User string `json:"user"`
Password string `json:"password"`
Cert string `json:"cert"`
Key string `json:"key"`
} `json:"http"`
Relay struct {
Listen string `json:"listen"`
Remote string `json:"remote"`
Timeout int `json:"timeout"`
} `json:"relay"`
ACL string `json:"acl"`
Obfs string `json:"obfs"`
Auth []byte `json:"auth"`
AuthString string `json:"auth_str"`
Insecure bool `json:"insecure"`
CustomCA string `json:"ca"`
ReceiveWindowConn uint64 `json:"recv_window_conn"`
ReceiveWindow uint64 `json:"recv_window"`
}
func (c *clientConfig) Check() error {
if len(c.SOCKS5.Listen) == 0 && len(c.HTTP.Listen) == 0 && len(c.Relay.Listen) == 0 {
return errors.New("no SOCKS5, HTTP or relay listen address")
}
if len(c.Relay.Listen) > 0 && len(c.Relay.Remote) == 0 {
return errors.New("no relay remote address")
}
if c.SOCKS5.Timeout != 0 && c.SOCKS5.Timeout <= 4 {
return errors.New("invalid SOCKS5 timeout")
}
if c.HTTP.Timeout != 0 && c.HTTP.Timeout <= 4 {
return errors.New("invalid HTTP timeout")
}
if c.Relay.Timeout != 0 && c.Relay.Timeout <= 4 {
return errors.New("invalid relay timeout")
}
if len(c.Server) == 0 {
return errors.New("no server address")
}
if c.UpMbps <= 0 || c.DownMbps <= 0 {
return errors.New("invalid speed")
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) {
return errors.New("invalid receive window size")
}
return nil
}
func (c *clientConfig) String() string {
return fmt.Sprintf("%+v", *c)
}

View File

@@ -1,11 +1,13 @@
package main
import (
"flag"
"fmt"
"github.com/sirupsen/logrus"
"github.com/yosuke-furukawa/json5/encoding/json5"
"io/ioutil"
"os"
"strings"
"github.com/sirupsen/logrus"
)
// Injected when compiling
@@ -15,12 +17,10 @@ var (
appDate = "Unknown"
)
var modeMap = map[string]func(args []string){
"relay server": relayServer,
"relay client": relayClient,
"proxy server": proxyServer,
"proxy client": proxyClient,
}
var (
configPath = flag.String("config", "config.json", "Config file")
showVersion = flag.Bool("version", false, "Show version")
)
func init() {
logrus.SetOutput(os.Stdout)
@@ -50,37 +50,68 @@ func init() {
TimestampFormat: tsFormat,
})
}
flag.Parse()
}
func main() {
if len(os.Args) == 2 && strings.ToLower(strings.TrimSpace(os.Args[1])) == "version" {
if *showVersion {
// Print version and quit
fmt.Printf("%-10s%s\n", "Version:", appVersion)
fmt.Printf("%-10s%s\n", "Commit:", appCommit)
fmt.Printf("%-10s%s\n", "Date:", appDate)
return
}
if len(os.Args) < 3 {
fmt.Println()
fmt.Printf("Usage: %s MODE SUBMODE [OPTIONS]\n\n"+
"Available mode/submode combinations: "+getModes()+"\n"+
"Use -h to see the available options for a mode.\n\n", os.Args[0])
return
cb, err := ioutil.ReadFile(*configPath)
if err != nil {
logrus.WithFields(logrus.Fields{
"file": *configPath,
"error": err,
}).Fatal("Failed to read configuration")
}
modeStr := fmt.Sprintf("%s %s", strings.ToLower(strings.TrimSpace(os.Args[1])),
strings.ToLower(strings.TrimSpace(os.Args[2])))
f := modeMap[modeStr]
if f != nil {
f(os.Args[3:])
mode := flag.Arg(0)
if strings.EqualFold(mode, "server") {
// server mode
c, err := parseServerConfig(cb)
if err != nil {
logrus.WithFields(logrus.Fields{
"file": *configPath,
"error": err,
}).Fatal("Failed to parse server configuration")
}
server(c)
} else if len(mode) == 0 || strings.EqualFold(mode, "client") {
// client mode
c, err := parseClientConfig(cb)
if err != nil {
logrus.WithFields(logrus.Fields{
"file": *configPath,
"error": err,
}).Fatal("Failed to parse client configuration")
}
client(c)
} else {
fmt.Println("Invalid mode:", modeStr)
// invalid
fmt.Println()
fmt.Printf("Usage: %s MODE [OPTIONS]\n\n"+
"Available modes: server, client\n\n", os.Args[0])
}
}
func getModes() string {
modes := make([]string, 0, len(modeMap))
for mode := range modeMap {
modes = append(modes, mode)
func parseServerConfig(cb []byte) (*serverConfig, error) {
var c serverConfig
err := json5.Unmarshal(cb, &c)
if err != nil {
return nil, err
}
return strings.Join(modes, ", ")
return &c, c.Check()
}
func parseClientConfig(cb []byte) (*clientConfig, error) {
var c clientConfig
err := json5.Unmarshal(cb, &c)
if err != nil {
return nil, err
}
return &c, c.Check()
}

View File

@@ -1,99 +0,0 @@
package main
import (
"errors"
"fmt"
)
const proxyTLSProtocol = "hysteria-proxy"
type proxyClientConfig struct {
SOCKS5Addr string `json:"socks5_addr" desc:"SOCKS5 listen address"`
SOCKS5Timeout int `json:"socks5_timeout" desc:"SOCKS5 connection timeout in seconds"`
SOCKS5DisableUDP bool `json:"socks5_disable_udp" desc:"Disable SOCKS5 UDP support"`
SOCKS5User string `json:"socks5_user" desc:"SOCKS5 auth username"`
SOCKS5Password string `json:"socks5_password" desc:"SOCKS5 auth password"`
HTTPAddr string `json:"http_addr" desc:"HTTP listen address"`
HTTPTimeout int `json:"http_timeout" desc:"HTTP connection timeout in seconds"`
HTTPUser string `json:"http_user" desc:"HTTP basic auth username"`
HTTPPassword string `json:"http_password" desc:"HTTP basic auth password"`
HTTPSCert string `json:"https_cert" desc:"HTTPS certificate file"`
HTTPSKey string `json:"https_key" desc:"HTTPS key file"`
ACLFile string `json:"acl" desc:"Access control list"`
ServerAddr string `json:"server" desc:"Server address"`
Username string `json:"username" desc:"Authentication username"`
Password string `json:"password" desc:"Authentication password"`
Insecure bool `json:"insecure" desc:"Ignore TLS certificate errors"`
CustomCAFile string `json:"ca" desc:"Specify a trusted CA file"`
UpMbps int `json:"up_mbps" desc:"Upload speed in Mbps"`
DownMbps int `json:"down_mbps" desc:"Download speed in Mbps"`
ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"`
ReceiveWindow uint64 `json:"recv_window" desc:"Max receive window size"`
Obfs string `json:"obfs" desc:"Obfuscation key"`
}
func (c *proxyClientConfig) Check() error {
if len(c.SOCKS5Addr) == 0 && len(c.HTTPAddr) == 0 {
return errors.New("no SOCKS5 or HTTP listen address")
}
if c.SOCKS5Timeout != 0 && c.SOCKS5Timeout <= 4 {
return errors.New("invalid SOCKS5 timeout")
}
if c.HTTPTimeout != 0 && c.HTTPTimeout <= 4 {
return errors.New("invalid HTTP timeout")
}
if len(c.ServerAddr) == 0 {
return errors.New("no server address")
}
if c.UpMbps <= 0 || c.DownMbps <= 0 {
return errors.New("invalid speed")
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) {
return errors.New("invalid receive window size")
}
return nil
}
func (c *proxyClientConfig) String() string {
return fmt.Sprintf("%+v", *c)
}
type proxyServerConfig struct {
ListenAddr string `json:"listen" desc:"Server listen address"`
DisableUDP bool `json:"disable_udp" desc:"Disable UDP support"`
ACLFile string `json:"acl" desc:"Access control list"`
CertFile string `json:"cert" desc:"TLS certificate file"`
KeyFile string `json:"key" desc:"TLS key file"`
AuthFile string `json:"auth" desc:"Authentication file"`
UpMbps int `json:"up_mbps" desc:"Max upload speed per client in Mbps"`
DownMbps int `json:"down_mbps" desc:"Max download speed per client in Mbps"`
ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"`
ReceiveWindowClient uint64 `json:"recv_window_client" desc:"Max receive window size per client"`
MaxConnClient int `json:"max_conn_client" desc:"Max simultaneous connections allowed per client"`
Obfs string `json:"obfs" desc:"Obfuscation key"`
}
func (c *proxyServerConfig) Check() error {
if len(c.ListenAddr) == 0 {
return errors.New("no listen address")
}
if len(c.CertFile) == 0 || len(c.KeyFile) == 0 {
return errors.New("TLS cert or key not provided")
}
if c.UpMbps < 0 || c.DownMbps < 0 {
return errors.New("invalid speed")
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) {
return errors.New("invalid receive window size")
}
if c.MaxConnClient < 0 {
return errors.New("invalid max connections per client")
}
return nil
}
func (c *proxyServerConfig) String() string {
return fmt.Sprintf("%+v", *c)
}

View File

@@ -1,298 +0,0 @@
package main
import (
"bufio"
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/sirupsen/logrus"
"github.com/tobyxdd/hysteria/pkg/acl"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/core"
"github.com/tobyxdd/hysteria/pkg/obfs"
"io"
"net"
"os"
"strings"
)
func proxyServer(args []string) {
var config proxyServerConfig
err := loadConfig(&config, args)
if err != nil {
logrus.WithField("error", err).Fatal("Unable to load configuration")
}
if err := config.Check(); err != nil {
logrus.WithField("error", err).Fatal("Configuration error")
}
logrus.WithField("config", config.String()).Info("Configuration loaded")
// Load cert
cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"cert": config.CertFile,
"key": config.KeyFile,
}).Fatal("Unable to load the certificate")
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{proxyTLSProtocol},
MinVersion: tls.VersionTLS13,
}
quicConfig := &quic.Config{
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient,
MaxIncomingStreams: int64(config.MaxConnClient),
KeepAlive: true,
}
if quicConfig.MaxReceiveStreamFlowControlWindow == 0 {
quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow
}
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
}
if quicConfig.MaxIncomingStreams == 0 {
quicConfig.MaxIncomingStreams = DefaultMaxIncomingStreams
}
if len(config.AuthFile) == 0 {
logrus.Warn("No authentication configured, this server can be used by anyone")
}
var obfuscator core.Obfuscator
if len(config.Obfs) > 0 {
obfuscator = obfs.XORObfuscator(config.Obfs)
}
var aclEngine *acl.Engine
if len(config.ACLFile) > 0 {
aclEngine, err = acl.LoadFromFile(config.ACLFile)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"file": config.ACLFile,
}).Fatal("Unable to parse ACL")
}
aclEngine.DefaultAction = acl.ActionDirect
}
server, err := core.NewServer(config.ListenAddr, tlsConfig, quicConfig,
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
},
obfuscator,
func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) {
if len(config.AuthFile) == 0 {
logrus.WithFields(logrus.Fields{
"addr": addr.String(),
"username": username,
"up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps,
}).Info("Client connected")
return core.AuthResultSuccess, ""
} else {
// Need auth
ok, err := checkAuth(config.AuthFile, username, password)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err.Error(),
"addr": addr.String(),
"username": username,
}).Error("Client authentication error")
return core.AuthResultInternalError, "Server auth error"
}
if ok {
logrus.WithFields(logrus.Fields{
"addr": addr.String(),
"username": username,
"up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps,
}).Info("Client authenticated")
return core.AuthResultSuccess, ""
} else {
logrus.WithFields(logrus.Fields{
"addr": addr.String(),
"username": username,
"up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps,
}).Info("Client rejected due to invalid credential")
return core.AuthResultInvalidCred, "Invalid credential"
}
}
},
func(addr net.Addr, username string, err error) {
logrus.WithFields(logrus.Fields{
"error": err.Error(),
"addr": addr.String(),
"username": username,
}).Info("Client disconnected")
},
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
packet := reqType == core.ConnectionTypePacket
if packet && config.DisableUDP {
return core.ConnectResultBlocked, "UDP disabled", nil
}
host, port, err := net.SplitHostPort(reqAddr)
if err != nil {
return core.ConnectResultFailed, err.Error(), nil
}
ip := net.ParseIP(host)
if ip != nil {
// IP request, clear host for ACL engine
host = ""
}
action, arg := acl.ActionDirect, ""
if aclEngine != nil {
action, arg = aclEngine.Lookup(host, ip)
}
switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
if !packet {
// TCP
logrus.WithFields(logrus.Fields{
"action": "direct",
"username": username,
"src": addr.String(),
"dst": reqAddr,
}).Debug("New TCP request")
conn, err := net.DialTimeout("tcp", reqAddr, dialTimeout)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"dst": reqAddr,
}).Error("TCP error")
return core.ConnectResultFailed, err.Error(), nil
}
return core.ConnectResultSuccess, "", conn
} else {
// UDP
logrus.WithFields(logrus.Fields{
"action": "direct",
"username": username,
"src": addr.String(),
"dst": reqAddr,
}).Debug("New UDP request")
conn, err := net.Dial("udp", reqAddr)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"dst": reqAddr,
}).Error("UDP error")
return core.ConnectResultFailed, err.Error(), nil
}
return core.ConnectResultSuccess, "", conn
}
case acl.ActionBlock:
if !packet {
// TCP
logrus.WithFields(logrus.Fields{
"action": "block",
"username": username,
"src": addr.String(),
"dst": reqAddr,
}).Debug("New TCP request")
return core.ConnectResultBlocked, "blocked by ACL", nil
} else {
// UDP
logrus.WithFields(logrus.Fields{
"action": "block",
"username": username,
"src": addr.String(),
"dst": reqAddr,
}).Debug("New UDP request")
return core.ConnectResultBlocked, "blocked by ACL", nil
}
case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, port)
if !packet {
// TCP
logrus.WithFields(logrus.Fields{
"action": "hijack",
"username": username,
"src": addr.String(),
"dst": reqAddr,
"rdst": arg,
}).Debug("New TCP request")
conn, err := net.DialTimeout("tcp", hijackAddr, dialTimeout)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"dst": hijackAddr,
}).Error("TCP error")
return core.ConnectResultFailed, err.Error(), nil
}
return core.ConnectResultSuccess, "", conn
} else {
// UDP
logrus.WithFields(logrus.Fields{
"action": "hijack",
"username": username,
"src": addr.String(),
"dst": reqAddr,
"rdst": arg,
}).Debug("New UDP request")
conn, err := net.Dial("udp", hijackAddr)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"dst": hijackAddr,
}).Error("UDP error")
return core.ConnectResultFailed, err.Error(), nil
}
return core.ConnectResultSuccess, "", conn
}
default:
return core.ConnectResultFailed, "server ACL error", nil
}
},
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
packet := reqType == core.ConnectionTypePacket
if !packet {
logrus.WithFields(logrus.Fields{
"error": err,
"username": username,
"src": addr.String(),
"dst": reqAddr,
}).Debug("TCP request closed")
} else {
logrus.WithFields(logrus.Fields{
"error": err,
"username": username,
"src": addr.String(),
"dst": reqAddr,
}).Debug("UDP request closed")
}
},
)
if err != nil {
logrus.WithField("error", err).Fatal("Server initialization failed")
}
defer server.Close()
logrus.WithField("addr", config.ListenAddr).Info("Server up and running")
err = server.Serve()
logrus.WithField("error", err).Fatal("Server shutdown")
}
func checkAuth(authFile, username, password string) (bool, error) {
f, err := os.Open(authFile)
if err != nil {
return false, err
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
pair := strings.Fields(scanner.Text())
if len(pair) != 2 {
// Invalid format
continue
}
if username == pair[0] && password == pair[1] {
return true, nil
}
}
return false, nil
}

View File

@@ -1,119 +0,0 @@
package main
import (
"crypto/tls"
"crypto/x509"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/sirupsen/logrus"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/core"
"github.com/tobyxdd/hysteria/pkg/obfs"
"github.com/tobyxdd/hysteria/pkg/utils"
"io/ioutil"
"net"
"os/user"
)
func relayClient(args []string) {
var config relayClientConfig
err := loadConfig(&config, args)
if err != nil {
logrus.WithField("error", err).Fatal("Unable to load configuration")
}
if err := config.Check(); err != nil {
logrus.WithField("error", err).Fatal("Configuration error")
}
if len(config.Name) == 0 {
usr, err := user.Current()
if err == nil {
config.Name = usr.Name
}
}
logrus.WithField("config", config.String()).Info("Configuration loaded")
tlsConfig := &tls.Config{
InsecureSkipVerify: config.Insecure,
NextProtos: []string{relayTLSProtocol},
MinVersion: tls.VersionTLS13,
}
// Load CA
if len(config.CustomCAFile) > 0 {
bs, err := ioutil.ReadFile(config.CustomCAFile)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"file": config.CustomCAFile,
}).Fatal("Unable to load CA file")
}
cp := x509.NewCertPool()
if !cp.AppendCertsFromPEM(bs) {
logrus.WithFields(logrus.Fields{
"file": config.CustomCAFile,
}).Fatal("Unable to parse CA file")
}
tlsConfig.RootCAs = cp
}
quicConfig := &quic.Config{
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindow,
KeepAlive: true,
}
if quicConfig.MaxReceiveStreamFlowControlWindow == 0 {
quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow
}
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
}
var obfuscator core.Obfuscator
if len(config.Obfs) > 0 {
obfuscator = obfs.XORObfuscator(config.Obfs)
}
client, err := core.NewClient(config.ServerAddr, config.Name, "", tlsConfig, quicConfig,
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
}, obfuscator)
if err != nil {
logrus.WithField("error", err).Fatal("Client initialization failed")
}
defer client.Close()
logrus.WithField("addr", config.ServerAddr).Info("Connected")
listener, err := net.Listen("tcp", config.ListenAddr)
if err != nil {
logrus.WithField("error", err).Fatal("TCP listen failed")
}
defer listener.Close()
logrus.WithField("addr", listener.Addr().String()).Info("TCP server listening")
for {
conn, err := listener.Accept()
if err != nil {
logrus.WithField("error", err).Fatal("TCP accept failed")
}
go relayClientHandleConn(conn, client)
}
}
func relayClientHandleConn(conn net.Conn, client *core.Client) {
logrus.WithField("src", conn.RemoteAddr().String()).Debug("New connection")
var closeErr error
defer func() {
_ = conn.Close()
logrus.WithFields(logrus.Fields{
"error": closeErr,
"src": conn.RemoteAddr().String(),
}).Debug("Connection closed")
}()
rwc, err := client.Dial(false, "")
if err != nil {
closeErr = err
return
}
defer rwc.Close()
closeErr = utils.PipePair(conn, rwc, nil, nil)
}

View File

@@ -1,82 +0,0 @@
package main
import (
"errors"
"fmt"
)
const relayTLSProtocol = "hysteria-relay"
type relayClientConfig struct {
ListenAddr string `json:"listen" desc:"TCP listen address"`
ServerAddr string `json:"server" desc:"Server address"`
Name string `json:"name" desc:"Client name presented to the server"`
Insecure bool `json:"insecure" desc:"Ignore TLS certificate errors"`
CustomCAFile string `json:"ca" desc:"Specify a trusted CA file"`
UpMbps int `json:"up_mbps" desc:"Upload speed in Mbps"`
DownMbps int `json:"down_mbps" desc:"Download speed in Mbps"`
ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"`
ReceiveWindow uint64 `json:"recv_window" desc:"Max receive window size"`
Obfs string `json:"obfs" desc:"Obfuscation key"`
}
func (c *relayClientConfig) Check() error {
if len(c.ListenAddr) == 0 {
return errors.New("no listen address")
}
if len(c.ServerAddr) == 0 {
return errors.New("no server address")
}
if c.UpMbps <= 0 || c.DownMbps <= 0 {
return errors.New("invalid speed")
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindow != 0 && c.ReceiveWindow < 65536) {
return errors.New("invalid receive window size")
}
return nil
}
func (c *relayClientConfig) String() string {
return fmt.Sprintf("%+v", *c)
}
type relayServerConfig struct {
ListenAddr string `json:"listen" desc:"Server listen address"`
RemoteAddr string `json:"remote" desc:"Remote relay address"`
CertFile string `json:"cert" desc:"TLS certificate file"`
KeyFile string `json:"key" desc:"TLS key file"`
UpMbps int `json:"up_mbps" desc:"Max upload speed per client in Mbps"`
DownMbps int `json:"down_mbps" desc:"Max download speed per client in Mbps"`
ReceiveWindowConn uint64 `json:"recv_window_conn" desc:"Max receive window size per connection"`
ReceiveWindowClient uint64 `json:"recv_window_client" desc:"Max receive window size per client"`
MaxConnClient int `json:"max_conn_client" desc:"Max simultaneous connections allowed per client"`
Obfs string `json:"obfs" desc:"Obfuscation key"`
}
func (c *relayServerConfig) Check() error {
if len(c.ListenAddr) == 0 {
return errors.New("no listen address")
}
if len(c.RemoteAddr) == 0 {
return errors.New("no remote address")
}
if len(c.CertFile) == 0 || len(c.KeyFile) == 0 {
return errors.New("TLS cert or key not provided")
}
if c.UpMbps < 0 || c.DownMbps < 0 {
return errors.New("invalid speed")
}
if (c.ReceiveWindowConn != 0 && c.ReceiveWindowConn < 65536) ||
(c.ReceiveWindowClient != 0 && c.ReceiveWindowClient < 65536) {
return errors.New("invalid receive window size")
}
if c.MaxConnClient < 0 {
return errors.New("invalid max connections per client")
}
return nil
}
func (c *relayServerConfig) String() string {
return fmt.Sprintf("%+v", *c)
}

View File

@@ -1,121 +0,0 @@
package main
import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/sirupsen/logrus"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/core"
"github.com/tobyxdd/hysteria/pkg/obfs"
"io"
"net"
)
func relayServer(args []string) {
var config relayServerConfig
err := loadConfig(&config, args)
if err != nil {
logrus.WithField("error", err).Fatal("Unable to load configuration")
}
if err := config.Check(); err != nil {
logrus.WithField("error", err).Fatal("Configuration error")
}
logrus.WithField("config", config.String()).Info("Configuration loaded")
// Load cert
cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"cert": config.CertFile,
"key": config.KeyFile,
}).Fatal("Unable to load the certificate")
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{relayTLSProtocol},
MinVersion: tls.VersionTLS13,
}
quicConfig := &quic.Config{
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient,
MaxIncomingStreams: int64(config.MaxConnClient),
KeepAlive: true,
}
if quicConfig.MaxReceiveStreamFlowControlWindow == 0 {
quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow
}
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
}
if quicConfig.MaxIncomingStreams == 0 {
quicConfig.MaxIncomingStreams = DefaultMaxIncomingStreams
}
var obfuscator core.Obfuscator
if len(config.Obfs) > 0 {
obfuscator = obfs.XORObfuscator(config.Obfs)
}
server, err := core.NewServer(config.ListenAddr, tlsConfig, quicConfig,
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
},
obfuscator,
func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) {
// No authentication logic in relay, just log username and speed
logrus.WithFields(logrus.Fields{
"addr": addr.String(),
"username": username,
"up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps,
}).Info("Client connected")
return core.AuthResultSuccess, ""
},
func(addr net.Addr, username string, err error) {
logrus.WithFields(logrus.Fields{
"error": err.Error(),
"addr": addr.String(),
"username": username,
}).Info("Client disconnected")
},
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
packet := reqType == core.ConnectionTypePacket
logrus.WithFields(logrus.Fields{
"username": username,
"src": addr.String(),
"id": id,
}).Debug("New stream")
if packet {
return core.ConnectResultBlocked, "unsupported", nil
}
conn, err := net.DialTimeout("tcp", config.RemoteAddr, dialTimeout)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"dst": config.RemoteAddr,
}).Error("TCP error")
return core.ConnectResultFailed, err.Error(), nil
}
return core.ConnectResultSuccess, "", conn
},
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
logrus.WithFields(logrus.Fields{
"error": err,
"username": username,
"src": addr.String(),
"id": id,
}).Debug("Stream closed")
},
)
if err != nil {
logrus.WithField("error", err).Fatal("Server initialization failed")
}
defer server.Close()
logrus.WithField("addr", config.ListenAddr).Info("Server up and running")
err = server.Serve()
logrus.WithField("error", err).Fatal("Server shutdown")
}

99
cmd/server.go Normal file
View File

@@ -0,0 +1,99 @@
package main
import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/sirupsen/logrus"
"github.com/tobyxdd/hysteria/pkg/acl"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/core"
"github.com/tobyxdd/hysteria/pkg/obfs"
"net"
"strings"
)
func server(config *serverConfig) {
logrus.WithField("config", config.String()).Info("Server configuration loaded")
// Load cert
cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"cert": config.CertFile,
"key": config.KeyFile,
}).Fatal("Failed to load the certificate")
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{tlsProtocolName},
MinVersion: tls.VersionTLS13,
}
// QUIC config
quicConfig := &quic.Config{
MaxReceiveStreamFlowControlWindow: config.ReceiveWindowConn,
MaxReceiveConnectionFlowControlWindow: config.ReceiveWindowClient,
MaxIncomingStreams: int64(config.MaxConnClient),
KeepAlive: true,
}
if quicConfig.MaxReceiveStreamFlowControlWindow == 0 {
quicConfig.MaxReceiveStreamFlowControlWindow = DefaultMaxReceiveStreamFlowControlWindow
}
if quicConfig.MaxReceiveConnectionFlowControlWindow == 0 {
quicConfig.MaxReceiveConnectionFlowControlWindow = DefaultMaxReceiveConnectionFlowControlWindow
}
if quicConfig.MaxIncomingStreams == 0 {
quicConfig.MaxIncomingStreams = DefaultMaxIncomingStreams
}
// Auth
var authFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
if len(config.Auth.Mode) == 0 || strings.EqualFold(config.Auth.Mode, "none") {
logrus.Warn("No authentication configured")
authFunc = func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) {
return true, "Welcome"
}
} else {
// TODO
logrus.WithField("mode", config.Auth.Mode).Fatal("Unsupported authentication mode")
}
// Obfuscator
var obfuscator core.Obfuscator
if len(config.Obfs) > 0 {
obfuscator = obfs.XORObfuscator(config.Obfs)
}
// ACL
var aclEngine *acl.Engine
if len(config.ACL) > 0 {
aclEngine, err = acl.LoadFromFile(config.ACL)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
"file": config.ACL,
}).Fatal("Failed to parse ACL")
}
aclEngine.DefaultAction = acl.ActionDirect
}
// Server
server, err := core.NewServer(config.Listen, tlsConfig, quicConfig,
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
}, aclEngine, obfuscator, authFunc, func(addr net.Addr, auth []byte, udp bool, reqAddr string) {
if !udp {
logrus.WithFields(logrus.Fields{
"src": addr.String(),
"dst": reqAddr,
}).Debug("New TCP request")
} else {
// TODO
}
})
if err != nil {
logrus.WithField("error", err).Fatal("Failed to initialize server")
}
defer server.Close()
logrus.WithField("addr", config.Listen).Info("Server up and running")
err = server.Serve()
logrus.WithField("error", err).Fatal("Server shutdown")
}

View File

@@ -1,28 +0,0 @@
package main
import (
"strconv"
)
type optionalBoolFlag struct {
Exists bool
Value bool
}
func (flag *optionalBoolFlag) String() string {
return strconv.FormatBool(flag.Value)
}
func (flag *optionalBoolFlag) Set(s string) error {
v, err := strconv.ParseBool(s)
if err != nil {
return err
}
flag.Exists = true
flag.Value = v
return nil
}
func (flag *optionalBoolFlag) IsBoolFlag() bool {
return true
}

2
go.mod
View File

@@ -8,11 +8,13 @@ require (
github.com/golang/protobuf v1.4.2
github.com/hashicorp/golang-lru v0.5.4
github.com/lucas-clemente/quic-go v0.19.3
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/sirupsen/logrus v1.6.0
github.com/txthinking/runnergroup v0.0.0-20200327135940-540a793bb997 // indirect
github.com/txthinking/socks5 v0.0.0-20200327133705-caf148ab5e9d
github.com/txthinking/x v0.0.0-20200330144832-5ad2416896a9 // indirect
github.com/yosuke-furukawa/json5 v0.1.1
)
replace github.com/lucas-clemente/quic-go => github.com/tobyxdd/quic-go v0.19.4-0.20210127052624-0ecb862c82b5

4
go.sum
View File

@@ -87,6 +87,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 h1:EnfXoSqDfSNJv0VBNqY/88RNnhSGYkrHaO0mmFGbVsc=
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40/go.mod h1:vy1vK6wD6j7xX6O6hXe621WabdtNkou2h7uRtTfRMyg=
github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI=
github.com/mailru/easyjson v0.0.0-20190312143242-1de009706dbe/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/marten-seemann/qpack v0.2.1/go.mod h1:F7Gl5L1jIgN1D11ucXefiuJS9UMVP2opoCp2jDKb7wc=
@@ -166,6 +168,8 @@ github.com/txthinking/x v0.0.0-20200330144832-5ad2416896a9 h1:ngJOce33YJJT1PFTfC
github.com/txthinking/x v0.0.0-20200330144832-5ad2416896a9/go.mod h1:WgqbSEmUYSjEV3B1qmee/PpP2NYEz4bL9/+mF1ma+s4=
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
github.com/yosuke-furukawa/json5 v0.1.1 h1:0F9mNwTvOuDNH243hoPqvf+dxa5QsKnZzU20uNsh3ZI=
github.com/yosuke-furukawa/json5 v0.1.1/go.mod h1:sw49aWDqNdRJ6DYUtIQiaA3xyj2IL9tjeNYmX2ixwcU=
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE=

View File

@@ -6,44 +6,45 @@ import (
"errors"
"fmt"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/pkg/core/pb"
"github.com/tobyxdd/hysteria/pkg/utils"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/lunixbochs/struc"
"net"
"sync"
"sync/atomic"
"time"
)
var (
ErrClosed = errors.New("client closed")
)
type Client struct {
inboundBytes, outboundBytes uint64 // atomic
type CongestionFactory func(refBPS uint64) congestion.CongestionControl
reconnectMutex sync.Mutex
closed bool
quicSession quic.Session
serverAddr string
username, password string
tlsConfig *tls.Config
quicConfig *quic.Config
sendBPS, recvBPS uint64
congestionFactory CongestionFactory
obfuscator Obfuscator
type Client struct {
serverAddr string
sendBPS, recvBPS uint64
auth []byte
congestionFactory CongestionFactory
obfuscator Obfuscator
tlsConfig *tls.Config
quicConfig *quic.Config
quicSession quic.Session
reconnectMutex sync.Mutex
closed bool
}
func NewClient(serverAddr string, username string, password string, tlsConfig *tls.Config, quicConfig *quic.Config,
func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator Obfuscator) (*Client, error) {
c := &Client{
serverAddr: serverAddr,
username: username,
password: password,
tlsConfig: tlsConfig,
quicConfig: quicConfig,
sendBPS: sendBPS,
recvBPS: recvBPS,
auth: auth,
congestionFactory: congestionFactory,
obfuscator: obfuscator,
tlsConfig: tlsConfig,
quicConfig: quicConfig,
}
if err := c.connectToServer(); err != nil {
return nil, err
@@ -51,58 +52,6 @@ func NewClient(serverAddr string, username string, password string, tlsConfig *t
return c, nil
}
func (c *Client) Dial(packet bool, addr string) (net.Conn, error) {
stream, localAddr, remoteAddr, err := c.openStreamWithReconnect()
if err != nil {
return nil, err
}
// Send request
req := &pb.ClientConnectRequest{Address: addr}
if packet {
req.Type = pb.ConnectionType_Packet
} else {
req.Type = pb.ConnectionType_Stream
}
err = writeClientConnectRequest(stream, req)
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
resp, err := readServerConnectResponse(stream)
if err != nil {
_ = stream.Close()
return nil, err
}
if resp.Result != pb.ConnectResult_CONN_SUCCESS {
_ = stream.Close()
return nil, fmt.Errorf("server rejected the connection %s (msg: %s)",
resp.Result.String(), resp.Message)
}
connWrap := &utils.QUICStreamWrapperConn{
Orig: stream,
PseudoLocalAddr: localAddr,
PseudoRemoteAddr: remoteAddr,
}
if packet {
return &utils.PacketWrapperConn{Orig: connWrap}, nil
} else {
return connWrap, nil
}
}
func (c *Client) Stats() (uint64, uint64) {
return atomic.LoadUint64(&c.inboundBytes), atomic.LoadUint64(&c.outboundBytes)
}
func (c *Client) Close() error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "generic")
c.closed = true
return err
}
func (c *Client) connectToServer() error {
serverUDPAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
if err != nil {
@@ -124,51 +73,50 @@ func (c *Client) connectToServer() error {
return err
}
// Control stream
ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout)
ctlStream, err := qs.OpenStreamSync(ctx)
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
stream, err := qs.OpenStreamSync(ctx)
ctxCancel()
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error")
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return err
}
result, msg, err := c.handleControlStream(qs, ctlStream)
ok, msg, err := c.handleControlStream(qs, stream)
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return err
}
if result != pb.AuthResult_AUTH_SUCCESS {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "authentication failure")
return fmt.Errorf("authentication failure %s (msg: %s)", result.String(), msg)
if !ok {
_ = qs.CloseWithError(closeErrorCodeAuth, "auth error")
return fmt.Errorf("auth error: %s", msg)
}
// All good
c.quicSession = qs
return nil
}
func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (pb.AuthResult, string, error) {
err := writeClientAuthRequest(stream, &pb.ClientAuthRequest{
Credential: &pb.Credential{
Username: c.username,
Password: c.password,
},
Speed: &pb.Speed{
SendBps: c.sendBPS,
ReceiveBps: c.recvBPS,
func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool, string, error) {
// Send client hello
err := struc.Pack(stream, &clientHello{
Rate: transmissionRate{
SendBPS: c.sendBPS,
RecvBPS: c.recvBPS,
},
Auth: c.auth,
})
if err != nil {
return 0, "", err
return false, "", err
}
// Response
resp, err := readServerAuthResponse(stream)
// Receive server hello
var sh serverHello
err = struc.Unpack(stream, &sh)
if err != nil {
return 0, "", err
return false, "", err
}
// Set the congestion accordingly
if resp.Result == pb.AuthResult_AUTH_SUCCESS && c.congestionFactory != nil {
qs.SetCongestionControl(c.congestionFactory(resp.Speed.ReceiveBps))
if sh.OK && c.congestionFactory != nil {
qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS))
}
return resp.Result, resp.Message, nil
return true, sh.Message, nil
}
func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, error) {
@@ -196,3 +144,81 @@ func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, err
stream, err = c.quicSession.OpenStream()
return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), err
}
func (c *Client) DialTCP(addr string) (net.Conn, error) {
stream, localAddr, remoteAddr, err := c.openStreamWithReconnect()
if err != nil {
return nil, err
}
// Send request
err = struc.Pack(stream, &clientRequest{
UDP: false,
Address: addr,
})
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
var sr serverResponse
err = struc.Unpack(stream, &sr)
if err != nil {
_ = stream.Close()
return nil, err
}
if !sr.OK {
_ = stream.Close()
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
}
return &quicConn{
Orig: stream,
PseudoLocalAddr: localAddr,
PseudoRemoteAddr: remoteAddr,
}, nil
}
func (c *Client) Close() error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "")
c.closed = true
return err
}
type quicConn struct {
Orig quic.Stream
PseudoLocalAddr net.Addr
PseudoRemoteAddr net.Addr
}
func (w *quicConn) Read(b []byte) (n int, err error) {
return w.Orig.Read(b)
}
func (w *quicConn) Write(b []byte) (n int, err error) {
return w.Orig.Write(b)
}
func (w *quicConn) Close() error {
return w.Orig.Close()
}
func (w *quicConn) LocalAddr() net.Addr {
return w.PseudoLocalAddr
}
func (w *quicConn) RemoteAddr() net.Addr {
return w.PseudoRemoteAddr
}
func (w *quicConn) SetDeadline(t time.Time) error {
return w.Orig.SetDeadline(t)
}
func (w *quicConn) SetReadDeadline(t time.Time) error {
return w.Orig.SetReadDeadline(t)
}
func (w *quicConn) SetWriteDeadline(t time.Time) error {
return w.Orig.SetWriteDeadline(t)
}

View File

@@ -1,104 +0,0 @@
package core
import (
"encoding/binary"
"github.com/golang/protobuf/proto"
"github.com/tobyxdd/hysteria/pkg/core/pb"
"io"
)
const (
closeErrorCodeGeneric = 0
closeErrorCodeProtocolFailure = 1
)
func readDataBlock(r io.Reader) ([]byte, error) {
var sz uint32
if err := binary.Read(r, controlProtocolEndian, &sz); err != nil {
return nil, err
}
buf := make([]byte, sz)
_, err := io.ReadFull(r, buf)
return buf, err
}
func writeDataBlock(w io.Writer, data []byte) error {
sz := uint32(len(data))
if err := binary.Write(w, controlProtocolEndian, &sz); err != nil {
return err
}
_, err := w.Write(data)
return err
}
func readClientAuthRequest(r io.Reader) (*pb.ClientAuthRequest, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var req pb.ClientAuthRequest
err = proto.Unmarshal(bs, &req)
return &req, err
}
func writeClientAuthRequest(w io.Writer, req *pb.ClientAuthRequest) error {
bs, err := proto.Marshal(req)
if err != nil {
return err
}
return writeDataBlock(w, bs)
}
func readServerAuthResponse(r io.Reader) (*pb.ServerAuthResponse, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var resp pb.ServerAuthResponse
err = proto.Unmarshal(bs, &resp)
return &resp, err
}
func writeServerAuthResponse(w io.Writer, resp *pb.ServerAuthResponse) error {
bs, err := proto.Marshal(resp)
if err != nil {
return err
}
return writeDataBlock(w, bs)
}
func readClientConnectRequest(r io.Reader) (*pb.ClientConnectRequest, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var req pb.ClientConnectRequest
err = proto.Unmarshal(bs, &req)
return &req, err
}
func writeClientConnectRequest(w io.Writer, req *pb.ClientConnectRequest) error {
bs, err := proto.Marshal(req)
if err != nil {
return err
}
return writeDataBlock(w, bs)
}
func readServerConnectResponse(r io.Reader) (*pb.ServerConnectResponse, error) {
bs, err := readDataBlock(r)
if err != nil {
return nil, err
}
var resp pb.ServerConnectResponse
err = proto.Unmarshal(bs, &resp)
return &resp, err
}
func writeServerConnectResponse(w io.Writer, resp *pb.ServerConnectResponse) error {
bs, err := proto.Marshal(resp)
if err != nil {
return err
}
return writeDataBlock(w, bs)
}

View File

@@ -1,440 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: control.proto
package pb
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type AuthResult int32
const (
AuthResult_AUTH_SUCCESS AuthResult = 0
AuthResult_AUTH_INVALID_CRED AuthResult = 1
AuthResult_AUTH_INTERNAL_ERROR AuthResult = 2
)
var AuthResult_name = map[int32]string{
0: "AUTH_SUCCESS",
1: "AUTH_INVALID_CRED",
2: "AUTH_INTERNAL_ERROR",
}
var AuthResult_value = map[string]int32{
"AUTH_SUCCESS": 0,
"AUTH_INVALID_CRED": 1,
"AUTH_INTERNAL_ERROR": 2,
}
func (x AuthResult) String() string {
return proto.EnumName(AuthResult_name, int32(x))
}
func (AuthResult) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{0}
}
type ConnectionType int32
const (
ConnectionType_Stream ConnectionType = 0
ConnectionType_Packet ConnectionType = 1
)
var ConnectionType_name = map[int32]string{
0: "Stream",
1: "Packet",
}
var ConnectionType_value = map[string]int32{
"Stream": 0,
"Packet": 1,
}
func (x ConnectionType) String() string {
return proto.EnumName(ConnectionType_name, int32(x))
}
func (ConnectionType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{1}
}
type ConnectResult int32
const (
ConnectResult_CONN_SUCCESS ConnectResult = 0
ConnectResult_CONN_FAILED ConnectResult = 1
ConnectResult_CONN_BLOCKED ConnectResult = 2
)
var ConnectResult_name = map[int32]string{
0: "CONN_SUCCESS",
1: "CONN_FAILED",
2: "CONN_BLOCKED",
}
var ConnectResult_value = map[string]int32{
"CONN_SUCCESS": 0,
"CONN_FAILED": 1,
"CONN_BLOCKED": 2,
}
func (x ConnectResult) String() string {
return proto.EnumName(ConnectResult_name, int32(x))
}
func (ConnectResult) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{2}
}
type Speed struct {
SendBps uint64 `protobuf:"varint,1,opt,name=send_bps,json=sendBps,proto3" json:"send_bps,omitempty"`
ReceiveBps uint64 `protobuf:"varint,2,opt,name=receive_bps,json=receiveBps,proto3" json:"receive_bps,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Speed) Reset() { *m = Speed{} }
func (m *Speed) String() string { return proto.CompactTextString(m) }
func (*Speed) ProtoMessage() {}
func (*Speed) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{0}
}
func (m *Speed) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Speed.Unmarshal(m, b)
}
func (m *Speed) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Speed.Marshal(b, m, deterministic)
}
func (m *Speed) XXX_Merge(src proto.Message) {
xxx_messageInfo_Speed.Merge(m, src)
}
func (m *Speed) XXX_Size() int {
return xxx_messageInfo_Speed.Size(m)
}
func (m *Speed) XXX_DiscardUnknown() {
xxx_messageInfo_Speed.DiscardUnknown(m)
}
var xxx_messageInfo_Speed proto.InternalMessageInfo
func (m *Speed) GetSendBps() uint64 {
if m != nil {
return m.SendBps
}
return 0
}
func (m *Speed) GetReceiveBps() uint64 {
if m != nil {
return m.ReceiveBps
}
return 0
}
type Credential struct {
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *Credential) Reset() { *m = Credential{} }
func (m *Credential) String() string { return proto.CompactTextString(m) }
func (*Credential) ProtoMessage() {}
func (*Credential) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{1}
}
func (m *Credential) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_Credential.Unmarshal(m, b)
}
func (m *Credential) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_Credential.Marshal(b, m, deterministic)
}
func (m *Credential) XXX_Merge(src proto.Message) {
xxx_messageInfo_Credential.Merge(m, src)
}
func (m *Credential) XXX_Size() int {
return xxx_messageInfo_Credential.Size(m)
}
func (m *Credential) XXX_DiscardUnknown() {
xxx_messageInfo_Credential.DiscardUnknown(m)
}
var xxx_messageInfo_Credential proto.InternalMessageInfo
func (m *Credential) GetUsername() string {
if m != nil {
return m.Username
}
return ""
}
func (m *Credential) GetPassword() string {
if m != nil {
return m.Password
}
return ""
}
type ClientAuthRequest struct {
Credential *Credential `protobuf:"bytes,1,opt,name=credential,proto3" json:"credential,omitempty"`
Speed *Speed `protobuf:"bytes,2,opt,name=speed,proto3" json:"speed,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientAuthRequest) Reset() { *m = ClientAuthRequest{} }
func (m *ClientAuthRequest) String() string { return proto.CompactTextString(m) }
func (*ClientAuthRequest) ProtoMessage() {}
func (*ClientAuthRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{2}
}
func (m *ClientAuthRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientAuthRequest.Unmarshal(m, b)
}
func (m *ClientAuthRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientAuthRequest.Marshal(b, m, deterministic)
}
func (m *ClientAuthRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientAuthRequest.Merge(m, src)
}
func (m *ClientAuthRequest) XXX_Size() int {
return xxx_messageInfo_ClientAuthRequest.Size(m)
}
func (m *ClientAuthRequest) XXX_DiscardUnknown() {
xxx_messageInfo_ClientAuthRequest.DiscardUnknown(m)
}
var xxx_messageInfo_ClientAuthRequest proto.InternalMessageInfo
func (m *ClientAuthRequest) GetCredential() *Credential {
if m != nil {
return m.Credential
}
return nil
}
func (m *ClientAuthRequest) GetSpeed() *Speed {
if m != nil {
return m.Speed
}
return nil
}
type ServerAuthResponse struct {
Result AuthResult `protobuf:"varint,1,opt,name=result,proto3,enum=core.AuthResult" json:"result,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
Speed *Speed `protobuf:"bytes,3,opt,name=speed,proto3" json:"speed,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerAuthResponse) Reset() { *m = ServerAuthResponse{} }
func (m *ServerAuthResponse) String() string { return proto.CompactTextString(m) }
func (*ServerAuthResponse) ProtoMessage() {}
func (*ServerAuthResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{3}
}
func (m *ServerAuthResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerAuthResponse.Unmarshal(m, b)
}
func (m *ServerAuthResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerAuthResponse.Marshal(b, m, deterministic)
}
func (m *ServerAuthResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerAuthResponse.Merge(m, src)
}
func (m *ServerAuthResponse) XXX_Size() int {
return xxx_messageInfo_ServerAuthResponse.Size(m)
}
func (m *ServerAuthResponse) XXX_DiscardUnknown() {
xxx_messageInfo_ServerAuthResponse.DiscardUnknown(m)
}
var xxx_messageInfo_ServerAuthResponse proto.InternalMessageInfo
func (m *ServerAuthResponse) GetResult() AuthResult {
if m != nil {
return m.Result
}
return AuthResult_AUTH_SUCCESS
}
func (m *ServerAuthResponse) GetMessage() string {
if m != nil {
return m.Message
}
return ""
}
func (m *ServerAuthResponse) GetSpeed() *Speed {
if m != nil {
return m.Speed
}
return nil
}
type ClientConnectRequest struct {
Type ConnectionType `protobuf:"varint,1,opt,name=type,proto3,enum=core.ConnectionType" json:"type,omitempty"`
Address string `protobuf:"bytes,2,opt,name=address,proto3" json:"address,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ClientConnectRequest) Reset() { *m = ClientConnectRequest{} }
func (m *ClientConnectRequest) String() string { return proto.CompactTextString(m) }
func (*ClientConnectRequest) ProtoMessage() {}
func (*ClientConnectRequest) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{4}
}
func (m *ClientConnectRequest) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ClientConnectRequest.Unmarshal(m, b)
}
func (m *ClientConnectRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ClientConnectRequest.Marshal(b, m, deterministic)
}
func (m *ClientConnectRequest) XXX_Merge(src proto.Message) {
xxx_messageInfo_ClientConnectRequest.Merge(m, src)
}
func (m *ClientConnectRequest) XXX_Size() int {
return xxx_messageInfo_ClientConnectRequest.Size(m)
}
func (m *ClientConnectRequest) XXX_DiscardUnknown() {
xxx_messageInfo_ClientConnectRequest.DiscardUnknown(m)
}
var xxx_messageInfo_ClientConnectRequest proto.InternalMessageInfo
func (m *ClientConnectRequest) GetType() ConnectionType {
if m != nil {
return m.Type
}
return ConnectionType_Stream
}
func (m *ClientConnectRequest) GetAddress() string {
if m != nil {
return m.Address
}
return ""
}
type ServerConnectResponse struct {
Result ConnectResult `protobuf:"varint,1,opt,name=result,proto3,enum=core.ConnectResult" json:"result,omitempty"`
Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *ServerConnectResponse) Reset() { *m = ServerConnectResponse{} }
func (m *ServerConnectResponse) String() string { return proto.CompactTextString(m) }
func (*ServerConnectResponse) ProtoMessage() {}
func (*ServerConnectResponse) Descriptor() ([]byte, []int) {
return fileDescriptor_0c5120591600887d, []int{5}
}
func (m *ServerConnectResponse) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_ServerConnectResponse.Unmarshal(m, b)
}
func (m *ServerConnectResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_ServerConnectResponse.Marshal(b, m, deterministic)
}
func (m *ServerConnectResponse) XXX_Merge(src proto.Message) {
xxx_messageInfo_ServerConnectResponse.Merge(m, src)
}
func (m *ServerConnectResponse) XXX_Size() int {
return xxx_messageInfo_ServerConnectResponse.Size(m)
}
func (m *ServerConnectResponse) XXX_DiscardUnknown() {
xxx_messageInfo_ServerConnectResponse.DiscardUnknown(m)
}
var xxx_messageInfo_ServerConnectResponse proto.InternalMessageInfo
func (m *ServerConnectResponse) GetResult() ConnectResult {
if m != nil {
return m.Result
}
return ConnectResult_CONN_SUCCESS
}
func (m *ServerConnectResponse) GetMessage() string {
if m != nil {
return m.Message
}
return ""
}
func init() {
proto.RegisterEnum("core.AuthResult", AuthResult_name, AuthResult_value)
proto.RegisterEnum("core.ConnectionType", ConnectionType_name, ConnectionType_value)
proto.RegisterEnum("core.ConnectResult", ConnectResult_name, ConnectResult_value)
proto.RegisterType((*Speed)(nil), "core.Speed")
proto.RegisterType((*Credential)(nil), "core.Credential")
proto.RegisterType((*ClientAuthRequest)(nil), "core.ClientAuthRequest")
proto.RegisterType((*ServerAuthResponse)(nil), "core.ServerAuthResponse")
proto.RegisterType((*ClientConnectRequest)(nil), "core.ClientConnectRequest")
proto.RegisterType((*ServerConnectResponse)(nil), "core.ServerConnectResponse")
}
func init() {
proto.RegisterFile("control.proto", fileDescriptor_0c5120591600887d)
}
var fileDescriptor_0c5120591600887d = []byte{
// 434 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xd1, 0x6e, 0xd3, 0x30,
0x14, 0x86, 0xd7, 0xd2, 0x75, 0xdb, 0x09, 0x1b, 0x99, 0xb7, 0x89, 0xc1, 0x0d, 0x90, 0xab, 0xaa,
0x48, 0x15, 0x1a, 0x4f, 0x90, 0x3a, 0x41, 0x54, 0x54, 0x29, 0x72, 0x3a, 0x2e, 0xb8, 0xa0, 0xca,
0x92, 0x23, 0x56, 0x91, 0xda, 0xc6, 0x76, 0x86, 0x26, 0x5e, 0x1e, 0xc5, 0x71, 0xd2, 0x16, 0x09,
0x89, 0xbb, 0x9e, 0x73, 0x7e, 0xfd, 0x9f, 0xbf, 0x2a, 0x70, 0x9a, 0x0b, 0x6e, 0x94, 0x28, 0x27,
0x52, 0x09, 0x23, 0xc8, 0x20, 0x17, 0x0a, 0x03, 0x0a, 0x87, 0xa9, 0x44, 0x2c, 0xc8, 0x0b, 0x38,
0xd6, 0xc8, 0x8b, 0xd5, 0x9d, 0xd4, 0xd7, 0xbd, 0xd7, 0xbd, 0xd1, 0x80, 0x1d, 0xd5, 0xf3, 0x54,
0x6a, 0xf2, 0x0a, 0x3c, 0x85, 0x39, 0xae, 0x1f, 0xd0, 0x5e, 0xfb, 0xf6, 0x0a, 0x6e, 0x35, 0x95,
0x3a, 0x88, 0x00, 0xa8, 0xc2, 0x02, 0xb9, 0x59, 0x67, 0x25, 0x79, 0x09, 0xc7, 0x95, 0x46, 0xc5,
0xb3, 0x0d, 0xda, 0xa6, 0x13, 0xd6, 0xcd, 0xf5, 0x4d, 0x66, 0x5a, 0xff, 0x12, 0xaa, 0xb0, 0x3d,
0x27, 0xac, 0x9b, 0x83, 0x7b, 0x38, 0xa7, 0xe5, 0x1a, 0xb9, 0x09, 0x2b, 0x73, 0xcf, 0xf0, 0x67,
0x85, 0xda, 0x90, 0x77, 0x00, 0x79, 0x57, 0x6d, 0xeb, 0xbc, 0x1b, 0x7f, 0x52, 0x3f, 0x7d, 0xb2,
0x45, 0xb2, 0x9d, 0x0c, 0x79, 0x03, 0x87, 0xba, 0x36, 0xb2, 0xfd, 0xde, 0x8d, 0xd7, 0x84, 0xad,
0x24, 0x6b, 0x2e, 0xc1, 0x6f, 0x20, 0x29, 0xaa, 0x07, 0x54, 0x0d, 0x49, 0x4b, 0xc1, 0x35, 0x92,
0x11, 0x0c, 0x15, 0xea, 0xaa, 0x34, 0x16, 0x73, 0xd6, 0x62, 0x5c, 0xa6, 0x2a, 0x0d, 0x73, 0x77,
0x72, 0x0d, 0x47, 0x1b, 0xd4, 0x3a, 0xfb, 0x8e, 0x4e, 0xa2, 0x1d, 0xb7, 0xf0, 0x27, 0xff, 0x84,
0x7f, 0x85, 0xcb, 0x46, 0x93, 0x0a, 0xce, 0x31, 0x37, 0xad, 0xe9, 0x08, 0x06, 0xe6, 0x51, 0xa2,
0x83, 0x5f, 0x3a, 0xc7, 0x26, 0xb3, 0x16, 0x7c, 0xf9, 0x28, 0x91, 0xd9, 0x44, 0x8d, 0xcf, 0x8a,
0x42, 0xa1, 0xd6, 0x2d, 0xde, 0x8d, 0xc1, 0x37, 0xb8, 0x6a, 0xc4, 0xba, 0x6e, 0xe7, 0xf6, 0xf6,
0x2f, 0xb7, 0x8b, 0xbd, 0xfa, 0xff, 0xd5, 0x1b, 0x27, 0x00, 0xdb, 0xbf, 0x83, 0xf8, 0xf0, 0x34,
0xbc, 0x5d, 0x7e, 0x5c, 0xa5, 0xb7, 0x94, 0xc6, 0x69, 0xea, 0x1f, 0x90, 0x2b, 0x38, 0xb7, 0x9b,
0x59, 0xf2, 0x25, 0x9c, 0xcf, 0xa2, 0x15, 0x65, 0x71, 0xe4, 0xf7, 0xc8, 0x73, 0xb8, 0x70, 0xeb,
0x65, 0xcc, 0x92, 0x70, 0xbe, 0x8a, 0x19, 0x5b, 0x30, 0xbf, 0x3f, 0x1e, 0xc1, 0xd9, 0xbe, 0x21,
0x01, 0x18, 0xa6, 0x46, 0x61, 0xb6, 0xf1, 0x0f, 0xea, 0xdf, 0x9f, 0xb3, 0xfc, 0x07, 0x1a, 0xbf,
0x37, 0x8e, 0xe0, 0x74, 0xef, 0xb1, 0x35, 0x9c, 0x2e, 0x92, 0x64, 0x07, 0xfe, 0x0c, 0x3c, 0xbb,
0xf9, 0x10, 0xce, 0xe6, 0x16, 0xdb, 0x46, 0xa6, 0xf3, 0x05, 0xfd, 0x14, 0x47, 0x7e, 0xff, 0x6e,
0x68, 0x3f, 0xfd, 0xf7, 0x7f, 0x02, 0x00, 0x00, 0xff, 0xff, 0xd2, 0x2f, 0x8d, 0x6f, 0x0b, 0x03,
0x00, 0x00,
}

View File

@@ -1,50 +0,0 @@
syntax = "proto3";
package core;
message Speed {
uint64 send_bps = 1;
uint64 receive_bps = 2;
}
message Credential {
string username = 1;
string password = 2;
}
enum AuthResult {
AUTH_SUCCESS = 0;
AUTH_INVALID_CRED = 1;
AUTH_INTERNAL_ERROR = 2;
}
message ClientAuthRequest {
Credential credential = 1;
Speed speed = 2;
}
message ServerAuthResponse {
AuthResult result = 1;
string message = 2;
Speed speed = 3;
}
enum ConnectionType {
Stream = 0;
Packet = 1;
}
enum ConnectResult {
CONN_SUCCESS = 0;
CONN_FAILED = 1;
CONN_BLOCKED = 2;
}
message ClientConnectRequest {
ConnectionType type = 1;
string address = 2;
}
message ServerConnectResponse {
ConnectResult result = 1;
string message = 2;
}

View File

@@ -1,3 +0,0 @@
package pb
//go:generate protoc --go_out=. control.proto

45
pkg/core/protocol.go Normal file
View File

@@ -0,0 +1,45 @@
package core
import (
"time"
)
const (
protocolVersion = uint8(1)
protocolTimeout = 10 * time.Second
closeErrorCodeGeneric = 0
closeErrorCodeProtocol = 1
closeErrorCodeAuth = 2
)
type transmissionRate struct {
SendBPS uint64
RecvBPS uint64
}
type clientHello struct {
Rate transmissionRate
AuthLen uint16 `struc:"sizeof=Auth"`
Auth []byte
}
type serverHello struct {
OK bool
Rate transmissionRate
MessageLen uint16 `struc:"sizeof=Message"`
Message string
}
type clientRequest struct {
UDP bool
AddressLen uint16 `struc:"sizeof=Address"`
Address string
}
type serverResponse struct {
OK bool
UDPSessionID uint32
MessageLen uint16 `struc:"sizeof=Message"`
Message string
}

View File

@@ -4,61 +4,32 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/pkg/core/pb"
"github.com/lunixbochs/struc"
"github.com/tobyxdd/hysteria/pkg/acl"
"github.com/tobyxdd/hysteria/pkg/utils"
"io"
"net"
"sync/atomic"
"time"
)
type AuthResult int32
type ConnectionType int32
type ConnectResult int32
const dialTimeout = 10 * time.Second
const (
AuthResultSuccess AuthResult = iota
AuthResultInvalidCred
AuthResultInternalError
)
const (
ConnectionTypeStream ConnectionType = iota
ConnectionTypePacket
)
const (
ConnectResultSuccess ConnectResult = iota
ConnectResultFailed
ConnectResultBlocked
)
type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string)
type ClientDisconnectedFunc func(addr net.Addr, username string, err error)
type HandleRequestFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string) (ConnectResult, string, io.ReadWriteCloser)
type RequestClosedFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string, err error)
type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
type RequestFunc func(addr net.Addr, auth []byte, udp bool, reqAddr string)
type Server struct {
inboundBytes, outboundBytes uint64 // atomic
sendBPS, recvBPS uint64
congestionFactory CongestionFactory
authFunc AuthFunc
requestFunc RequestFunc
aclEngine *acl.Engine
listener quic.Listener
sendBPS, recvBPS uint64
congestionFactory CongestionFactory
clientAuthFunc ClientAuthFunc
clientDisconnectedFunc ClientDisconnectedFunc
handleRequestFunc HandleRequestFunc
requestClosedFunc RequestClosedFunc
listener quic.Listener
}
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
obfuscator Obfuscator,
clientAuthFunc ClientAuthFunc,
clientDisconnectedFunc ClientDisconnectedFunc,
handleRequestFunc HandleRequestFunc,
requestClosedFunc RequestClosedFunc) (*Server, error) {
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, aclEngine *acl.Engine,
obfuscator Obfuscator, authFunc AuthFunc, requestFunc RequestFunc) (*Server, error) {
packetConn, err := net.ListenPacket("udp", addr)
if err != nil {
return nil, err
@@ -75,14 +46,13 @@ func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
return nil, err
}
s := &Server{
listener: listener,
sendBPS: sendBPS,
recvBPS: recvBPS,
congestionFactory: congestionFactory,
clientAuthFunc: clientAuthFunc,
clientDisconnectedFunc: clientDisconnectedFunc,
handleRequestFunc: handleRequestFunc,
requestClosedFunc: requestClosedFunc,
listener: listener,
sendBPS: sendBPS,
recvBPS: recvBPS,
congestionFactory: congestionFactory,
authFunc: authFunc,
requestFunc: requestFunc,
aclEngine: aclEngine,
}
return s, nil
}
@@ -97,128 +67,156 @@ func (s *Server) Serve() error {
}
}
func (s *Server) Stats() (uint64, uint64) {
return atomic.LoadUint64(&s.inboundBytes), atomic.LoadUint64(&s.outboundBytes)
}
func (s *Server) Close() error {
return s.listener.Close()
}
func (s *Server) handleClient(cs quic.Session) {
// Expect the client to create a control stream to send its own information
ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout)
ctlStream, err := cs.AcceptStream(ctx)
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
stream, err := cs.AcceptStream(ctx)
ctxCancel()
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error")
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return
}
// Handle the control stream
username, ok, err := s.handleControlStream(cs, ctlStream)
auth, ok, err := s.handleControlStream(cs, stream)
if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return
}
if !ok {
_ = cs.CloseWithError(closeErrorCodeGeneric, "authentication failure")
_ = cs.CloseWithError(closeErrorCodeAuth, "auth error")
return
}
// Start accepting streams
var closeErr error
for {
stream, err := cs.AcceptStream(context.Background())
if err != nil {
closeErr = err
break
}
go s.handleStream(cs.LocalAddr(), cs.RemoteAddr(), username, stream)
go s.handleStream(cs.RemoteAddr(), auth, stream)
}
s.clientDisconnectedFunc(cs.RemoteAddr(), username, closeErr)
_ = cs.CloseWithError(closeErrorCodeGeneric, "generic")
_ = cs.CloseWithError(closeErrorCodeGeneric, "")
}
// Auth & negotiate speed
func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) (string, bool, error) {
req, err := readClientAuthRequest(stream)
func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, error) {
var ch clientHello
err := struc.Unpack(stream, &ch)
if err != nil {
return "", false, err
return nil, false, err
}
// Speed
if req.Speed == nil || req.Speed.SendBps == 0 || req.Speed.ReceiveBps == 0 {
return "", false, errors.New("incorrect speed provided by the client")
if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 {
return nil, false, errors.New("invalid rate from client")
}
serverSendBPS, serverReceiveBPS := req.Speed.ReceiveBps, req.Speed.SendBps
serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS
if s.sendBPS > 0 && serverSendBPS > s.sendBPS {
serverSendBPS = s.sendBPS
}
if s.recvBPS > 0 && serverReceiveBPS > s.recvBPS {
serverReceiveBPS = s.recvBPS
if s.recvBPS > 0 && serverRecvBPS > s.recvBPS {
serverRecvBPS = s.recvBPS
}
// Auth
if req.Credential == nil {
return "", false, errors.New("incorrect credential provided by the client")
}
authResult, msg := s.clientAuthFunc(cs.RemoteAddr(), req.Credential.Username, req.Credential.Password,
serverSendBPS, serverReceiveBPS)
ok, msg := s.authFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS)
// Response
err = writeServerAuthResponse(stream, &pb.ServerAuthResponse{
Result: pb.AuthResult(authResult),
Message: msg,
Speed: &pb.Speed{
SendBps: serverSendBPS,
ReceiveBps: serverReceiveBPS,
err = struc.Pack(stream, &serverHello{
OK: ok,
Rate: transmissionRate{
SendBPS: serverSendBPS,
RecvBPS: serverRecvBPS,
},
Message: msg,
})
if err != nil {
return "", false, err
return nil, false, err
}
// Set the congestion accordingly
if authResult == AuthResultSuccess && s.congestionFactory != nil {
if ok && s.congestionFactory != nil {
cs.SetCongestionControl(s.congestionFactory(serverSendBPS))
}
return req.Credential.Username, authResult == AuthResultSuccess, nil
return ch.Auth, ok, nil
}
func (s *Server) handleStream(localAddr net.Addr, remoteAddr net.Addr, username string, stream quic.Stream) {
func (s *Server) handleStream(remoteAddr net.Addr, auth []byte, stream quic.Stream) {
defer stream.Close()
// Read request
req, err := readClientConnectRequest(stream)
var req clientRequest
err := struc.Unpack(stream, &req)
if err != nil {
return
}
// Create connection with the handler
result, msg, conn := s.handleRequestFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address)
defer func() {
if conn != nil {
_ = conn.Close()
s.requestFunc(remoteAddr, auth, req.UDP, req.Address)
if !req.UDP {
// TCP connection
s.handleTCP(stream, req.Address)
} else {
// UDP connection
// TODO
}
}
func (s *Server) handleTCP(stream quic.Stream, reqAddr string) {
host, port, err := net.SplitHostPort(reqAddr)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "invalid address",
})
return
}
ip := net.ParseIP(host)
if ip != nil {
// IP request, clear host for ACL engine
host = ""
}
action, arg := acl.ActionDirect, ""
if s.aclEngine != nil {
action, arg = s.aclEngine.Lookup(host, ip)
}
var conn net.Conn // Connection to be piped
switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: err.Error(),
})
return
}
}()
// Send response
err = writeServerConnectResponse(stream, &pb.ServerConnectResponse{
Result: pb.ConnectResult(result),
Message: msg,
case acl.ActionBlock:
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "blocked by ACL",
})
return
case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, port)
conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: err.Error(),
})
return
}
default:
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "ACL error",
})
return
}
// So far so good if we reach here
err = struc.Pack(stream, &serverResponse{
OK: true,
})
if err != nil {
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, err)
return
}
if result != ConnectResultSuccess {
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address,
fmt.Errorf("handler returned an unsuccessful state %d (msg: %s)", result, msg))
return
}
switch req.Type {
case pb.ConnectionType_Stream:
err = utils.PipePair(stream, conn, &s.outboundBytes, &s.inboundBytes)
case pb.ConnectionType_Packet:
err = utils.PipePair(&utils.PacketWrapperConn{Orig: &utils.QUICStreamWrapperConn{
Orig: stream,
PseudoLocalAddr: localAddr,
PseudoRemoteAddr: remoteAddr,
}}, conn, &s.outboundBytes, &s.inboundBytes)
default:
err = fmt.Errorf("unsupported connection type %s", req.Type.String())
}
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, err)
_ = utils.Pipe2Way(stream, conn)
}

View File

@@ -1,13 +0,0 @@
package core
import (
"encoding/binary"
"github.com/lucas-clemente/quic-go/congestion"
"time"
)
const controlStreamTimeout = 10 * time.Second
var controlProtocolEndian = binary.BigEndian
type CongestionFactory func(refBPS uint64) congestion.CongestionControl

View File

@@ -42,9 +42,9 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng
case acl.ActionDirect:
return net.Dial(network, addr)
case acl.ActionProxy:
return hyClient.Dial(false, addr)
return hyClient.DialTCP(addr)
case acl.ActionBlock:
return nil, errors.New("blocked in ACL")
return nil, errors.New("blocked by ACL")
case acl.ActionHijack:
return net.Dial(network, net.JoinHostPort(arg, port))
default:
@@ -53,7 +53,6 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng
},
IdleConnTimeout: idleTimeout,
// TODO: Disable HTTP2 support? ref: https://github.com/elazarl/goproxy/issues/361
//TLSNextProto: make(map[string]func(authority string, c *tls.Conn) http.RoundTripper),
}
proxy.ConnectDial = nil
if basicAuthFunc != nil {

View File

@@ -156,12 +156,8 @@ func (s *Server) handle(c *net.TCPConn, r *socks5.Request) error {
return s.handleTCP(c, r)
} else if r.Cmd == socks5.CmdUDP {
// UDP
if !s.DisableUDP {
return s.handleUDP(c, r)
} else {
_ = sendReply(c, socks5.RepCommandNotSupported)
return ErrUnsupportedCmd
}
_ = sendReply(c, socks5.RepCommandNotSupported)
return ErrUnsupportedCmd
} else {
_ = sendReply(c, socks5.RepCommandNotSupported)
return ErrUnsupportedCmd
@@ -193,7 +189,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
closeErr = pipePair(c, rc, s.TCPDeadline)
return nil
case acl.ActionProxy:
rc, err := s.HyClient.Dial(false, addr)
rc, err := s.HyClient.DialTCP(addr)
if err != nil {
_ = sendReply(c, socks5.RepHostUnreachable)
closeErr = err
@@ -225,132 +221,6 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
}
}
func (s *Server) handleUDP(c *net.TCPConn, r *socks5.Request) error {
s.NewUDPAssociateFunc(c.RemoteAddr())
var closeErr error
defer func() {
s.UDPAssociateClosedFunc(c.RemoteAddr(), closeErr)
}()
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{
IP: s.TCPAddr.IP,
Zone: s.TCPAddr.Zone,
})
if err != nil {
_ = sendReply(c, socks5.RepServerFailure)
closeErr = err
return err
}
defer udpConn.Close()
// Send UDP server addr to the client
atyp, addr, port, err := socks5.ParseAddress(udpConn.LocalAddr().String())
if err != nil {
_ = sendReply(c, socks5.RepServerFailure)
closeErr = err
return err
}
_, _ = socks5.NewReply(socks5.RepSuccess, atyp, addr, port).WriteTo(c)
// Let UDP server do its job, we hold the TCP connection here
go s.udpServer(udpConn)
buf := make([]byte, 1024)
for {
if s.TCPDeadline != 0 {
_ = c.SetDeadline(time.Now().Add(time.Duration(s.TCPDeadline) * time.Second))
}
_, err := c.Read(buf)
if err != nil {
closeErr = err
break
}
}
// As the TCP connection closes, so does the UDP listener
return nil
}
func (s *Server) udpServer(c *net.UDPConn) {
var clientAddr *net.UDPAddr
remoteMap := make(map[string]io.ReadWriteCloser) // Remote addr <-> Remote conn
buf := make([]byte, utils.PipeBufferSize)
var closeErr error
for {
n, caddr, err := c.ReadFromUDP(buf)
if err != nil {
closeErr = err
break
}
d, err := socks5.NewDatagramFromBytes(buf[:n])
if err != nil || d.Frag != 0 {
// Ignore bad packets
continue
}
if clientAddr == nil {
// Whoever sends the first valid packet is our client :P
clientAddr = caddr
} else if caddr.String() != clientAddr.String() {
// We already have a client and you're not it!
continue
}
domain, ip, port, addr := parseDatagramRequestAddress(d)
rc := remoteMap[addr]
if rc == nil {
// Need a new entry
action, arg := acl.ActionProxy, ""
if s.ACLEngine != nil {
action, arg = s.ACLEngine.Lookup(domain, ip)
}
s.NewUDPTunnelFunc(clientAddr, addr, action, arg)
// Handle according to the action
switch action {
case acl.ActionDirect:
rc, err = net.Dial("udp", addr)
if err != nil {
s.UDPTunnelClosedFunc(clientAddr, addr, err)
continue
}
// The other direction
go udpReversePipe(clientAddr, c, rc)
remoteMap[addr] = rc
case acl.ActionProxy:
rc, err = s.HyClient.Dial(true, addr)
if err != nil {
s.UDPTunnelClosedFunc(clientAddr, addr, err)
continue
}
// The other direction
go udpReversePipe(clientAddr, c, rc)
remoteMap[addr] = rc
case acl.ActionBlock:
s.UDPTunnelClosedFunc(clientAddr, addr, errors.New("blocked in ACL"))
continue
case acl.ActionHijack:
rc, err = net.Dial("udp", net.JoinHostPort(arg, port))
if err != nil {
s.UDPTunnelClosedFunc(clientAddr, addr, err)
continue
}
// The other direction
go udpReversePipe(clientAddr, c, rc)
remoteMap[addr] = rc
default:
s.UDPTunnelClosedFunc(clientAddr, addr, fmt.Errorf("unknown action %d", action))
continue
}
}
_, err = rc.Write(d.Data)
if err != nil {
// The connection is no longer valid, close & remove from map
_ = rc.Close()
delete(remoteMap, addr)
s.UDPTunnelClosedFunc(clientAddr, addr, err)
}
}
// Close all remote connections
for raddr, rc := range remoteMap {
_ = rc.Close()
s.UDPTunnelClosedFunc(clientAddr, raddr, closeErr)
}
}
func sendReply(conn *net.TCPConn, rep byte) error {
p := socks5.NewReply(rep, socks5.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00})
_, err := p.WriteTo(conn)
@@ -367,24 +237,15 @@ func parseRequestAddress(r *socks5.Request) (domain string, ip net.IP, port stri
}
}
func parseDatagramRequestAddress(r *socks5.Datagram) (domain string, ip net.IP, port string, addr string) {
p := strconv.Itoa(int(binary.BigEndian.Uint16(r.DstPort)))
if r.Atyp == socks5.ATYPDomain {
d := string(r.DstAddr[1:])
return d, nil, p, net.JoinHostPort(d, p)
} else {
return "", r.DstAddr, p, net.JoinHostPort(net.IP(r.DstAddr).String(), p)
}
}
func pipePair(conn *net.TCPConn, stream io.ReadWriteCloser, deadline int) error {
deadlineDuration := time.Duration(deadline) * time.Second
errChan := make(chan error, 2)
// TCP to stream
go func() {
buf := make([]byte, utils.PipeBufferSize)
for {
if deadline != 0 {
_ = conn.SetDeadline(time.Now().Add(time.Duration(deadline) * time.Second))
_ = conn.SetDeadline(time.Now().Add(deadlineDuration))
}
rn, err := conn.Read(buf)
if rn > 0 {
@@ -402,22 +263,7 @@ func pipePair(conn *net.TCPConn, stream io.ReadWriteCloser, deadline int) error
}()
// Stream to TCP
go func() {
errChan <- utils.Pipe(stream, conn, nil)
errChan <- utils.Pipe(stream, conn)
}()
return <-errChan
}
func udpReversePipe(clientAddr *net.UDPAddr, c *net.UDPConn, rc io.ReadWriteCloser) {
buf := make([]byte, utils.PipeBufferSize)
for {
n, err := rc.Read(buf)
if err != nil {
break
}
d := socks5.NewDatagram(socks5.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}, buf[:n])
_, err = c.WriteTo(d.Bytes(), clientAddr)
if err != nil {
break
}
}
}

View File

@@ -1,96 +0,0 @@
package utils
import (
"encoding/binary"
"fmt"
"github.com/lucas-clemente/quic-go"
"io"
"net"
"time"
)
type PacketWrapperConn struct {
Orig net.Conn
}
func (w *PacketWrapperConn) Read(b []byte) (n int, err error) {
var sz uint32
if err := binary.Read(w.Orig, binary.BigEndian, &sz); err != nil {
return 0, err
}
if int(sz) <= len(b) {
return io.ReadFull(w.Orig, b[:sz])
} else {
return 0, fmt.Errorf("the buffer is too small to hold %d bytes of packet data", sz)
}
}
func (w *PacketWrapperConn) Write(b []byte) (n int, err error) {
sz := uint32(len(b))
if err := binary.Write(w.Orig, binary.BigEndian, &sz); err != nil {
return 0, err
}
return w.Orig.Write(b)
}
func (w *PacketWrapperConn) Close() error {
return w.Orig.Close()
}
func (w *PacketWrapperConn) LocalAddr() net.Addr {
return w.Orig.LocalAddr()
}
func (w *PacketWrapperConn) RemoteAddr() net.Addr {
return w.Orig.RemoteAddr()
}
func (w *PacketWrapperConn) SetDeadline(t time.Time) error {
return w.Orig.SetDeadline(t)
}
func (w *PacketWrapperConn) SetReadDeadline(t time.Time) error {
return w.Orig.SetReadDeadline(t)
}
func (w *PacketWrapperConn) SetWriteDeadline(t time.Time) error {
return w.Orig.SetWriteDeadline(t)
}
type QUICStreamWrapperConn struct {
Orig quic.Stream
PseudoLocalAddr net.Addr
PseudoRemoteAddr net.Addr
}
func (w *QUICStreamWrapperConn) Read(b []byte) (n int, err error) {
return w.Orig.Read(b)
}
func (w *QUICStreamWrapperConn) Write(b []byte) (n int, err error) {
return w.Orig.Write(b)
}
func (w *QUICStreamWrapperConn) Close() error {
return w.Orig.Close()
}
func (w *QUICStreamWrapperConn) LocalAddr() net.Addr {
return w.PseudoLocalAddr
}
func (w *QUICStreamWrapperConn) RemoteAddr() net.Addr {
return w.PseudoRemoteAddr
}
func (w *QUICStreamWrapperConn) SetDeadline(t time.Time) error {
return w.Orig.SetDeadline(t)
}
func (w *QUICStreamWrapperConn) SetReadDeadline(t time.Time) error {
return w.Orig.SetReadDeadline(t)
}
func (w *QUICStreamWrapperConn) SetWriteDeadline(t time.Time) error {
return w.Orig.SetWriteDeadline(t)
}

View File

@@ -2,20 +2,16 @@ package utils
import (
"io"
"sync/atomic"
)
const PipeBufferSize = 65536
func Pipe(src, dst io.ReadWriter, atomicCounter *uint64) error {
func Pipe(src, dst io.ReadWriter) error {
buf := make([]byte, PipeBufferSize)
for {
rn, err := src.Read(buf)
if rn > 0 {
wn, err := dst.Write(buf[:rn])
if atomicCounter != nil {
atomic.AddUint64(atomicCounter, uint64(wn))
}
_, err := dst.Write(buf[:rn])
if err != nil {
return err
}
@@ -26,13 +22,13 @@ func Pipe(src, dst io.ReadWriter, atomicCounter *uint64) error {
}
}
func PipePair(rw1, rw2 io.ReadWriter, rw1WriteCounter, rw2WriteCounter *uint64) error {
func Pipe2Way(rw1, rw2 io.ReadWriter) error {
errChan := make(chan error, 2)
go func() {
errChan <- Pipe(rw2, rw1, rw1WriteCounter)
errChan <- Pipe(rw2, rw1)
}()
go func() {
errChan <- Pipe(rw1, rw2, rw2WriteCounter)
errChan <- Pipe(rw1, rw2)
}()
// We only need the first error
return <-errChan