From 5b54edd09ad07ebd4f1eed323cf91005551938ca Mon Sep 17 00:00:00 2001 From: tobyxdd Date: Wed, 7 Jun 2023 22:37:09 -0700 Subject: [PATCH] feat: traffic logger (wip, disconnect not done) --- .../integration_tests/trafficlogger_test.go | 150 ++++++++++++++++++ core/internal/integration_tests/utils_test.go | 59 +++++++ core/server/config.go | 16 +- core/server/copy.go | 58 +++++++ core/server/server.go | 37 +++-- 5 files changed, 302 insertions(+), 18 deletions(-) create mode 100644 core/internal/integration_tests/trafficlogger_test.go create mode 100644 core/server/copy.go diff --git a/core/internal/integration_tests/trafficlogger_test.go b/core/internal/integration_tests/trafficlogger_test.go new file mode 100644 index 0000000..6b41123 --- /dev/null +++ b/core/internal/integration_tests/trafficlogger_test.go @@ -0,0 +1,150 @@ +package integration_tests + +import ( + "io" + "net" + "sync/atomic" + "testing" + + "github.com/apernet/hysteria/core/client" + "github.com/apernet/hysteria/core/server" +) + +type testTrafficLogger struct { + Tx, Rx uint64 +} + +func (l *testTrafficLogger) Log(id string, tx, rx uint64) bool { + atomic.AddUint64(&l.Tx, tx) + atomic.AddUint64(&l.Rx, rx) + return true +} + +func (l *testTrafficLogger) Get() (tx, rx uint64) { + return atomic.LoadUint64(&l.Tx), atomic.LoadUint64(&l.Rx) +} + +func (l *testTrafficLogger) Reset() { + atomic.StoreUint64(&l.Tx, 0) + atomic.StoreUint64(&l.Rx, 0) +} + +// TestServerTrafficLogger tests that the server's TrafficLogger interface is working correctly. +func TestServerTrafficLogger(t *testing.T) { + tl := &testTrafficLogger{} + + // Create server + udpAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 14514} + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + t.Fatal("error creating server:", err) + } + s, err := server.NewServer(&server.Config{ + TLSConfig: serverTLSConfig(), + Conn: udpConn, + Authenticator: &pwAuthenticator{ + Password: "password", + ID: "nobody", + }, + TrafficLogger: tl, + }) + if err != nil { + t.Fatal("error creating server:", err) + } + defer s.Close() + go s.Serve() + + // Create TCP double echo server + // We use double echo to test that the traffic logger is correctly logging both directions. + echoTCPAddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 14515} + echoListener, err := net.ListenTCP("tcp", echoTCPAddr) + if err != nil { + t.Fatal("error creating TCP echo server:", err) + } + tEchoServer := &tcpDoubleEchoServer{Listener: echoListener} + defer tEchoServer.Close() + go tEchoServer.Serve() + + // Create client + c, err := client.NewClient(&client.Config{ + ServerAddr: udpAddr, + Auth: "password", + TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, + }) + if err != nil { + t.Fatal("error creating client:", err) + } + defer c.Close() + + // Dial TCP + tConn, err := c.DialTCP(echoTCPAddr.String()) + if err != nil { + t.Fatal("error dialing TCP:", err) + } + defer tConn.Close() + + // Send and receive TCP data + sData := []byte("1234") + _, err = tConn.Write(sData) + if err != nil { + t.Fatal("error writing to TCP:", err) + } + rData := make([]byte, len(sData)*2) + _, err = io.ReadFull(tConn, rData) + if err != nil { + t.Fatal("error reading from TCP:", err) + } + expected := string(sData) + string(sData) + if string(rData) != expected { + t.Fatalf("expected %q, got %q", expected, string(rData)) + } + + // Check traffic logger + tx, rx := tl.Get() + if tx != uint64(len(sData)) || rx != uint64(len(rData)) { + t.Fatalf("expected TrafficLogger Tx=%d, Rx=%d, got Tx=%d, Rx=%d", len(sData), len(rData), tx, rx) + } + tl.Reset() + + // Create UDP double echo server + echoUDPAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 55555} + echoConn, err := net.ListenUDP("udp", echoUDPAddr) + if err != nil { + t.Fatal("error creating UDP echo server:", err) + } + uEchoServer := &udpDoubleEchoServer{Conn: echoConn} + defer uEchoServer.Close() + go uEchoServer.Serve() + + // Listen UDP + uConn, err := c.ListenUDP() + if err != nil { + t.Fatal("error listening UDP:", err) + } + defer uConn.Close() + + // Send and receive UDP data + sData = []byte("gucci gang") + err = uConn.Send(sData, echoUDPAddr.String()) + if err != nil { + t.Fatal("error sending UDP:", err) + } + for i := 0; i < 2; i++ { + rData, rAddr, err := uConn.Receive() + if err != nil { + t.Fatal("error receiving UDP:", err) + } + if string(rData) != string(sData) { + t.Fatalf("expected %q, got %q", string(sData), string(rData)) + } + if rAddr != echoUDPAddr.String() { + t.Fatalf("expected %q, got %q", echoUDPAddr.String(), rAddr) + } + } + + // Check traffic logger + tx, rx = tl.Get() + if tx != uint64(len(sData)) || rx != uint64(len(sData)*2) { + t.Fatalf("expected TrafficLogger Tx=%d, Rx=%d, got Tx=%d, Rx=%d", len(sData), len(sData)*2, tx, rx) + } +} diff --git a/core/internal/integration_tests/utils_test.go b/core/internal/integration_tests/utils_test.go index fb90dcb..ae7e9b4 100644 --- a/core/internal/integration_tests/utils_test.go +++ b/core/internal/integration_tests/utils_test.go @@ -61,6 +61,37 @@ func (s *tcpEchoServer) Close() error { return s.Listener.Close() } +// tcpDoubleEchoServer is a TCP server that echoes twice what it reads from the connection. +// It will never actively close the connection. +type tcpDoubleEchoServer struct { + Listener net.Listener +} + +func (s *tcpDoubleEchoServer) Serve() error { + for { + conn, err := s.Listener.Accept() + if err != nil { + return err + } + go func() { + buf := make([]byte, 1024) + for { + n, err := conn.Read(buf) + if err != nil { + _ = conn.Close() + return + } + _, _ = conn.Write(buf[:n]) + _, _ = conn.Write(buf[:n]) + } + }() + } +} + +func (s *tcpDoubleEchoServer) Close() error { + return s.Listener.Close() +} + type sinkEvent struct { Data []byte Err error @@ -140,6 +171,34 @@ func (s *udpEchoServer) Close() error { return s.Conn.Close() } +// udpDoubleEchoServer is a UDP server that echoes twice what it reads from the connection. +// It will never actively close the connection. +type udpDoubleEchoServer struct { + Conn net.PacketConn +} + +func (s *udpDoubleEchoServer) Serve() error { + buf := make([]byte, 65536) + for { + n, addr, err := s.Conn.ReadFrom(buf) + if err != nil { + return err + } + _, err = s.Conn.WriteTo(buf[:n], addr) + if err != nil { + return err + } + _, err = s.Conn.WriteTo(buf[:n], addr) + if err != nil { + return err + } + } +} + +func (s *udpDoubleEchoServer) Close() error { + return s.Conn.Close() +} + type connectEvent struct { Addr net.Addr ID string diff --git a/core/server/config.go b/core/server/config.go index edc62c6..a676f6a 100644 --- a/core/server/config.go +++ b/core/server/config.go @@ -26,8 +26,8 @@ type Config struct { DisableUDP bool Authenticator Authenticator EventLogger EventLogger - // TODO: TrafficLogger - MasqHandler http.Handler + TrafficLogger TrafficLogger + MasqHandler http.Handler } // fill fills the fields that are not set by the user with default values when possible, @@ -170,3 +170,15 @@ type EventLogger interface { UDPRequest(addr net.Addr, id string, sessionID uint32) UDPError(addr net.Addr, id string, sessionID uint32, err error) } + +// TrafficLogger is an interface that provides traffic logging logic. +// Tx/Rx in this context refers to the server-remote (proxy target) perspective. +// Tx is the bytes sent from the server to the remote. +// Rx is the bytes received by the server from the remote. +// Apart from logging, the Log function can also return false to signal +// that the client should be disconnected. This can be used to implement +// bandwidth limits or post-connection authentication, for example. +// The implementation of this interface must be thread-safe. +type TrafficLogger interface { + Log(id string, tx, rx uint64) bool +} diff --git a/core/server/copy.go b/core/server/copy.go new file mode 100644 index 0000000..5f470d1 --- /dev/null +++ b/core/server/copy.go @@ -0,0 +1,58 @@ +package server + +import "io" + +func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64)) error { + buf := make([]byte, 32*1024) + for { + nr, er := src.Read(buf) + if nr > 0 { + nw, ew := dst.Write(buf[0:nr]) + if nw > 0 { + log(uint64(nw)) + } + if ew != nil { + return ew + } + } + if er != nil { + if er == io.EOF { + // EOF should not be considered as an error + return nil + } + return er + } + } +} + +func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l TrafficLogger) error { + errChan := make(chan error, 2) + go func() { + errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) { + l.Log(id, 0, n) + }) + }() + go func() { + errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) { + l.Log(id, n, 0) + }) + }() + // Block until one of the two goroutines returns + return <-errChan +} + +// copyTwoWay is the "fast-path" version of copyTwoWayWithLogger that does not log traffic. +// It uses the built-in io.Copy instead of our own copyBufferLog. +func copyTwoWay(serverRw, remoteRw io.ReadWriter) error { + errChan := make(chan error, 2) + go func() { + _, err := io.Copy(serverRw, remoteRw) + errChan <- err + }() + go func() { + _, err := io.Copy(remoteRw, serverRw) + errChan <- err + }() + // Block until one of the two goroutines returns + return <-errChan +} diff --git a/core/server/server.go b/core/server/server.go index a3681ed..90550bf 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -154,21 +154,25 @@ func (m *udpSessionManager) Add() (uint32, UDPConn, func(), error) { }, nil } -func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) { +// Feed feeds a UDP message to the session manager. +// If the message itself is a complete message, or it's the last fragment of a message, +// it will be sent to the UDP connection. +// The function will then return the number of bytes sent and any error occurred. +func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) (int, error) { m.mutex.RLock() defer m.mutex.RUnlock() entry, ok := m.m[msg.SessionID] if !ok { // No such session, drop the message - return + return 0, nil } dfMsg := entry.D.Feed(msg) if dfMsg == nil { // Not a complete message yet - return + return 0, nil } - _, _ = entry.Conn.WriteTo(dfMsg.Data, dfMsg.Addr) + return entry.Conn.WriteTo(dfMsg.Data, dfMsg.Addr) } func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -263,17 +267,12 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { } _ = protocol.WriteTCPResponse(stream, true, "") // Start proxying - copyErrChan := make(chan error, 2) - go func() { - _, err := io.Copy(tConn, stream) - copyErrChan <- err - }() - go func() { - _, err := io.Copy(stream, tConn) - copyErrChan <- err - }() - // Block until one of the copy goroutines exits - err = <-copyErrChan + if h.config.TrafficLogger != nil { + err = copyTwoWayWithLogger(h.authID, stream, tConn, h.config.TrafficLogger) + } else { + // Use the fast path if no traffic logger is set + err = copyTwoWay(stream, tConn) + } if h.config.EventLogger != nil { h.config.EventLogger.TCPError(h.conn.RemoteAddr(), h.authID, reqAddr, err) } @@ -316,6 +315,9 @@ func (h *h3sHandler) handleUDPRequest(stream quic.Stream) { for { udpN, rAddr, err := conn.ReadFrom(udpBuf) if udpN > 0 { + if h.config.TrafficLogger != nil { + h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN)) + } // Try no frag first msg := protocol.UDPMessage{ SessionID: sessionID, @@ -379,7 +381,10 @@ func (h *h3sHandler) handleUDPMessage(msg []byte) { if err != nil { return } - h.udpSM.Feed(udpMsg) + n, _ := h.udpSM.Feed(udpMsg) + if n > 0 && h.config.TrafficLogger != nil { + h.config.TrafficLogger.Log(h.authID, uint64(n), 0) + } } func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {