Use core.Transport

This commit is contained in:
Toby 2021-04-27 20:14:43 -07:00
parent 4da73888f4
commit 5ac95d987a
11 changed files with 120 additions and 49 deletions

View File

@ -75,7 +75,7 @@ func client(config *clientConfig) {
var aclEngine *acl.Engine var aclEngine *acl.Engine
if len(config.ACL) > 0 { if len(config.ACL) > 0 {
var err error var err error
aclEngine, err = acl.LoadFromFile(config.ACL) aclEngine, err = acl.LoadFromFile(config.ACL, core.DefaultTransport)
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"error": err, "error": err,
@ -84,7 +84,7 @@ func client(config *clientConfig) {
} }
} }
// Client // Client
client, err := core.NewClient(config.Server, auth, tlsConfig, quicConfig, client, err := core.NewClient(config.Server, auth, tlsConfig, quicConfig, core.DefaultTransport,
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
func(refBPS uint64) congestion.CongestionControl { func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
@ -105,7 +105,7 @@ func client(config *clientConfig) {
return config.SOCKS5.User == user && config.SOCKS5.Password == password return config.SOCKS5.User == user && config.SOCKS5.Password == password
} }
} }
socks5server, err := socks5.NewServer(client, config.SOCKS5.Listen, authFunc, socks5server, err := socks5.NewServer(client, core.DefaultTransport, config.SOCKS5.Listen, authFunc,
time.Duration(config.SOCKS5.Timeout)*time.Second, aclEngine, config.SOCKS5.DisableUDP, time.Duration(config.SOCKS5.Timeout)*time.Second, aclEngine, config.SOCKS5.DisableUDP,
func(addr net.Addr, reqAddr string, action acl.Action, arg string) { func(addr net.Addr, reqAddr string, action acl.Action, arg string) {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
@ -161,7 +161,8 @@ func client(config *clientConfig) {
return config.HTTP.User == user && config.HTTP.Password == password return config.HTTP.User == user && config.HTTP.Password == password
} }
} }
proxy, err := hyHTTP.NewProxyHTTPServer(client, time.Duration(config.HTTP.Timeout)*time.Second, aclEngine, proxy, err := hyHTTP.NewProxyHTTPServer(client, core.DefaultTransport,
time.Duration(config.HTTP.Timeout)*time.Second, aclEngine,
func(reqAddr string, action acl.Action, arg string) { func(reqAddr string, action acl.Action, arg string) {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"action": actionToString(action, arg), "action": actionToString(action, arg),
@ -184,7 +185,8 @@ func client(config *clientConfig) {
if len(config.TCPRelay.Listen) > 0 { if len(config.TCPRelay.Listen) > 0 {
go func() { go func() {
rl, err := relay.NewTCPRelay(client, config.TCPRelay.Listen, config.TCPRelay.Remote, rl, err := relay.NewTCPRelay(client, core.DefaultTransport,
config.TCPRelay.Listen, config.TCPRelay.Remote,
time.Duration(config.TCPRelay.Timeout)*time.Second, time.Duration(config.TCPRelay.Timeout)*time.Second,
func(addr net.Addr) { func(addr net.Addr) {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
@ -213,7 +215,8 @@ func client(config *clientConfig) {
if len(config.UDPRelay.Listen) > 0 { if len(config.UDPRelay.Listen) > 0 {
go func() { go func() {
rl, err := relay.NewUDPRelay(client, config.UDPRelay.Listen, config.UDPRelay.Remote, rl, err := relay.NewUDPRelay(client, core.DefaultTransport,
config.UDPRelay.Listen, config.UDPRelay.Remote,
time.Duration(config.UDPRelay.Timeout)*time.Second, time.Duration(config.UDPRelay.Timeout)*time.Second,
func(addr net.Addr) { func(addr net.Addr) {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
@ -242,8 +245,8 @@ func client(config *clientConfig) {
if len(config.TCPTProxy.Listen) > 0 { if len(config.TCPTProxy.Listen) > 0 {
go func() { go func() {
rl, err := tproxy.NewTCPTProxy(client, config.TCPTProxy.Listen, rl, err := tproxy.NewTCPTProxy(client, core.DefaultTransport,
time.Duration(config.TCPTProxy.Timeout)*time.Second, aclEngine, config.TCPTProxy.Listen, time.Duration(config.TCPTProxy.Timeout)*time.Second, aclEngine,
func(addr, reqAddr net.Addr, action acl.Action, arg string) { func(addr, reqAddr net.Addr, action acl.Action, arg string) {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"action": actionToString(action, arg), "action": actionToString(action, arg),
@ -275,8 +278,8 @@ func client(config *clientConfig) {
if len(config.UDPTProxy.Listen) > 0 { if len(config.UDPTProxy.Listen) > 0 {
go func() { go func() {
rl, err := tproxy.NewUDPTProxy(client, config.UDPTProxy.Listen, rl, err := tproxy.NewUDPTProxy(client, core.DefaultTransport,
time.Duration(config.UDPTProxy.Timeout)*time.Second, aclEngine, config.UDPTProxy.Listen, time.Duration(config.UDPTProxy.Timeout)*time.Second, aclEngine,
func(addr net.Addr) { func(addr net.Addr) {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"src": addr.String(), "src": addr.String(),

View File

@ -104,7 +104,7 @@ func server(config *serverConfig) {
// ACL // ACL
var aclEngine *acl.Engine var aclEngine *acl.Engine
if len(config.ACL) > 0 { if len(config.ACL) > 0 {
aclEngine, err = acl.LoadFromFile(config.ACL) aclEngine, err = acl.LoadFromFile(config.ACL, core.DefaultTransport)
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"error": err, "error": err,
@ -123,7 +123,7 @@ func server(config *serverConfig) {
logrus.WithField("error", err).Fatal("Prometheus HTTP server error") logrus.WithField("error", err).Fatal("Prometheus HTTP server error")
}() }()
} }
server, err := core.NewServer(config.Listen, tlsConfig, quicConfig, server, err := core.NewServer(config.Listen, tlsConfig, quicConfig, core.DefaultTransport,
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
func(refBPS uint64) congestion.CongestionControl { func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))

View File

@ -59,7 +59,7 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro
ip, zone := parseIPZone(host) ip, zone := parseIPZone(host)
if ip == nil { if ip == nil {
// Domain // Domain
ipAddr, err := e.Transport.OutResolveIPAddr(host) ipAddr, err := e.Transport.LocalResolveIPAddr(host)
if v, ok := e.Cache.Get(host); ok { if v, ok := e.Cache.Get(host); ok {
// Cache hit // Cache hit
ce := v.(cacheEntry) ce := v.(cacheEntry)

View File

@ -120,7 +120,7 @@ func (c *serverClient) handleMessage(msg []byte) {
if c.ACLEngine != nil { if c.ACLEngine != nil {
action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(udpMsg.Host) action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(udpMsg.Host)
} else { } else {
ipAddr, err = c.Transport.OutResolveIPAddr(udpMsg.Host) ipAddr, err = c.Transport.LocalResolveIPAddr(udpMsg.Host)
} }
if err != nil { if err != nil {
return return
@ -139,7 +139,7 @@ func (c *serverClient) handleMessage(msg []byte) {
// Do nothing // Do nothing
case acl.ActionHijack: case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(udpMsg.Port))) hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(udpMsg.Port)))
addr, err := c.Transport.OutResolveUDPAddr(hijackAddr) addr, err := c.Transport.LocalResolveUDPAddr(hijackAddr)
if err == nil { if err == nil {
_, _ = conn.WriteToUDP(udpMsg.Data, addr) _, _ = conn.WriteToUDP(udpMsg.Data, addr)
if c.UpCounter != nil { if c.UpCounter != nil {
@ -160,7 +160,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
if c.ACLEngine != nil { if c.ACLEngine != nil {
action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(host) action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(host)
} else { } else {
ipAddr, err = c.Transport.OutResolveIPAddr(host) ipAddr, err = c.Transport.LocalResolveIPAddr(host)
} }
if err != nil { if err != nil {
_ = struc.Pack(stream, &serverResponse{ _ = struc.Pack(stream, &serverResponse{
@ -175,7 +175,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
var conn net.Conn // Connection to be piped var conn net.Conn // Connection to be piped
switch action { switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
conn, err = c.Transport.OutDialTCP(nil, &net.TCPAddr{ conn, err = c.Transport.LocalDialTCP(nil, &net.TCPAddr{
IP: ipAddr.IP, IP: ipAddr.IP,
Port: int(port), Port: int(port),
Zone: ipAddr.Zone, Zone: ipAddr.Zone,
@ -196,7 +196,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
return return
case acl.ActionHijack: case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port))) hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port)))
conn, err = c.Transport.OutDial("tcp", hijackAddr) conn, err = c.Transport.LocalDial("tcp", hijackAddr)
if err != nil { if err != nil {
_ = struc.Pack(stream, &serverResponse{ _ = struc.Pack(stream, &serverResponse{
OK: false, OK: false,
@ -236,7 +236,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
func (c *serverClient) handleUDP(stream quic.Stream) { func (c *serverClient) handleUDP(stream quic.Stream) {
// Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it // Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it
conn, err := c.Transport.OutListenUDP(nil) conn, err := c.Transport.LocalListenUDP(nil)
if err != nil { if err != nil {
_ = struc.Pack(stream, &serverResponse{ _ = struc.Pack(stream, &serverResponse{
OK: false, OK: false,

View File

@ -1,14 +1,69 @@
package core package core
import "net" import (
"net"
"time"
)
type Transport interface { type Transport interface {
QUICResolveUDPAddr(address string) (*net.UDPAddr, error) QUICResolveUDPAddr(address string) (*net.UDPAddr, error)
QUICListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error) QUICListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error)
OutResolveIPAddr(address string) (*net.IPAddr, error) LocalResolveIPAddr(address string) (*net.IPAddr, error)
OutResolveUDPAddr(address string) (*net.UDPAddr, error) LocalResolveTCPAddr(address string) (*net.TCPAddr, error)
OutDial(network, address string) (net.Conn, error) LocalResolveUDPAddr(address string) (*net.UDPAddr, error)
OutDialTCP(laddr, raddr *net.TCPAddr) (*net.TCPConn, error) LocalDial(network, address string) (net.Conn, error)
OutListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error) LocalDialTCP(laddr, raddr *net.TCPAddr) (*net.TCPConn, error)
LocalListenTCP(laddr *net.TCPAddr) (*net.TCPListener, error)
LocalListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error)
}
var DefaultTransport Transport = &defaultTransport{
Timeout: 8 * time.Second,
}
type defaultTransport struct {
Timeout time.Duration
}
func (t *defaultTransport) QUICResolveUDPAddr(address string) (*net.UDPAddr, error) {
return net.ResolveUDPAddr("udp", address)
}
func (t *defaultTransport) QUICListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP("udp", laddr)
}
func (t *defaultTransport) LocalResolveIPAddr(address string) (*net.IPAddr, error) {
return net.ResolveIPAddr("ip", address)
}
func (t *defaultTransport) LocalResolveTCPAddr(address string) (*net.TCPAddr, error) {
return net.ResolveTCPAddr("tcp", address)
}
func (t *defaultTransport) LocalResolveUDPAddr(address string) (*net.UDPAddr, error) {
return net.ResolveUDPAddr("udp", address)
}
func (t *defaultTransport) LocalDial(network, address string) (net.Conn, error) {
dialer := &net.Dialer{Timeout: t.Timeout}
return dialer.Dial(network, address)
}
func (t *defaultTransport) LocalDialTCP(laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
dialer := &net.Dialer{Timeout: t.Timeout, LocalAddr: laddr}
conn, err := dialer.Dial("tcp", raddr.String())
if err != nil {
return nil, err
}
return conn.(*net.TCPConn), nil
}
func (t *defaultTransport) LocalListenTCP(laddr *net.TCPAddr) (*net.TCPListener, error) {
return net.ListenTCP("tcp", laddr)
}
func (t *defaultTransport) LocalListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error) {
return net.ListenUDP("udp", laddr)
} }

View File

@ -44,7 +44,7 @@ func NewProxyHTTPServer(hyClient *core.Client, transport core.Transport, idleTim
if resErr != nil { if resErr != nil {
return nil, resErr return nil, resErr
} }
return transport.OutDialTCP(nil, &net.TCPAddr{ return transport.LocalDialTCP(nil, &net.TCPAddr{
IP: ipAddr.IP, IP: ipAddr.IP,
Port: int(port), Port: int(port),
Zone: ipAddr.Zone, Zone: ipAddr.Zone,
@ -54,7 +54,7 @@ func NewProxyHTTPServer(hyClient *core.Client, transport core.Transport, idleTim
case acl.ActionBlock: case acl.ActionBlock:
return nil, errors.New("blocked by ACL") return nil, errors.New("blocked by ACL")
case acl.ActionHijack: case acl.ActionHijack:
return transport.OutDial(network, net.JoinHostPort(arg, strconv.Itoa(int(port)))) return transport.LocalDial(network, net.JoinHostPort(arg, strconv.Itoa(int(port))))
default: default:
return nil, fmt.Errorf("unknown action %d", action) return nil, fmt.Errorf("unknown action %d", action)
} }

View File

@ -9,6 +9,7 @@ import (
type TCPRelay struct { type TCPRelay struct {
HyClient *core.Client HyClient *core.Client
Transport core.Transport
ListenAddr *net.TCPAddr ListenAddr *net.TCPAddr
Remote string Remote string
Timeout time.Duration Timeout time.Duration
@ -17,14 +18,15 @@ type TCPRelay struct {
ErrorFunc func(addr net.Addr, err error) ErrorFunc func(addr net.Addr, err error)
} }
func NewTCPRelay(hyClient *core.Client, listen, remote string, timeout time.Duration, func NewTCPRelay(hyClient *core.Client, transport core.Transport, listen, remote string, timeout time.Duration,
connFunc func(addr net.Addr), errorFunc func(addr net.Addr, err error)) (*TCPRelay, error) { connFunc func(addr net.Addr), errorFunc func(addr net.Addr, err error)) (*TCPRelay, error) {
tAddr, err := net.ResolveTCPAddr("tcp", listen) tAddr, err := transport.LocalResolveTCPAddr(listen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r := &TCPRelay{ r := &TCPRelay{
HyClient: hyClient, HyClient: hyClient,
Transport: transport,
ListenAddr: tAddr, ListenAddr: tAddr,
Remote: remote, Remote: remote,
Timeout: timeout, Timeout: timeout,
@ -35,7 +37,7 @@ func NewTCPRelay(hyClient *core.Client, listen, remote string, timeout time.Dura
} }
func (r *TCPRelay) ListenAndServe() error { func (r *TCPRelay) ListenAndServe() error {
listener, err := net.ListenTCP("tcp", r.ListenAddr) listener, err := r.Transport.LocalListenTCP(r.ListenAddr)
if err != nil { if err != nil {
return err return err
} }

View File

@ -15,6 +15,7 @@ var ErrTimeout = errors.New("inactivity timeout")
type UDPRelay struct { type UDPRelay struct {
HyClient *core.Client HyClient *core.Client
Transport core.Transport
ListenAddr *net.UDPAddr ListenAddr *net.UDPAddr
Remote string Remote string
Timeout time.Duration Timeout time.Duration
@ -23,14 +24,15 @@ type UDPRelay struct {
ErrorFunc func(addr net.Addr, err error) ErrorFunc func(addr net.Addr, err error)
} }
func NewUDPRelay(hyClient *core.Client, listen, remote string, timeout time.Duration, func NewUDPRelay(hyClient *core.Client, transport core.Transport, listen, remote string, timeout time.Duration,
connFunc func(addr net.Addr), errorFunc func(addr net.Addr, err error)) (*UDPRelay, error) { connFunc func(addr net.Addr), errorFunc func(addr net.Addr, err error)) (*UDPRelay, error) {
uAddr, err := net.ResolveUDPAddr("udp", listen) uAddr, err := transport.LocalResolveUDPAddr(listen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r := &UDPRelay{ r := &UDPRelay{
HyClient: hyClient, HyClient: hyClient,
Transport: transport,
ListenAddr: uAddr, ListenAddr: uAddr,
Remote: remote, Remote: remote,
Timeout: timeout, Timeout: timeout,
@ -49,7 +51,7 @@ type connEntry struct {
} }
func (r *UDPRelay) ListenAndServe() error { func (r *UDPRelay) ListenAndServe() error {
conn, err := net.ListenUDP("udp", r.ListenAddr) conn, err := r.Transport.LocalListenUDP(r.ListenAddr)
if err != nil { if err != nil {
return err return err
} }

View File

@ -25,6 +25,7 @@ var (
type Server struct { type Server struct {
HyClient *core.Client HyClient *core.Client
Transport core.Transport
AuthFunc func(username, password string) bool AuthFunc func(username, password string) bool
Method byte Method byte
TCPAddr *net.TCPAddr TCPAddr *net.TCPAddr
@ -40,12 +41,13 @@ type Server struct {
tcpListener *net.TCPListener tcpListener *net.TCPListener
} }
func NewServer(hyClient *core.Client, addr string, authFunc func(username, password string) bool, tcpTimeout time.Duration, func NewServer(hyClient *core.Client, transport core.Transport, addr string,
authFunc func(username, password string) bool, tcpTimeout time.Duration,
aclEngine *acl.Engine, disableUDP bool, aclEngine *acl.Engine, disableUDP bool,
tcpReqFunc func(addr net.Addr, reqAddr string, action acl.Action, arg string), tcpReqFunc func(addr net.Addr, reqAddr string, action acl.Action, arg string),
tcpErrorFunc func(addr net.Addr, reqAddr string, err error), tcpErrorFunc func(addr net.Addr, reqAddr string, err error),
udpAssocFunc func(addr net.Addr), udpErrorFunc func(addr net.Addr, err error)) (*Server, error) { udpAssocFunc func(addr net.Addr), udpErrorFunc func(addr net.Addr, err error)) (*Server, error) {
tAddr, err := net.ResolveTCPAddr("tcp", addr) tAddr, err := transport.LocalResolveTCPAddr(addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -55,6 +57,7 @@ func NewServer(hyClient *core.Client, addr string, authFunc func(username, passw
} }
s := &Server{ s := &Server{
HyClient: hyClient, HyClient: hyClient,
Transport: transport,
AuthFunc: authFunc, AuthFunc: authFunc,
Method: m, Method: m,
TCPAddr: tAddr, TCPAddr: tAddr,
@ -114,7 +117,7 @@ func (s *Server) negotiate(c *net.TCPConn) error {
func (s *Server) ListenAndServe() error { func (s *Server) ListenAndServe() error {
var err error var err error
s.tcpListener, err = net.ListenTCP("tcp", s.TCPAddr) s.tcpListener, err = s.Transport.LocalListenTCP(s.TCPAddr)
if err != nil { if err != nil {
return err return err
} }
@ -183,7 +186,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
closeErr = resErr closeErr = resErr
return resErr return resErr
} }
rc, err := net.DialTCP("tcp", nil, &net.TCPAddr{ rc, err := s.Transport.LocalDialTCP(nil, &net.TCPAddr{
IP: ipAddr.IP, IP: ipAddr.IP,
Port: int(port), Port: int(port),
Zone: ipAddr.Zone, Zone: ipAddr.Zone,
@ -213,7 +216,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
closeErr = errors.New("blocked in ACL") closeErr = errors.New("blocked in ACL")
return nil return nil
case acl.ActionHijack: case acl.ActionHijack:
rc, err := net.Dial("tcp", net.JoinHostPort(arg, strconv.Itoa(int(port)))) rc, err := s.Transport.LocalDial("tcp", net.JoinHostPort(arg, strconv.Itoa(int(port))))
if err != nil { if err != nil {
_ = sendReply(c, socks5.RepHostUnreachable) _ = sendReply(c, socks5.RepHostUnreachable)
closeErr = err closeErr = err
@ -237,7 +240,7 @@ func (s *Server) handleUDP(c *net.TCPConn, r *socks5.Request) error {
s.UDPErrorFunc(c.RemoteAddr(), closeErr) s.UDPErrorFunc(c.RemoteAddr(), closeErr)
}() }()
// Start local UDP server // Start local UDP server
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{ udpConn, err := s.Transport.LocalListenUDP(&net.UDPAddr{
IP: s.TCPAddr.IP, IP: s.TCPAddr.IP,
Zone: s.TCPAddr.Zone, Zone: s.TCPAddr.Zone,
}) })
@ -250,7 +253,7 @@ func (s *Server) handleUDP(c *net.TCPConn, r *socks5.Request) error {
// Local UDP relay conn for ACL Direct // Local UDP relay conn for ACL Direct
var localRelayConn *net.UDPConn var localRelayConn *net.UDPConn
if s.ACLEngine != nil { if s.ACLEngine != nil {
localRelayConn, err = net.ListenUDP("udp", nil) localRelayConn, err = s.Transport.LocalListenUDP(nil)
if err != nil { if err != nil {
_ = sendReply(c, socks5.RepServerFailure) _ = sendReply(c, socks5.RepServerFailure)
closeErr = err closeErr = err
@ -371,7 +374,7 @@ func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn,
// Do nothing // Do nothing
case acl.ActionHijack: case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port))) hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port)))
rAddr, err := net.ResolveUDPAddr("udp", hijackAddr) rAddr, err := s.Transport.LocalResolveUDPAddr(hijackAddr)
if err == nil { if err == nil {
_, _ = localRelayConn.WriteToUDP(d.Data, rAddr) _, _ = localRelayConn.WriteToUDP(d.Data, rAddr)
} }

View File

@ -14,6 +14,7 @@ import (
type TCPTProxy struct { type TCPTProxy struct {
HyClient *core.Client HyClient *core.Client
Transport core.Transport
ListenAddr *net.TCPAddr ListenAddr *net.TCPAddr
Timeout time.Duration Timeout time.Duration
ACLEngine *acl.Engine ACLEngine *acl.Engine
@ -22,15 +23,17 @@ type TCPTProxy struct {
ErrorFunc func(addr, reqAddr net.Addr, err error) ErrorFunc func(addr, reqAddr net.Addr, err error)
} }
func NewTCPTProxy(hyClient *core.Client, listen string, timeout time.Duration, aclEngine *acl.Engine, func NewTCPTProxy(hyClient *core.Client, transport core.Transport, listen string, timeout time.Duration,
aclEngine *acl.Engine,
connFunc func(addr, reqAddr net.Addr, action acl.Action, arg string), connFunc func(addr, reqAddr net.Addr, action acl.Action, arg string),
errorFunc func(addr, reqAddr net.Addr, err error)) (*TCPTProxy, error) { errorFunc func(addr, reqAddr net.Addr, err error)) (*TCPTProxy, error) {
tAddr, err := net.ResolveTCPAddr("tcp", listen) tAddr, err := transport.LocalResolveTCPAddr(listen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r := &TCPTProxy{ r := &TCPTProxy{
HyClient: hyClient, HyClient: hyClient,
Transport: transport,
ListenAddr: tAddr, ListenAddr: tAddr,
Timeout: timeout, Timeout: timeout,
ACLEngine: aclEngine, ACLEngine: aclEngine,
@ -79,7 +82,7 @@ func (r *TCPTProxy) ListenAndServe() error {
closeErr = resErr closeErr = resErr
return return
} }
rc, err := net.DialTCP("tcp", nil, &net.TCPAddr{ rc, err := r.Transport.LocalDialTCP(nil, &net.TCPAddr{
IP: ipAddr.IP, IP: ipAddr.IP,
Port: int(port), Port: int(port),
Zone: ipAddr.Zone, Zone: ipAddr.Zone,
@ -104,7 +107,7 @@ func (r *TCPTProxy) ListenAndServe() error {
closeErr = errors.New("blocked in ACL") closeErr = errors.New("blocked in ACL")
return return
case acl.ActionHijack: case acl.ActionHijack:
rc, err := net.Dial("tcp", net.JoinHostPort(arg, strconv.Itoa(int(port)))) rc, err := r.Transport.LocalDial("tcp", net.JoinHostPort(arg, strconv.Itoa(int(port))))
if err != nil { if err != nil {
closeErr = err closeErr = err
return return

View File

@ -19,6 +19,7 @@ var ErrTimeout = errors.New("inactivity timeout")
type UDPTProxy struct { type UDPTProxy struct {
HyClient *core.Client HyClient *core.Client
Transport core.Transport
ListenAddr *net.UDPAddr ListenAddr *net.UDPAddr
Timeout time.Duration Timeout time.Duration
ACLEngine *acl.Engine ACLEngine *acl.Engine
@ -27,14 +28,16 @@ type UDPTProxy struct {
ErrorFunc func(addr net.Addr, err error) ErrorFunc func(addr net.Addr, err error)
} }
func NewUDPTProxy(hyClient *core.Client, listen string, timeout time.Duration, aclEngine *acl.Engine, func NewUDPTProxy(hyClient *core.Client, transport core.Transport, listen string, timeout time.Duration,
aclEngine *acl.Engine,
connFunc func(addr net.Addr), errorFunc func(addr net.Addr, err error)) (*UDPTProxy, error) { connFunc func(addr net.Addr), errorFunc func(addr net.Addr, err error)) (*UDPTProxy, error) {
uAddr, err := net.ResolveUDPAddr("udp", listen) uAddr, err := transport.LocalResolveUDPAddr(listen)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r := &UDPTProxy{ r := &UDPTProxy{
HyClient: hyClient, HyClient: hyClient,
Transport: transport,
ListenAddr: uAddr, ListenAddr: uAddr,
Timeout: timeout, Timeout: timeout,
ACLEngine: aclEngine, ACLEngine: aclEngine,
@ -84,7 +87,7 @@ func (r *UDPTProxy) sendPacket(entry *connEntry, dstAddr *net.UDPAddr, data []by
return nil return nil
case acl.ActionHijack: case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port))) hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port)))
rAddr, err := net.ResolveUDPAddr("udp", hijackAddr) rAddr, err := r.Transport.LocalResolveUDPAddr(hijackAddr)
if err != nil { if err != nil {
return err return err
} }
@ -126,7 +129,7 @@ func (r *UDPTProxy) ListenAndServe() error {
} }
var localConn *net.UDPConn var localConn *net.UDPConn
if r.ACLEngine != nil { if r.ACLEngine != nil {
localConn, err = net.ListenUDP("udp", nil) localConn, err = r.Transport.LocalListenUDP(nil)
if err != nil { if err != nil {
r.ErrorFunc(srcAddr, err) r.ErrorFunc(srcAddr, err)
continue continue