diff --git a/extras/masq/server.go b/extras/masq/server.go index cb7439b..5600f7c 100644 --- a/extras/masq/server.go +++ b/extras/masq/server.go @@ -1,8 +1,10 @@ package masq import ( + "bufio" "crypto/tls" "fmt" + "net" "net/http" ) @@ -28,10 +30,7 @@ func (s *MasqTCPServer) ListenAndServeHTTP(addr string) error { } return } - s.Handler.ServeHTTP(&altSvcHijackResponseWriter{ - Port: s.QUICPort, - ResponseWriter: w, - }, r) + s.Handler.ServeHTTP(newAltSvcHijackResponseWriter(w, s.QUICPort), r) })) } @@ -39,16 +38,15 @@ func (s *MasqTCPServer) ListenAndServeHTTPS(addr string) error { server := &http.Server{ Addr: addr, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - s.Handler.ServeHTTP(&altSvcHijackResponseWriter{ - Port: s.QUICPort, - ResponseWriter: w, - }, r) + s.Handler.ServeHTTP(newAltSvcHijackResponseWriter(w, s.QUICPort), r) }), TLSConfig: s.TLSConfig, } return server.ListenAndServeTLS("", "") } +var _ http.ResponseWriter = (*altSvcHijackResponseWriter)(nil) + // altSvcHijackResponseWriter makes sure that the Alt-Svc's port // is always set with our own value, no matter what the handler sets. type altSvcHijackResponseWriter struct { @@ -60,3 +58,30 @@ func (w *altSvcHijackResponseWriter) WriteHeader(statusCode int) { w.Header().Set("Alt-Svc", fmt.Sprintf(`h3=":%d"; ma=2592000`, w.Port)) w.ResponseWriter.WriteHeader(statusCode) } + +var _ http.Hijacker = (*altSvcHijackResponseWriterHijacker)(nil) + +// altSvcHijackResponseWriterHijacker is a wrapper around altSvcHijackResponseWriter +// that also implements http.Hijacker. This is needed for WebSocket support. +type altSvcHijackResponseWriterHijacker struct { + altSvcHijackResponseWriter +} + +func (w *altSvcHijackResponseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return w.ResponseWriter.(http.Hijacker).Hijack() +} + +func newAltSvcHijackResponseWriter(w http.ResponseWriter, port int) http.ResponseWriter { + if _, ok := w.(http.Hijacker); ok { + return &altSvcHijackResponseWriterHijacker{ + altSvcHijackResponseWriter: altSvcHijackResponseWriter{ + Port: port, + ResponseWriter: w, + }, + } + } + return &altSvcHijackResponseWriter{ + Port: port, + ResponseWriter: w, + } +}