diff --git a/CHANGELOG.md b/CHANGELOG.md index b9d0385..c36c201 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Changes since v3.1.0 - [#110](https://github.com/pusher/oauth2_proxy/pull/110) Added GCP healthcheck option (@timothy-spencer) +- [#112](https://github.com/pusher/oauth2_proxy/pull/112) Improve websocket support (@gyson) - [#63](https://github.com/pusher/oauth2_proxy/pull/63) Use encoding/json for SessionState serialization (@yaegashi) - Use JSON to encode session state to be stored in browser cookies - Implement legacy decode function to support existing cookies generated by older versions diff --git a/logging_handler.go b/logging_handler.go index d47ae58..4502ed3 100644 --- a/logging_handler.go +++ b/logging_handler.go @@ -4,6 +4,8 @@ package main import ( + "bufio" + "errors" "fmt" "io" "net" @@ -32,6 +34,14 @@ func (l *responseLogger) Header() http.Header { return l.w.Header() } +// Support Websocket +func (l *responseLogger) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { + if hj, ok := l.w.(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, errors.New("http.Hijacker is not available on writer") +} + // ExtractGAPMetadata extracts and removes GAP headers from the ResponseWriter's // Header func (l *responseLogger) ExtractGAPMetadata() { diff --git a/logging_handler_test.go b/logging_handler_test.go index de0efcc..ea5d968 100644 --- a/logging_handler_test.go +++ b/logging_handler_test.go @@ -24,6 +24,11 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { for _, test := range tests { buf := bytes.NewBuffer(nil) handler := func(w http.ResponseWriter, req *http.Request) { + _, ok := w.(http.Hijacker) + if !ok { + t.Error("http.Hijacker is not available") + } + w.Write([]byte("test")) } diff --git a/oauthproxy.go b/oauthproxy.go index 561dce3..24fea21 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -110,7 +110,7 @@ func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth")) u.auth.SignRequest(r) } - if u.wsHandler != nil && r.Header.Get("Connection") == "Upgrade" && r.Header.Get("Upgrade") == "websocket" { + if u.wsHandler != nil && strings.ToLower(r.Header.Get("Connection")) == "upgrade" && r.Header.Get("Upgrade") == "websocket" { u.wsHandler.ServeHTTP(w, r) } else { u.handler.ServeHTTP(w, r)