diff --git a/core/internal/integration_tests/trafficlogger_test.go b/core/internal/integration_tests/trafficlogger_test.go index 6b41123..3d79f8f 100644 --- a/core/internal/integration_tests/trafficlogger_test.go +++ b/core/internal/integration_tests/trafficlogger_test.go @@ -12,24 +12,31 @@ import ( type testTrafficLogger struct { Tx, Rx uint64 + Block atomic.Bool } func (l *testTrafficLogger) Log(id string, tx, rx uint64) bool { atomic.AddUint64(&l.Tx, tx) atomic.AddUint64(&l.Rx, rx) - return true + return !l.Block.Load() } func (l *testTrafficLogger) Get() (tx, rx uint64) { return atomic.LoadUint64(&l.Tx), atomic.LoadUint64(&l.Rx) } +func (l *testTrafficLogger) SetBlock(block bool) { + l.Block.Store(block) +} + func (l *testTrafficLogger) Reset() { atomic.StoreUint64(&l.Tx, 0) atomic.StoreUint64(&l.Rx, 0) } // TestServerTrafficLogger tests that the server's TrafficLogger interface is working correctly. +// More specifically, it tests that the server is correctly logging traffic in both directions, +// and that it is correctly disconnecting clients when the traffic logger returns false. func TestServerTrafficLogger(t *testing.T) { tl := &testTrafficLogger{} @@ -147,4 +154,20 @@ func TestServerTrafficLogger(t *testing.T) { if tx != uint64(len(sData)) || rx != uint64(len(sData)*2) { t.Fatalf("expected TrafficLogger Tx=%d, Rx=%d, got Tx=%d, Rx=%d", len(sData), len(sData)*2, tx, rx) } + + // Check the disconnect client functionality + tl.SetBlock(true) + + // Send and receive TCP data again + sData = []byte("1234") + _, err = tConn.Write(sData) + if err != nil { + t.Fatal("error writing to TCP:", err) + } + // This should fail instantly without reading any data + // io.Copy should return nil as EOF is treated as a non-error though + n, err := io.Copy(io.Discard, tConn) + if n != 0 || err != nil { + t.Fatal("expected 0 bytes read and nil error, got", n, err) + } } diff --git a/core/server/config.go b/core/server/config.go index a676f6a..baeea67 100644 --- a/core/server/config.go +++ b/core/server/config.go @@ -180,5 +180,5 @@ type EventLogger interface { // bandwidth limits or post-connection authentication, for example. // The implementation of this interface must be thread-safe. type TrafficLogger interface { - Log(id string, tx, rx uint64) bool + Log(id string, tx, rx uint64) (ok bool) } diff --git a/core/server/copy.go b/core/server/copy.go index 5f470d1..25831e5 100644 --- a/core/server/copy.go +++ b/core/server/copy.go @@ -1,16 +1,22 @@ package server -import "io" +import ( + "errors" + "io" +) -func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64)) error { +var errDisconnect = errors.New("traffic logger requested disconnect") + +func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64) bool) error { buf := make([]byte, 32*1024) for { nr, er := src.Read(buf) if nr > 0 { - nw, ew := dst.Write(buf[0:nr]) - if nw > 0 { - log(uint64(nw)) + if !log(uint64(nr)) { + // Log returns false, which means that the client should be disconnected + return errDisconnect } + _, ew := dst.Write(buf[0:nr]) if ew != nil { return ew } @@ -28,13 +34,13 @@ func copyBufferLog(dst io.Writer, src io.Reader, log func(n uint64)) error { func copyTwoWayWithLogger(id string, serverRw, remoteRw io.ReadWriter, l TrafficLogger) error { errChan := make(chan error, 2) go func() { - errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) { - l.Log(id, 0, n) + errChan <- copyBufferLog(serverRw, remoteRw, func(n uint64) bool { + return l.Log(id, 0, n) }) }() go func() { - errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) { - l.Log(id, n, 0) + errChan <- copyBufferLog(remoteRw, serverRw, func(n uint64) bool { + return l.Log(id, n, 0) }) }() // Block until one of the two goroutines returns diff --git a/core/server/server.go b/core/server/server.go index 90550bf..ecfbcef 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -279,6 +279,10 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { // Cleanup _ = tConn.Close() _ = stream.Close() + // Disconnect the client if TrafficLogger requested + if err == errDisconnect { + _ = h.conn.CloseWithError(0, "") + } } func (h *h3sHandler) handleUDPRequest(stream quic.Stream) { @@ -316,7 +320,12 @@ func (h *h3sHandler) handleUDPRequest(stream quic.Stream) { udpN, rAddr, err := conn.ReadFrom(udpBuf) if udpN > 0 { if h.config.TrafficLogger != nil { - h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN)) + ok := h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN)) + if !ok { + // TrafficLogger requested to disconnect the client + _ = h.conn.CloseWithError(0, "") + return + } } // Try no frag first msg := protocol.UDPMessage{ @@ -371,20 +380,30 @@ func (h *h3sHandler) udpLoop() { if err != nil { return } - h.handleUDPMessage(msg) + ok := h.handleUDPMessage(msg) + if !ok { + // TrafficLogger requested to disconnect the client + _ = h.conn.CloseWithError(0, "") + return + } } } // client -> remote direction -func (h *h3sHandler) handleUDPMessage(msg []byte) { +// Returns a bool indicating whether the receiving loop should continue +func (h *h3sHandler) handleUDPMessage(msg []byte) (ok bool) { udpMsg, err := protocol.ParseUDPMessage(msg) if err != nil { - return + return true } - n, _ := h.udpSM.Feed(udpMsg) - if n > 0 && h.config.TrafficLogger != nil { - h.config.TrafficLogger.Log(h.authID, uint64(n), 0) + if h.config.TrafficLogger != nil { + ok := h.config.TrafficLogger.Log(h.authID, uint64(len(udpMsg.Data)), 0) + if !ok { + return false + } } + _, _ = h.udpSM.Feed(udpMsg) + return true } func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) {