Transport WIP

This commit is contained in:
Toby 2021-04-26 17:34:08 -07:00
parent b3d149a72f
commit 4da73888f4
6 changed files with 42 additions and 19 deletions

View File

@ -3,6 +3,7 @@ package acl
import ( import (
"bufio" "bufio"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/tobyxdd/hysteria/pkg/core"
"net" "net"
"os" "os"
"strings" "strings"
@ -14,6 +15,7 @@ type Engine struct {
DefaultAction Action DefaultAction Action
Entries []Entry Entries []Entry
Cache *lru.ARCCache Cache *lru.ARCCache
Transport core.Transport
} }
type cacheEntry struct { type cacheEntry struct {
@ -21,7 +23,7 @@ type cacheEntry struct {
Arg string Arg string
} }
func LoadFromFile(filename string) (*Engine, error) { func LoadFromFile(filename string, transport core.Transport) (*Engine, error) {
f, err := os.Open(filename) f, err := os.Open(filename)
if err != nil { if err != nil {
return nil, err return nil, err
@ -49,6 +51,7 @@ func LoadFromFile(filename string) (*Engine, error) {
DefaultAction: ActionProxy, DefaultAction: ActionProxy,
Entries: entries, Entries: entries,
Cache: cache, Cache: cache,
Transport: transport,
}, nil }, nil
} }
@ -56,7 +59,7 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro
ip, zone := parseIPZone(host) ip, zone := parseIPZone(host)
if ip == nil { if ip == nil {
// Domain // Domain
ipAddr, err := net.ResolveIPAddr("ip", host) ipAddr, err := e.Transport.OutResolveIPAddr(host)
if v, ok := e.Cache.Get(host); ok { if v, ok := e.Cache.Get(host); ok {
// Cache hit // Cache hit
ce := v.(cacheEntry) ce := v.(cacheEntry)

View File

@ -23,6 +23,7 @@ var (
type CongestionFactory func(refBPS uint64) congestion.CongestionControl type CongestionFactory func(refBPS uint64) congestion.CongestionControl
type Client struct { type Client struct {
transport Transport
serverAddr string serverAddr string
sendBPS, recvBPS uint64 sendBPS, recvBPS uint64
auth []byte auth []byte
@ -40,9 +41,10 @@ type Client struct {
udpSessionMap map[uint32]chan *udpMessage udpSessionMap map[uint32]chan *udpMessage
} }
func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, transport Transport,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator Obfuscator) (*Client, error) { sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator Obfuscator) (*Client, error) {
c := &Client{ c := &Client{
transport: transport,
serverAddr: serverAddr, serverAddr: serverAddr,
sendBPS: sendBPS, sendBPS: sendBPS,
recvBPS: recvBPS, recvBPS: recvBPS,
@ -59,11 +61,11 @@ func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig
} }
func (c *Client) connectToServer() error { func (c *Client) connectToServer() error {
serverUDPAddr, err := net.ResolveUDPAddr("udp", c.serverAddr) serverUDPAddr, err := c.transport.QUICResolveUDPAddr(c.serverAddr)
if err != nil { if err != nil {
return err return err
} }
udpConn, err := net.ListenUDP("udp", nil) udpConn, err := c.transport.QUICListenUDP(nil)
if err != nil { if err != nil {
return err return err
} }

View File

@ -20,6 +20,7 @@ type UDPRequestFunc func(addr net.Addr, auth []byte, sessionID uint32)
type UDPErrorFunc func(addr net.Addr, auth []byte, sessionID uint32, err error) type UDPErrorFunc func(addr net.Addr, auth []byte, sessionID uint32, err error)
type Server struct { type Server struct {
transport Transport
sendBPS, recvBPS uint64 sendBPS, recvBPS uint64
congestionFactory CongestionFactory congestionFactory CongestionFactory
disableUDP bool disableUDP bool
@ -36,15 +37,15 @@ type Server struct {
listener quic.Listener listener quic.Listener
} }
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, transport Transport,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine,
obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc, obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc,
udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc, promRegistry *prometheus.Registry) (*Server, error) { udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc, promRegistry *prometheus.Registry) (*Server, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr) udpAddr, err := transport.QUICResolveUDPAddr(addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
udpConn, err := net.ListenUDP("udp", udpAddr) udpConn, err := transport.QUICListenUDP(udpAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -72,6 +73,7 @@ func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
} }
s := &Server{ s := &Server{
listener: listener, listener: listener,
transport: transport,
sendBPS: sendBPS, sendBPS: sendBPS,
recvBPS: recvBPS, recvBPS: recvBPS,
congestionFactory: congestionFactory, congestionFactory: congestionFactory,
@ -129,7 +131,7 @@ func (s *Server) handleClient(cs quic.Session) {
return return
} }
// Start accepting streams and messages // Start accepting streams and messages
sc := newServerClient(cs, auth, s.disableUDP, s.aclEngine, sc := newServerClient(cs, s.transport, auth, s.disableUDP, s.aclEngine,
s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc, s.upCounterVec, s.downCounterVec) s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc, s.upCounterVec, s.downCounterVec)
sc.Run() sc.Run()
_ = cs.CloseWithError(closeErrorCodeGeneric, "") _ = cs.CloseWithError(closeErrorCodeGeneric, "")

View File

@ -18,6 +18,7 @@ const udpBufferSize = 65535
type serverClient struct { type serverClient struct {
CS quic.Session CS quic.Session
Transport Transport
Auth []byte Auth []byte
ClientAddr net.Addr ClientAddr net.Addr
DisableUDP bool DisableUDP bool
@ -34,12 +35,13 @@ type serverClient struct {
nextUDPSessionID uint32 nextUDPSessionID uint32
} }
func newServerClient(cs quic.Session, auth []byte, disableUDP bool, ACLEngine *acl.Engine, func newServerClient(cs quic.Session, transport Transport, auth []byte, disableUDP bool, ACLEngine *acl.Engine,
CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc, CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc,
CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc, CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc,
UpCounterVec, DownCounterVec *prometheus.CounterVec) *serverClient { UpCounterVec, DownCounterVec *prometheus.CounterVec) *serverClient {
sc := &serverClient{ sc := &serverClient{
CS: cs, CS: cs,
Transport: transport,
Auth: auth, Auth: auth,
ClientAddr: cs.RemoteAddr(), ClientAddr: cs.RemoteAddr(),
DisableUDP: disableUDP, DisableUDP: disableUDP,
@ -118,7 +120,7 @@ func (c *serverClient) handleMessage(msg []byte) {
if c.ACLEngine != nil { if c.ACLEngine != nil {
action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(udpMsg.Host) action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(udpMsg.Host)
} else { } else {
ipAddr, err = net.ResolveIPAddr("ip", udpMsg.Host) ipAddr, err = c.Transport.OutResolveIPAddr(udpMsg.Host)
} }
if err != nil { if err != nil {
return return
@ -137,7 +139,7 @@ func (c *serverClient) handleMessage(msg []byte) {
// Do nothing // Do nothing
case acl.ActionHijack: case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(udpMsg.Port))) hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(udpMsg.Port)))
addr, err := net.ResolveUDPAddr("udp", hijackAddr) addr, err := c.Transport.OutResolveUDPAddr(hijackAddr)
if err == nil { if err == nil {
_, _ = conn.WriteToUDP(udpMsg.Data, addr) _, _ = conn.WriteToUDP(udpMsg.Data, addr)
if c.UpCounter != nil { if c.UpCounter != nil {
@ -158,7 +160,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
if c.ACLEngine != nil { if c.ACLEngine != nil {
action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(host) action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(host)
} else { } else {
ipAddr, err = net.ResolveIPAddr("ip", host) ipAddr, err = c.Transport.OutResolveIPAddr(host)
} }
if err != nil { if err != nil {
_ = struc.Pack(stream, &serverResponse{ _ = struc.Pack(stream, &serverResponse{
@ -173,7 +175,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
var conn net.Conn // Connection to be piped var conn net.Conn // Connection to be piped
switch action { switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
conn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ conn, err = c.Transport.OutDialTCP(nil, &net.TCPAddr{
IP: ipAddr.IP, IP: ipAddr.IP,
Port: int(port), Port: int(port),
Zone: ipAddr.Zone, Zone: ipAddr.Zone,
@ -194,7 +196,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
return return
case acl.ActionHijack: case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port))) hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port)))
conn, err = net.Dial("tcp", hijackAddr) conn, err = c.Transport.OutDial("tcp", hijackAddr)
if err != nil { if err != nil {
_ = struc.Pack(stream, &serverResponse{ _ = struc.Pack(stream, &serverResponse{
OK: false, OK: false,
@ -234,7 +236,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
func (c *serverClient) handleUDP(stream quic.Stream) { func (c *serverClient) handleUDP(stream quic.Stream) {
// Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it // Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it
conn, err := net.ListenUDP("udp", nil) conn, err := c.Transport.OutListenUDP(nil)
if err != nil { if err != nil {
_ = struc.Pack(stream, &serverResponse{ _ = struc.Pack(stream, &serverResponse{
OK: false, OK: false,

14
pkg/core/transport.go Normal file
View File

@ -0,0 +1,14 @@
package core
import "net"
type Transport interface {
QUICResolveUDPAddr(address string) (*net.UDPAddr, error)
QUICListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error)
OutResolveIPAddr(address string) (*net.IPAddr, error)
OutResolveUDPAddr(address string) (*net.UDPAddr, error)
OutDial(network, address string) (net.Conn, error)
OutDialTCP(laddr, raddr *net.TCPAddr) (*net.TCPConn, error)
OutListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error)
}

View File

@ -16,7 +16,7 @@ import (
"github.com/tobyxdd/hysteria/pkg/core" "github.com/tobyxdd/hysteria/pkg/core"
) )
func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEngine *acl.Engine, func NewProxyHTTPServer(hyClient *core.Client, transport core.Transport, idleTimeout time.Duration, aclEngine *acl.Engine,
newDialFunc func(reqAddr string, action acl.Action, arg string), newDialFunc func(reqAddr string, action acl.Action, arg string),
basicAuthFunc func(user, password string) bool) (*goproxy.ProxyHttpServer, error) { basicAuthFunc func(user, password string) bool) (*goproxy.ProxyHttpServer, error) {
proxy := goproxy.NewProxyHttpServer() proxy := goproxy.NewProxyHttpServer()
@ -44,7 +44,7 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng
if resErr != nil { if resErr != nil {
return nil, resErr return nil, resErr
} }
return net.DialTCP(network, nil, &net.TCPAddr{ return transport.OutDialTCP(nil, &net.TCPAddr{
IP: ipAddr.IP, IP: ipAddr.IP,
Port: int(port), Port: int(port),
Zone: ipAddr.Zone, Zone: ipAddr.Zone,
@ -54,7 +54,7 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng
case acl.ActionBlock: case acl.ActionBlock:
return nil, errors.New("blocked by ACL") return nil, errors.New("blocked by ACL")
case acl.ActionHijack: case acl.ActionHijack:
return net.Dial(network, net.JoinHostPort(arg, strconv.Itoa(int(port)))) return transport.OutDial(network, net.JoinHostPort(arg, strconv.Itoa(int(port))))
default: default:
return nil, fmt.Errorf("unknown action %d", action) return nil, fmt.Errorf("unknown action %d", action)
} }