From 48bf9b964a63c43eef65f9c364d0d27e9e08eb4b Mon Sep 17 00:00:00 2001 From: Toby Date: Fri, 16 Aug 2024 15:46:30 -0700 Subject: [PATCH 1/3] fix: sniffing handled HTTP host header incorrectly --- extras/sniff/sniff.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/extras/sniff/sniff.go b/extras/sniff/sniff.go index e0c94d4..b2d7419 100644 --- a/extras/sniff/sniff.go +++ b/extras/sniff/sniff.go @@ -112,11 +112,21 @@ func (h *Sniffer) TCP(stream quic.Stream, reqAddr *string) ([]byte, error) { tr := &teeReader{Stream: stream, Pre: pre} req, _ := http.ReadRequest(bufio.NewReader(tr)) if req != nil && req.Host != "" { - _, port, err := net.SplitHostPort(*reqAddr) + // req.Host may already contain the port. + // If it does, just overwrite the whole address with req.Host. + // Otherwise, use the port in reqAddr. + _, _, err := net.SplitHostPort(req.Host) if err != nil { - return nil, err + // Not host:port format, append the port from reqAddr + _, port, err := net.SplitHostPort(*reqAddr) + if err != nil { + return nil, err + } + *reqAddr = net.JoinHostPort(req.Host, port) + } else { + // Already host:port format + *reqAddr = req.Host } - *reqAddr = net.JoinHostPort(req.Host, port) } return tr.Buffer(), nil } else if h.isTLS(pre) { From f014c005463cb5a77696485703d25c1d9566c287 Mon Sep 17 00:00:00 2001 From: Toby Date: Fri, 16 Aug 2024 15:51:42 -0700 Subject: [PATCH 2/3] fix: add a test case --- extras/sniff/sniff_test.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/extras/sniff/sniff_test.go b/extras/sniff/sniff_test.go index a22784e..731cc6b 100644 --- a/extras/sniff/sniff_test.go +++ b/extras/sniff/sniff_test.go @@ -70,6 +70,18 @@ func TestSnifferTCP(t *testing.T) { assert.Equal(t, *buf, putback) assert.Equal(t, "example.com:80", reqAddr) + // Test HTTP with Host as host:port + *buf = []byte("GET / HTTP/1.1\r\n" + + "Host: example.com:8080\r\n" + + "User-Agent: test-agent\r\n" + + "Accept: */*\r\n\r\n") + index = 0 + reqAddr = "222.222.222.222:10086" + putback, err = sniffer.TCP(stream, &reqAddr) + assert.NoError(t, err) + assert.Equal(t, *buf, putback) + assert.Equal(t, "example.com:8080", reqAddr) + // Test TLS *buf, err = base64.StdEncoding.DecodeString("FgMBARcBAAETAwPJL2jlt1OAo+Rslkjv/aqKiTthKMaCKg2Gvd+uALDbDCDdY+UIk8ouadEB9fC3j52Y1i7SJZqGIgBRIS6kKieYrAAoEwITAcAswCvAMMAvwCTAI8AowCfACsAJwBTAEwCdAJwAPQA8ADUALwEAAKIAAAAOAAwAAAlpcGluZm8uaW8ABQAFAQAAAAAAKwAJCAMEAwMDAgMBAA0AGgAYCAQIBQgGBAEFAQIBBAMFAwIDAgIGAQYDACMAAAAKAAgABgAdABcAGAAQAAsACQhodHRwLzEuMQAzACYAJAAdACBguQbqNJNyamYxYcrBFpBP7pWv5TgZsP9gwGtMYNKVBQAxAAAAFwAA/wEAAQAALQACAQE=") assert.NoError(t, err) From 55c3a064cca4a8e2bb83cf054cd6e901d681302c Mon Sep 17 00:00:00 2001 From: Toby Date: Fri, 16 Aug 2024 20:48:14 -0700 Subject: [PATCH 3/3] fix: never overwrite the port --- extras/sniff/sniff.go | 22 +++++++++------------- extras/sniff/sniff_test.go | 2 +- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/extras/sniff/sniff.go b/extras/sniff/sniff.go index b2d7419..9994b8a 100644 --- a/extras/sniff/sniff.go +++ b/extras/sniff/sniff.go @@ -112,21 +112,17 @@ func (h *Sniffer) TCP(stream quic.Stream, reqAddr *string) ([]byte, error) { tr := &teeReader{Stream: stream, Pre: pre} req, _ := http.ReadRequest(bufio.NewReader(tr)) if req != nil && req.Host != "" { - // req.Host may already contain the port. - // If it does, just overwrite the whole address with req.Host. - // Otherwise, use the port in reqAddr. - _, _, err := net.SplitHostPort(req.Host) + // req.Host can be host:port, in which case we need to extract the host part + host, _, err := net.SplitHostPort(req.Host) if err != nil { - // Not host:port format, append the port from reqAddr - _, port, err := net.SplitHostPort(*reqAddr) - if err != nil { - return nil, err - } - *reqAddr = net.JoinHostPort(req.Host, port) - } else { - // Already host:port format - *reqAddr = req.Host + // No port, just use the whole string + host = req.Host } + _, port, err := net.SplitHostPort(*reqAddr) + if err != nil { + return nil, err + } + *reqAddr = net.JoinHostPort(host, port) } return tr.Buffer(), nil } else if h.isTLS(pre) { diff --git a/extras/sniff/sniff_test.go b/extras/sniff/sniff_test.go index 731cc6b..445660b 100644 --- a/extras/sniff/sniff_test.go +++ b/extras/sniff/sniff_test.go @@ -80,7 +80,7 @@ func TestSnifferTCP(t *testing.T) { putback, err = sniffer.TCP(stream, &reqAddr) assert.NoError(t, err) assert.Equal(t, *buf, putback) - assert.Equal(t, "example.com:8080", reqAddr) + assert.Equal(t, "example.com:10086", reqAddr) // Test TLS *buf, err = base64.StdEncoding.DecodeString("FgMBARcBAAETAwPJL2jlt1OAo+Rslkjv/aqKiTthKMaCKg2Gvd+uALDbDCDdY+UIk8ouadEB9fC3j52Y1i7SJZqGIgBRIS6kKieYrAAoEwITAcAswCvAMMAvwCTAI8AowCfACsAJwBTAEwCdAJwAPQA8ADUALwEAAKIAAAAOAAwAAAlpcGluZm8uaW8ABQAFAQAAAAAAKwAJCAMEAwMDAgMBAA0AGgAYCAQIBQgGBAEFAQIBBAMFAwIDAgIGAQYDACMAAAAKAAgABgAdABcAGAAQAAsACQhodHRwLzEuMQAzACYAJAAdACBguQbqNJNyamYxYcrBFpBP7pWv5TgZsP9gwGtMYNKVBQAxAAAAFwAA/wEAAQAALQACAQE=")