diff --git a/README.md b/README.md index a074e32..38e822a 100644 --- a/README.md +++ b/README.md @@ -247,7 +247,7 @@ hysteria_traffic_uplink_bytes_total{auth="aGFja2VyISE="} 37452 }, "tun": { "name": "tun-hy", // TUN interface name - "timeout": 300, // UDP timeout in seconds + "timeout": 300, // Timeout in seconds "address": "192.0.2.2", // TUN interface address, not applicable for Linux "gateway": "192.0.2.1", // TUN interface gateway, not applicable for Linux "mask": "255.255.255.252", // TUN interface mask, not applicable for Linux diff --git a/README.zh.md b/README.zh.md index 1ab9216..44290d7 100644 --- a/README.zh.md +++ b/README.zh.md @@ -234,7 +234,7 @@ hysteria_traffic_uplink_bytes_total{auth="aGFja2VyISE="} 37452 }, "tun": { "name": "tun-hy", // TUN 接口名称 - "timeout": 300, // UDP 超时秒数 + "timeout": 300, // 超时秒数 "address": "192.0.2.2", // TUN 接口地址(不适用于 Linux) "gateway": "192.0.2.1", // TUN 接口网关(不适用于 Linux) "mask": "255.255.255.252", // TUN 接口子网掩码(不适用于 Linux) diff --git a/cmd/client.go b/cmd/client.go index a32f0ce..13fbc9f 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -20,6 +20,7 @@ import ( "io/ioutil" "net" "net/http" + "strings" "time" ) @@ -191,11 +192,40 @@ func client(config *clientConfig) { if len(config.TUN.Name) != 0 { go func() { - tunServer, err := tun.NewServer(client, time.Duration(config.TUN.Timeout)*time.Second, + tunServer, err := tun.NewServer(client, transport.DefaultTransport, + time.Duration(config.TUN.Timeout)*time.Second, config.TUN.Name, config.TUN.Address, config.TUN.Gateway, config.TUN.Mask, config.TUN.DNS, config.TUN.Persist) if err != nil { logrus.WithField("error", err).Fatal("Failed to initialize TUN server") } + tunServer.RequestFunc = func(addr net.Addr, reqAddr string, action acl.Action, arg string) { + logrus.WithFields(logrus.Fields{ + "action": actionToString(action, arg), + "src": addr.String(), + "dst": reqAddr, + }).Debugf("TUN %s request", strings.ToUpper(addr.Network())) + } + tunServer.ErrorFunc = func(addr net.Addr, reqAddr string, err error) { + if err != nil { + if err == io.EOF { + logrus.WithFields(logrus.Fields{ + "src": addr.String(), + "dst": reqAddr, + }).Debugf("TUN %s EOF", strings.ToUpper(addr.Network())) + } else if err == core.ErrClosed && strings.HasPrefix(addr.Network(), "udp") { + logrus.WithFields(logrus.Fields{ + "src": addr.String(), + "dst": reqAddr, + }).Debugf("TUN %s closed for timeout", strings.ToUpper(addr.Network())) + } else { + logrus.WithFields(logrus.Fields{ + "error": err, + "src": addr.String(), + "dst": reqAddr, + }).Infof("TUN %s error", strings.ToUpper(addr.Network())) + } + } + } errChan <- tunServer.ListenAndServe() }() } diff --git a/pkg/tun/server.go b/pkg/tun/server.go index a4d8863..b66cc8b 100644 --- a/pkg/tun/server.go +++ b/pkg/tun/server.go @@ -3,18 +3,26 @@ package tun import ( tun2socks "github.com/eycorsican/go-tun2socks/core" "github.com/eycorsican/go-tun2socks/tun" + "github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/core" + "github.com/tobyxdd/hysteria/pkg/transport" "io" + "net" "sync" "time" ) type Server struct { - HyClient *core.Client - Timeout time.Duration - TunDev io.ReadWriteCloser + HyClient *core.Client + Timeout time.Duration + TunDev io.ReadWriteCloser + Transport transport.Transport + ACLEngine *acl.Engine - udpConnMap map[tun2socks.UDPConn]*UDPConnInfo + RequestFunc func(addr net.Addr, reqAddr string, action acl.Action, arg string) + ErrorFunc func(addr net.Addr, reqAddr string, err error) + + udpConnMap map[tun2socks.UDPConn]*udpConnInfo udpConnMapLock sync.RWMutex } @@ -22,24 +30,27 @@ const ( MTU = 1500 ) -func NewServerWithTunDev(hyClient *core.Client, timeout time.Duration, +func NewServerWithTunDev(hyClient *core.Client, transport transport.Transport, + timeout time.Duration, tunDev io.ReadWriteCloser) (*Server, error) { s := &Server{ HyClient: hyClient, + Transport: transport, Timeout: timeout, TunDev: tunDev, - udpConnMap: make(map[tun2socks.UDPConn]*UDPConnInfo), + udpConnMap: make(map[tun2socks.UDPConn]*udpConnInfo), } return s, nil } -func NewServer(hyClient *core.Client, timeout time.Duration, +func NewServer(hyClient *core.Client, transport transport.Transport, + timeout time.Duration, name, address, gateway, mask string, dnsServers []string, persist bool) (*Server, error) { tunDev, err := tun.OpenTunDevice(name, address, gateway, mask, dnsServers, persist) if err != nil { return nil, err } - return NewServerWithTunDev(hyClient, timeout, tunDev) + return NewServerWithTunDev(hyClient, transport, timeout, tunDev) } func (s *Server) ListenAndServe() error { diff --git a/pkg/tun/tcp.go b/pkg/tun/tcp.go index bb82b36..d0e8ce7 100644 --- a/pkg/tun/tcp.go +++ b/pkg/tun/tcp.go @@ -1,73 +1,74 @@ package tun import ( - "io" + "errors" + "fmt" + "github.com/tobyxdd/hysteria/pkg/acl" + "github.com/tobyxdd/hysteria/pkg/utils" "net" + "strconv" ) func (s *Server) Handle(conn net.Conn, target *net.TCPAddr) error { - hyConn, err := s.HyClient.DialTCP(target.String()) - if err != nil { - return err + action, arg := acl.ActionProxy, "" + var resErr error + if s.ACLEngine != nil { + action, arg, _, resErr = s.ACLEngine.ResolveAndMatch(target.IP.String()) } - go s.relay(conn, hyConn) - return nil -} - -type direction byte - -const ( - directionUplink direction = iota - directionDownlink -) - -type duplexConn interface { - net.Conn - CloseRead() error - CloseWrite() error -} - -func (s *Server) relay(clientConn, relayConn net.Conn) { - uplinkDone := make(chan struct{}) - - halfCloseConn := func(dir direction, interrupt bool) { - clientDuplexConn, ok1 := clientConn.(duplexConn) - relayDuplexConn, ok2 := relayConn.(duplexConn) - if !interrupt && ok1 && ok2 { - switch dir { - case directionUplink: - clientDuplexConn.CloseRead() - relayDuplexConn.CloseWrite() - case directionDownlink: - clientDuplexConn.CloseWrite() - relayDuplexConn.CloseRead() - } - } else { - clientConn.Close() - relayConn.Close() - } + if s.RequestFunc != nil { + s.RequestFunc(conn.LocalAddr(), target.String(), action, arg) } - - // Uplink - go func() { - var err error - _, err = io.Copy(relayConn, clientConn) - if err != nil { - halfCloseConn(directionUplink, true) - } else { - halfCloseConn(directionUplink, false) + var closeErr error + defer func() { + if s.ErrorFunc != nil && closeErr != nil { + s.ErrorFunc(conn.LocalAddr(), target.String(), closeErr) } - uplinkDone <- struct{}{} }() - - // Downlink - var err error - _, err = io.Copy(clientConn, relayConn) - if err != nil { - halfCloseConn(directionDownlink, true) - } else { - halfCloseConn(directionDownlink, false) + switch action { + case acl.ActionDirect: + if resErr != nil { + closeErr = resErr + return resErr + } + rc, err := s.Transport.LocalDialTCP(nil, target) + if err != nil { + closeErr = err + return err + } + go s.relayTCP(conn, rc) + return nil + case acl.ActionProxy: + rc, err := s.HyClient.DialTCP(target.String()) + if err != nil { + closeErr = err + return err + } + go s.relayTCP(conn, rc) + return nil + case acl.ActionBlock: + closeErr = errors.New("blocked in ACL") + // caller will abort the connection when err != nil + return closeErr + case acl.ActionHijack: + rc, err := s.Transport.LocalDial("tcp", net.JoinHostPort(arg, strconv.Itoa(target.Port))) + if err != nil { + closeErr = err + return err + } + go s.relayTCP(conn, rc) + return nil + default: + closeErr = fmt.Errorf("unknown action %d", action) + // caller will abort the connection when err != nil + return closeErr } - - <-uplinkDone +} + +func (s *Server) relayTCP(clientConn, relayConn net.Conn) { + closeErr := utils.PipePairWithTimeout(clientConn, relayConn, s.Timeout) + if s.ErrorFunc != nil { + s.ErrorFunc(clientConn.LocalAddr(), relayConn.RemoteAddr().String(), closeErr) + } + relayConn.Close() + clientConn.Close() } diff --git a/pkg/tun/udp.go b/pkg/tun/udp.go index 9c737a4..9c74e27 100644 --- a/pkg/tun/udp.go +++ b/pkg/tun/udp.go @@ -1,21 +1,28 @@ package tun import ( + "bytes" + "errors" "fmt" tun2socks "github.com/eycorsican/go-tun2socks/core" + "github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/core" - "log" + "io" "net" + "strconv" "sync/atomic" "time" ) -type UDPConnInfo struct { +const udpBufferSize = 65535 + +type udpConnInfo struct { hyConn core.UDPConn + target string expire atomic.Value } -func (s *Server) fetchUDPInput(conn tun2socks.UDPConn, ci *UDPConnInfo) { +func (s *Server) fetchUDPInput(conn tun2socks.UDPConn, ci *udpConnInfo) { defer func() { s.closeUDPConn(conn) }() @@ -34,24 +41,85 @@ func (s *Server) fetchUDPInput(conn tun2socks.UDPConn, ci *UDPConnInfo) { }() } + var err error + for { - bs, from, err := ci.hyConn.ReadFrom() + var bs []byte + var from string + bs, from, err = ci.hyConn.ReadFrom() if err != nil { break } ci.expire.Store(time.Now().Add(s.Timeout)) udpAddr, _ := net.ResolveUDPAddr("udp", from) - _, _ = conn.WriteFrom(bs, udpAddr) + _, err = conn.WriteFrom(bs, udpAddr) + if err != nil { + break + } + } + + if s.ErrorFunc != nil { + s.ErrorFunc(conn.LocalAddr(), ci.target, err) } } func (s *Server) Connect(conn tun2socks.UDPConn, target *net.UDPAddr) error { - c, err := s.HyClient.DialUDP() - if err != nil { - return err + action, arg := acl.ActionProxy, "" + var resErr error + if s.ACLEngine != nil { + action, arg, _, resErr = s.ACLEngine.ResolveAndMatch(target.IP.String()) } - ci := UDPConnInfo{ - hyConn: c, + if s.RequestFunc != nil { + s.RequestFunc(conn.LocalAddr(), target.String(), action, arg) + } + var hyConn core.UDPConn + var closeErr error + defer func() { + if s.ErrorFunc != nil && closeErr != nil { + s.ErrorFunc(conn.LocalAddr(), target.String(), closeErr) + } + }() + switch action { + case acl.ActionDirect: + if resErr != nil { + closeErr = resErr + return resErr + } + var relayConn net.Conn + relayConn, closeErr = s.Transport.LocalDial("udp", target.String()) + if closeErr != nil { + return closeErr + } + hyConn = &delegatedUDPConn{ + underlayConn: relayConn, + delegatedRemoteAddr: target.String(), + } + case acl.ActionProxy: + hyConn, closeErr = s.HyClient.DialUDP() + if closeErr != nil { + return closeErr + } + case acl.ActionBlock: + closeErr = errors.New("blocked in ACL") + return closeErr + case acl.ActionHijack: + hijackAddr := net.JoinHostPort(arg, strconv.Itoa(target.Port)) + var relayConn net.Conn + relayConn, closeErr = s.Transport.LocalDial("udp", hijackAddr) + if closeErr != nil { + return closeErr + } + hyConn = &delegatedUDPConn{ + underlayConn: relayConn, + delegatedRemoteAddr: target.String(), + } + default: + closeErr = fmt.Errorf("unknown action %d", action) + return nil + } + ci := udpConnInfo{ + hyConn: hyConn, + target: net.JoinHostPort(target.IP.String(), strconv.Itoa(target.Port)), } ci.expire.Store(time.Now().Add(s.Timeout)) s.udpConnMapLock.Lock() @@ -66,8 +134,9 @@ func (s *Server) ReceiveTo(conn tun2socks.UDPConn, data []byte, addr *net.UDPAdd ci, ok := s.udpConnMap[conn] s.udpConnMapLock.RUnlock() if !ok { - log.Printf("not connected: %s <-> %s\n", conn.LocalAddr().String(), addr.String()) - return fmt.Errorf("not connected: %s <-> %s", conn.LocalAddr().String(), addr.String()) + err := errors.New("previous connection closed for timeout") + s.ErrorFunc(conn.LocalAddr(), addr.String(), err) + return err } ci.expire.Store(time.Now().Add(s.Timeout)) _ = ci.hyConn.WriteTo(data, addr.String()) @@ -83,3 +152,29 @@ func (s *Server) closeUDPConn(conn tun2socks.UDPConn) { delete(s.udpConnMap, conn) } } + +type delegatedUDPConn struct { + underlayConn net.Conn + delegatedRemoteAddr string +} + +func (c *delegatedUDPConn) ReadFrom() (bs []byte, addr string, err error) { + buf := make([]byte, udpBufferSize) + n, err := c.underlayConn.Read(buf) + if n > 0 { + bs = append(bs, buf[0:n]...) + } + if err != nil || err == io.EOF { + addr = c.delegatedRemoteAddr + } + return +} + +func (c *delegatedUDPConn) WriteTo(bs []byte, addr string) error { + _, err := io.Copy(c.underlayConn, bytes.NewReader(bs)) + return err +} + +func (c *delegatedUDPConn) Close() error { + return c.underlayConn.Close() +}