Simplify code

This commit is contained in:
Toby 2020-10-02 18:23:47 -07:00
parent 2df70dafca
commit 05a34f8f92
16 changed files with 99 additions and 146 deletions

View File

@ -92,7 +92,7 @@ func proxyServer(args []string) {
"up": sSend / mbpsToBps, "up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps, "down": sRecv / mbpsToBps,
}).Info("Client connected") }).Info("Client connected")
return core.AuthSuccess, "" return core.AuthResultSuccess, ""
} else { } else {
// Need auth // Need auth
ok, err := checkAuth(config.AuthFile, username, password) ok, err := checkAuth(config.AuthFile, username, password)
@ -102,7 +102,7 @@ func proxyServer(args []string) {
"addr": addr.String(), "addr": addr.String(),
"username": username, "username": username,
}).Error("Client authentication error") }).Error("Client authentication error")
return core.AuthInternalError, "Server auth error" return core.AuthResultInternalError, "Server auth error"
} }
if ok { if ok {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
@ -111,7 +111,7 @@ func proxyServer(args []string) {
"up": sSend / mbpsToBps, "up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps, "down": sRecv / mbpsToBps,
}).Info("Client authenticated") }).Info("Client authenticated")
return core.AuthSuccess, "" return core.AuthResultSuccess, ""
} else { } else {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"addr": addr.String(), "addr": addr.String(),
@ -119,7 +119,7 @@ func proxyServer(args []string) {
"up": sSend / mbpsToBps, "up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps, "down": sRecv / mbpsToBps,
}).Info("Client rejected due to invalid credential") }).Info("Client rejected due to invalid credential")
return core.AuthInvalidCred, "Invalid credential" return core.AuthResultInvalidCred, "Invalid credential"
} }
} }
}, },
@ -130,13 +130,14 @@ func proxyServer(args []string) {
"username": username, "username": username,
}).Info("Client disconnected") }).Info("Client disconnected")
}, },
func(addr net.Addr, username string, id int, packet bool, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) { func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
packet := reqType == core.ConnectionTypePacket
if packet && config.DisableUDP { if packet && config.DisableUDP {
return core.ConnBlocked, "UDP disabled", nil return core.ConnectResultBlocked, "UDP disabled", nil
} }
host, port, err := net.SplitHostPort(reqAddr) host, port, err := net.SplitHostPort(reqAddr)
if err != nil { if err != nil {
return core.ConnFailed, err.Error(), nil return core.ConnectResultFailed, err.Error(), nil
} }
ip := net.ParseIP(host) ip := net.ParseIP(host)
if ip != nil { if ip != nil {
@ -163,9 +164,9 @@ func proxyServer(args []string) {
"error": err, "error": err,
"dst": reqAddr, "dst": reqAddr,
}).Error("TCP error") }).Error("TCP error")
return core.ConnFailed, err.Error(), nil return core.ConnectResultFailed, err.Error(), nil
} }
return core.ConnSuccess, "", conn return core.ConnectResultSuccess, "", conn
} else { } else {
// UDP // UDP
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
@ -180,9 +181,9 @@ func proxyServer(args []string) {
"error": err, "error": err,
"dst": reqAddr, "dst": reqAddr,
}).Error("UDP error") }).Error("UDP error")
return core.ConnFailed, err.Error(), nil return core.ConnectResultFailed, err.Error(), nil
} }
return core.ConnSuccess, "", conn return core.ConnectResultSuccess, "", conn
} }
case acl.ActionBlock: case acl.ActionBlock:
if !packet { if !packet {
@ -193,7 +194,7 @@ func proxyServer(args []string) {
"src": addr.String(), "src": addr.String(),
"dst": reqAddr, "dst": reqAddr,
}).Debug("New TCP request") }).Debug("New TCP request")
return core.ConnBlocked, "blocked by ACL", nil return core.ConnectResultBlocked, "blocked by ACL", nil
} else { } else {
// UDP // UDP
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
@ -202,7 +203,7 @@ func proxyServer(args []string) {
"src": addr.String(), "src": addr.String(),
"dst": reqAddr, "dst": reqAddr,
}).Debug("New UDP request") }).Debug("New UDP request")
return core.ConnBlocked, "blocked by ACL", nil return core.ConnectResultBlocked, "blocked by ACL", nil
} }
case acl.ActionHijack: case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, port) hijackAddr := net.JoinHostPort(arg, port)
@ -221,9 +222,9 @@ func proxyServer(args []string) {
"error": err, "error": err,
"dst": hijackAddr, "dst": hijackAddr,
}).Error("TCP error") }).Error("TCP error")
return core.ConnFailed, err.Error(), nil return core.ConnectResultFailed, err.Error(), nil
} }
return core.ConnSuccess, "", conn return core.ConnectResultSuccess, "", conn
} else { } else {
// UDP // UDP
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
@ -239,15 +240,16 @@ func proxyServer(args []string) {
"error": err, "error": err,
"dst": hijackAddr, "dst": hijackAddr,
}).Error("UDP error") }).Error("UDP error")
return core.ConnFailed, err.Error(), nil return core.ConnectResultFailed, err.Error(), nil
} }
return core.ConnSuccess, "", conn return core.ConnectResultSuccess, "", conn
} }
default: default:
return core.ConnFailed, "server ACL error", nil return core.ConnectResultFailed, "server ACL error", nil
} }
}, },
func(addr net.Addr, username string, id int, packet bool, reqAddr string, err error) { func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
packet := reqType == core.ConnectionTypePacket
if !packet { if !packet {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"error": err, "error": err,

View File

@ -6,10 +6,10 @@ import (
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/tobyxdd/hysteria/internal/utils"
hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion" hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion"
"github.com/tobyxdd/hysteria/pkg/core" "github.com/tobyxdd/hysteria/pkg/core"
"github.com/tobyxdd/hysteria/pkg/obfs" "github.com/tobyxdd/hysteria/pkg/obfs"
"github.com/tobyxdd/hysteria/pkg/utils"
"io/ioutil" "io/ioutil"
"net" "net"
"os/user" "os/user"
@ -99,7 +99,7 @@ func relayClient(args []string) {
} }
} }
func relayClientHandleConn(conn net.Conn, client core.Client) { func relayClientHandleConn(conn net.Conn, client *core.Client) {
logrus.WithField("src", conn.RemoteAddr().String()).Debug("New connection") logrus.WithField("src", conn.RemoteAddr().String()).Debug("New connection")
var closeErr error var closeErr error
defer func() { defer func() {

View File

@ -72,7 +72,7 @@ func relayServer(args []string) {
"up": sSend / mbpsToBps, "up": sSend / mbpsToBps,
"down": sRecv / mbpsToBps, "down": sRecv / mbpsToBps,
}).Info("Client connected") }).Info("Client connected")
return core.AuthSuccess, "" return core.AuthResultSuccess, ""
}, },
func(addr net.Addr, username string, err error) { func(addr net.Addr, username string, err error) {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
@ -81,14 +81,15 @@ func relayServer(args []string) {
"username": username, "username": username,
}).Info("Client disconnected") }).Info("Client disconnected")
}, },
func(addr net.Addr, username string, id int, packet bool, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) { func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
packet := reqType == core.ConnectionTypePacket
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"username": username, "username": username,
"src": addr.String(), "src": addr.String(),
"id": id, "id": id,
}).Debug("New stream") }).Debug("New stream")
if packet { if packet {
return core.ConnBlocked, "unsupported", nil return core.ConnectResultBlocked, "unsupported", nil
} }
conn, err := net.DialTimeout("tcp", config.RemoteAddr, dialTimeout) conn, err := net.DialTimeout("tcp", config.RemoteAddr, dialTimeout)
if err != nil { if err != nil {
@ -96,11 +97,11 @@ func relayServer(args []string) {
"error": err, "error": err,
"dst": config.RemoteAddr, "dst": config.RemoteAddr,
}).Error("TCP error") }).Error("TCP error")
return core.ConnFailed, err.Error(), nil return core.ConnectResultFailed, err.Error(), nil
} }
return core.ConnSuccess, "", conn return core.ConnectResultSuccess, "", conn
}, },
func(addr net.Addr, username string, id int, packet bool, reqAddr string, err error) { func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"error": err, "error": err,
"username": username, "username": username,

View File

@ -6,7 +6,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/utils" "github.com/tobyxdd/hysteria/pkg/core/pb"
"github.com/tobyxdd/hysteria/pkg/utils"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -56,11 +57,11 @@ func (c *Client) Dial(packet bool, addr string) (net.Conn, error) {
return nil, err return nil, err
} }
// Send request // Send request
req := &ClientConnectRequest{Address: addr} req := &pb.ClientConnectRequest{Address: addr}
if packet { if packet {
req.Type = ConnectionType_Packet req.Type = pb.ConnectionType_Packet
} else { } else {
req.Type = ConnectionType_Stream req.Type = pb.ConnectionType_Stream
} }
err = writeClientConnectRequest(stream, req) err = writeClientConnectRequest(stream, req)
if err != nil { if err != nil {
@ -73,7 +74,7 @@ func (c *Client) Dial(packet bool, addr string) (net.Conn, error) {
_ = stream.Close() _ = stream.Close()
return nil, err return nil, err
} }
if resp.Result != ConnectResult_CONN_SUCCESS { if resp.Result != pb.ConnectResult_CONN_SUCCESS {
_ = stream.Close() _ = stream.Close()
return nil, fmt.Errorf("server rejected the connection %s (msg: %s)", return nil, fmt.Errorf("server rejected the connection %s (msg: %s)",
resp.Result.String(), resp.Message) resp.Result.String(), resp.Message)
@ -135,7 +136,7 @@ func (c *Client) connectToServer() error {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error") _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error")
return err return err
} }
if result != AuthResult_AUTH_SUCCESS { if result != pb.AuthResult_AUTH_SUCCESS {
_ = qs.CloseWithError(closeErrorCodeProtocolFailure, "authentication failure") _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "authentication failure")
return fmt.Errorf("authentication failure %s (msg: %s)", result.String(), msg) return fmt.Errorf("authentication failure %s (msg: %s)", result.String(), msg)
} }
@ -144,13 +145,13 @@ func (c *Client) connectToServer() error {
return nil return nil
} }
func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (AuthResult, string, error) { func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (pb.AuthResult, string, error) {
err := writeClientAuthRequest(stream, &ClientAuthRequest{ err := writeClientAuthRequest(stream, &pb.ClientAuthRequest{
Credential: &Credential{ Credential: &pb.Credential{
Username: c.username, Username: c.username,
Password: c.password, Password: c.password,
}, },
Speed: &Speed{ Speed: &pb.Speed{
SendBps: c.sendBPS, SendBps: c.sendBPS,
ReceiveBps: c.recvBPS, ReceiveBps: c.recvBPS,
}, },
@ -164,7 +165,7 @@ func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (AuthR
return 0, "", err return 0, "", err
} }
// Set the congestion accordingly // Set the congestion accordingly
if resp.Result == AuthResult_AUTH_SUCCESS && c.congestionFactory != nil { if resp.Result == pb.AuthResult_AUTH_SUCCESS && c.congestionFactory != nil {
qs.SetCongestion(c.congestionFactory(resp.Speed.ReceiveBps)) qs.SetCongestion(c.congestionFactory(resp.Speed.ReceiveBps))
} }
return resp.Result, resp.Message, nil return resp.Result, resp.Message, nil

View File

@ -3,6 +3,7 @@ package core
import ( import (
"encoding/binary" "encoding/binary"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/tobyxdd/hysteria/pkg/core/pb"
"io" "io"
) )
@ -30,17 +31,17 @@ func writeDataBlock(w io.Writer, data []byte) error {
return err return err
} }
func readClientAuthRequest(r io.Reader) (*ClientAuthRequest, error) { func readClientAuthRequest(r io.Reader) (*pb.ClientAuthRequest, error) {
bs, err := readDataBlock(r) bs, err := readDataBlock(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var req ClientAuthRequest var req pb.ClientAuthRequest
err = proto.Unmarshal(bs, &req) err = proto.Unmarshal(bs, &req)
return &req, err return &req, err
} }
func writeClientAuthRequest(w io.Writer, req *ClientAuthRequest) error { func writeClientAuthRequest(w io.Writer, req *pb.ClientAuthRequest) error {
bs, err := proto.Marshal(req) bs, err := proto.Marshal(req)
if err != nil { if err != nil {
return err return err
@ -48,17 +49,17 @@ func writeClientAuthRequest(w io.Writer, req *ClientAuthRequest) error {
return writeDataBlock(w, bs) return writeDataBlock(w, bs)
} }
func readServerAuthResponse(r io.Reader) (*ServerAuthResponse, error) { func readServerAuthResponse(r io.Reader) (*pb.ServerAuthResponse, error) {
bs, err := readDataBlock(r) bs, err := readDataBlock(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var resp ServerAuthResponse var resp pb.ServerAuthResponse
err = proto.Unmarshal(bs, &resp) err = proto.Unmarshal(bs, &resp)
return &resp, err return &resp, err
} }
func writeServerAuthResponse(w io.Writer, resp *ServerAuthResponse) error { func writeServerAuthResponse(w io.Writer, resp *pb.ServerAuthResponse) error {
bs, err := proto.Marshal(resp) bs, err := proto.Marshal(resp)
if err != nil { if err != nil {
return err return err
@ -66,17 +67,17 @@ func writeServerAuthResponse(w io.Writer, resp *ServerAuthResponse) error {
return writeDataBlock(w, bs) return writeDataBlock(w, bs)
} }
func readClientConnectRequest(r io.Reader) (*ClientConnectRequest, error) { func readClientConnectRequest(r io.Reader) (*pb.ClientConnectRequest, error) {
bs, err := readDataBlock(r) bs, err := readDataBlock(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var req ClientConnectRequest var req pb.ClientConnectRequest
err = proto.Unmarshal(bs, &req) err = proto.Unmarshal(bs, &req)
return &req, err return &req, err
} }
func writeClientConnectRequest(w io.Writer, req *ClientConnectRequest) error { func writeClientConnectRequest(w io.Writer, req *pb.ClientConnectRequest) error {
bs, err := proto.Marshal(req) bs, err := proto.Marshal(req)
if err != nil { if err != nil {
return err return err
@ -84,17 +85,17 @@ func writeClientConnectRequest(w io.Writer, req *ClientConnectRequest) error {
return writeDataBlock(w, bs) return writeDataBlock(w, bs)
} }
func readServerConnectResponse(r io.Reader) (*ServerConnectResponse, error) { func readServerConnectResponse(r io.Reader) (*pb.ServerConnectResponse, error) {
bs, err := readDataBlock(r) bs, err := readDataBlock(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var resp ServerConnectResponse var resp pb.ServerConnectResponse
err = proto.Unmarshal(bs, &resp) err = proto.Unmarshal(bs, &resp)
return &resp, err return &resp, err
} }
func writeServerConnectResponse(w io.Writer, resp *ServerConnectResponse) error { func writeServerConnectResponse(w io.Writer, resp *pb.ServerConnectResponse) error {
bs, err := proto.Marshal(resp) bs, err := proto.Marshal(resp)
if err != nil { if err != nil {
return err return err

View File

@ -1,74 +0,0 @@
package core
import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/core"
"io"
"net"
)
type AuthResult int32
const (
AuthSuccess = AuthResult(iota)
AuthInvalidCred
AuthInternalError
)
type ConnectResult int32
const (
ConnSuccess = ConnectResult(iota)
ConnFailed
ConnBlocked
)
type CongestionFactory core.CongestionFactory
type Obfuscator core.Obfuscator
type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string)
type ClientDisconnectedFunc core.ClientDisconnectedFunc
type HandleRequestFunc func(addr net.Addr, username string, id int, packet bool, reqAddr string) (ConnectResult, string, io.ReadWriteCloser)
type RequestClosedFunc func(addr net.Addr, username string, id int, packet bool, reqAddr string, err error)
type Server interface {
Serve() error
Stats() (inbound uint64, outbound uint64)
Close() error
}
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
obfuscator Obfuscator,
clientAuthFunc ClientAuthFunc,
clientDisconnectedFunc ClientDisconnectedFunc,
handleRequestFunc HandleRequestFunc,
requestClosedFunc RequestClosedFunc) (Server, error) {
return core.NewServer(addr, tlsConfig, quicConfig, sendBPS, recvBPS, core.CongestionFactory(congestionFactory),
core.Obfuscator(obfuscator),
func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (core.AuthResult, string) {
r, msg := clientAuthFunc(addr, username, password, sSend, sRecv)
return core.AuthResult(r), msg
},
core.ClientDisconnectedFunc(clientDisconnectedFunc),
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string) (core.ConnectResult, string, io.ReadWriteCloser) {
r, msg, conn := handleRequestFunc(addr, username, id, reqType == core.ConnectionType_Packet, reqAddr)
return core.ConnectResult(r), msg, conn
},
func(addr net.Addr, username string, id int, reqType core.ConnectionType, reqAddr string, err error) {
requestClosedFunc(addr, username, id, reqType == core.ConnectionType_Packet, reqAddr, err)
})
}
type Client interface {
Dial(packet bool, addr string) (net.Conn, error)
Stats() (inbound uint64, outbound uint64)
Close() error
}
func NewClient(serverAddr string, username string, password string,
tlsConfig *tls.Config, quicConfig *quic.Config, sendBPS uint64, recvBPS uint64,
congestionFactory CongestionFactory, obfuscator Obfuscator) (Client, error) {
return core.NewClient(serverAddr, username, password, tlsConfig, quicConfig, sendBPS, recvBPS,
core.CongestionFactory(congestionFactory), core.Obfuscator(obfuscator))
}

View File

@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT. // Code generated by protoc-gen-go. DO NOT EDIT.
// source: control.proto // source: control.proto
package core package pb
import ( import (
fmt "fmt" fmt "fmt"

View File

@ -1,3 +1,3 @@
package core package pb
//go:generate protoc --go_out=. control.proto //go:generate protoc --go_out=. control.proto

View File

@ -6,12 +6,34 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/tobyxdd/hysteria/internal/utils" "github.com/tobyxdd/hysteria/pkg/core/pb"
"github.com/tobyxdd/hysteria/pkg/utils"
"io" "io"
"net" "net"
"sync/atomic" "sync/atomic"
) )
type AuthResult int32
type ConnectionType int32
type ConnectResult int32
const (
AuthResultSuccess AuthResult = iota
AuthResultInvalidCred
AuthResultInternalError
)
const (
ConnectionTypeStream ConnectionType = iota
ConnectionTypePacket
)
const (
ConnectResultSuccess ConnectResult = iota
ConnectResultFailed
ConnectResultBlocked
)
type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string) type ClientAuthFunc func(addr net.Addr, username string, password string, sSend uint64, sRecv uint64) (AuthResult, string)
type ClientDisconnectedFunc func(addr net.Addr, username string, err error) type ClientDisconnectedFunc func(addr net.Addr, username string, err error)
type HandleRequestFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string) (ConnectResult, string, io.ReadWriteCloser) type HandleRequestFunc func(addr net.Addr, username string, id int, reqType ConnectionType, reqAddr string) (ConnectResult, string, io.ReadWriteCloser)
@ -140,10 +162,10 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) (strin
authResult, msg := s.clientAuthFunc(cs.RemoteAddr(), req.Credential.Username, req.Credential.Password, authResult, msg := s.clientAuthFunc(cs.RemoteAddr(), req.Credential.Username, req.Credential.Password,
serverSendBPS, serverReceiveBPS) serverSendBPS, serverReceiveBPS)
// Response // Response
err = writeServerAuthResponse(stream, &ServerAuthResponse{ err = writeServerAuthResponse(stream, &pb.ServerAuthResponse{
Result: authResult, Result: pb.AuthResult(authResult),
Message: msg, Message: msg,
Speed: &Speed{ Speed: &pb.Speed{
SendBps: serverSendBPS, SendBps: serverSendBPS,
ReceiveBps: serverReceiveBPS, ReceiveBps: serverReceiveBPS,
}, },
@ -152,10 +174,10 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) (strin
return "", false, err return "", false, err
} }
// Set the congestion accordingly // Set the congestion accordingly
if authResult == AuthResult_AUTH_SUCCESS && s.congestionFactory != nil { if authResult == AuthResultSuccess && s.congestionFactory != nil {
cs.SetCongestion(s.congestionFactory(serverSendBPS)) cs.SetCongestion(s.congestionFactory(serverSendBPS))
} }
return req.Credential.Username, authResult == AuthResult_AUTH_SUCCESS, nil return req.Credential.Username, authResult == AuthResultSuccess, nil
} }
func (s *Server) handleStream(localAddr net.Addr, remoteAddr net.Addr, username string, stream quic.Stream) { func (s *Server) handleStream(localAddr net.Addr, remoteAddr net.Addr, username string, stream quic.Stream) {
@ -166,30 +188,30 @@ func (s *Server) handleStream(localAddr net.Addr, remoteAddr net.Addr, username
return return
} }
// Create connection with the handler // Create connection with the handler
result, msg, conn := s.handleRequestFunc(remoteAddr, username, int(stream.StreamID()), req.Type, req.Address) result, msg, conn := s.handleRequestFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address)
defer func() { defer func() {
if conn != nil { if conn != nil {
_ = conn.Close() _ = conn.Close()
} }
}() }()
// Send response // Send response
err = writeServerConnectResponse(stream, &ServerConnectResponse{ err = writeServerConnectResponse(stream, &pb.ServerConnectResponse{
Result: result, Result: pb.ConnectResult(result),
Message: msg, Message: msg,
}) })
if err != nil { if err != nil {
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), req.Type, req.Address, err) s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, err)
return return
} }
if result != ConnectResult_CONN_SUCCESS { if result != ConnectResultSuccess {
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), req.Type, req.Address, s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address,
fmt.Errorf("handler returned an unsuccessful state %s (msg: %s)", result.String(), msg)) fmt.Errorf("handler returned an unsuccessful state %d (msg: %s)", result, msg))
return return
} }
switch req.Type { switch req.Type {
case ConnectionType_Stream: case pb.ConnectionType_Stream:
err = utils.PipePair(stream, conn, &s.outboundBytes, &s.inboundBytes) err = utils.PipePair(stream, conn, &s.outboundBytes, &s.inboundBytes)
case ConnectionType_Packet: case pb.ConnectionType_Packet:
err = utils.PipePair(&utils.PacketWrapperConn{Orig: &utils.QUICStreamWrapperConn{ err = utils.PipePair(&utils.PacketWrapperConn{Orig: &utils.QUICStreamWrapperConn{
Orig: stream, Orig: stream,
PseudoLocalAddr: localAddr, PseudoLocalAddr: localAddr,
@ -198,5 +220,5 @@ func (s *Server) handleStream(localAddr net.Addr, remoteAddr net.Addr, username
default: default:
err = fmt.Errorf("unsupported connection type %s", req.Type.String()) err = fmt.Errorf("unsupported connection type %s", req.Type.String())
} }
s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), req.Type, req.Address, err) s.requestClosedFunc(remoteAddr, username, int(stream.StreamID()), ConnectionType(req.Type), req.Address, err)
} }

View File

@ -14,7 +14,7 @@ import (
"github.com/tobyxdd/hysteria/pkg/core" "github.com/tobyxdd/hysteria/pkg/core"
) )
func NewProxyHTTPServer(hyClient core.Client, idleTimeout time.Duration, aclEngine *acl.Engine, func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEngine *acl.Engine,
newDialFunc func(reqAddr string, action acl.Action, arg string), newDialFunc func(reqAddr string, action acl.Action, arg string),
basicAuthFunc func(user, password string) bool) (*goproxy.ProxyHttpServer, error) { basicAuthFunc func(user, password string) bool) (*goproxy.ProxyHttpServer, error) {
proxy := goproxy.NewProxyHttpServer() proxy := goproxy.NewProxyHttpServer()

View File

@ -4,9 +4,9 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/tobyxdd/hysteria/internal/utils"
"github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/acl"
"github.com/tobyxdd/hysteria/pkg/core" "github.com/tobyxdd/hysteria/pkg/core"
"github.com/tobyxdd/hysteria/pkg/utils"
"io" "io"
"strconv" "strconv"
) )
@ -23,7 +23,7 @@ var (
) )
type Server struct { type Server struct {
HyClient core.Client HyClient *core.Client
AuthFunc func(username, password string) bool AuthFunc func(username, password string) bool
Method byte Method byte
TCPAddr *net.TCPAddr TCPAddr *net.TCPAddr
@ -41,7 +41,7 @@ type Server struct {
tcpListener *net.TCPListener tcpListener *net.TCPListener
} }
func NewServer(hyClient core.Client, addr string, authFunc func(username, password string) bool, tcpDeadline int, func NewServer(hyClient *core.Client, addr string, authFunc func(username, password string) bool, tcpDeadline int,
aclEngine *acl.Engine, disableUDP bool, aclEngine *acl.Engine, disableUDP bool,
newReqFunc func(addr net.Addr, reqAddr string, action acl.Action, arg string), newReqFunc func(addr net.Addr, reqAddr string, action acl.Action, arg string),
reqClosedFunc func(addr net.Addr, reqAddr string, err error), reqClosedFunc func(addr net.Addr, reqAddr string, err error),