diff --git a/CHANGELOG.md b/CHANGELOG.md index d42d688..68954d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Changes since v3.1.0 +- [#92](https://github.com/pusher/oauth2_proxy/pull/92) Merge websocket proxy feature from openshift/oauth-proxy (@butzist) - [#57](https://github.com/pusher/oauth2_proxy/pull/57) Fall back to using OIDC Subject instead of Email (@aigarius) - [#85](https://github.com/pusher/oauth2_proxy/pull/85) Use non-root user in docker images (@kskewes) - [#68](https://github.com/pusher/oauth2_proxy/pull/68) forward X-Auth-Access-Token header (@davidholsgrove) diff --git a/Gopkg.lock b/Gopkg.lock index bc4c28f..6d27a3b 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -59,11 +59,11 @@ [[projects]] branch = "master" - digest = "1:9408fb9c637c103010e5147469c232ce6b68edc840879cc730a2a15918e6cae8" + digest = "1:15c0562bca5d78ac087fb39c211071dc124e79fb18f8b7c3f8a0bc7ffcb2a38e" name = "github.com/mreiferson/go-options" packages = ["."] pruneopts = "" - revision = "77551d20752b54535462404ad9d877ebdb26e53d" + revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95" [[projects]] digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" @@ -87,11 +87,22 @@ [[projects]] digest = "1:3926a4ec9a4ff1a072458451aa2d9b98acd059a45b38f7335d31e06c3d6a0159" name = "github.com/stretchr/testify" - packages = ["assert"] + packages = [ + "assert", + "require", + ] pruneopts = "" revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" version = "v1.1.4" +[[projects]] + branch = "master" + digest = "1:39630a0e2844fc4297c27caacb394a9fd342f869292284a62f856877adab65bc" + name = "github.com/yhat/wsutil" + packages = ["."] + pruneopts = "" + revision = "1d66fa95c997864ba4d8479f56609620fe542928" + [[projects]] branch = "master" digest = "1:f6a006d27619a4d93bf9b66fe1999b8c8d1fa62bdc63af14f10fbe6fcaa2aa1a" @@ -112,6 +123,7 @@ packages = [ "context", "context/ctxhttp", + "websocket", ] pruneopts = "" revision = "9dfe39835686865bff950a07b394c12a98ddc811" @@ -192,7 +204,10 @@ "github.com/mbland/hmacauth", "github.com/mreiferson/go-options", "github.com/stretchr/testify/assert", + "github.com/stretchr/testify/require", + "github.com/yhat/wsutil", "golang.org/x/crypto/bcrypt", + "golang.org/x/net/websocket", "golang.org/x/oauth2", "golang.org/x/oauth2/google", "google.golang.org/api/admin/directory/v1", diff --git a/README.md b/README.md index f833b37..67ec599 100644 --- a/README.md +++ b/README.md @@ -242,6 +242,7 @@ Usage of oauth2_proxy: -profile-url string: Profile access endpoint -provider string: OAuth provider (default "google") -proxy-prefix string: the url root path that this proxy should be nested under (e.g. //sign_in) (default "/oauth2") + -proxy-websockets: enables WebSocket proxying (default true) -redeem-url string: Token redemption endpoint -redirect-url string: the OAuth Redirect URL. ie: "https://internalapp.yourcompany.com/oauth2/callback" -request-logging: Log requests to stdout (default true) diff --git a/main.go b/main.go index 93f799c..1992fc0 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,7 @@ import ( "time" "github.com/BurntSushi/toml" - "github.com/mreiferson/go-options" + options "github.com/mreiferson/go-options" ) func main() { @@ -62,6 +62,7 @@ func main() { flagSet.String("custom-templates-dir", "", "path to custom html templates") flagSet.String("footer", "", "custom footer string. Use \"-\" to disable default footer.") flagSet.String("proxy-prefix", "/oauth2", "the url root path that this proxy should be nested under (e.g. //sign_in)") + flagSet.Bool("proxy-websockets", true, "enables WebSocket proxying") flagSet.String("cookie-name", "_oauth2_proxy", "the name of the cookie that the oauth_proxy creates") flagSet.String("cookie-secret", "", "the seed string for secure cookies (optionally base64 encoded)") diff --git a/oauthproxy.go b/oauthproxy.go index 941e934..7b6ba88 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -17,6 +17,7 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/providers" + "github.com/yhat/wsutil" ) const ( @@ -95,9 +96,10 @@ type OAuthProxy struct { // UpstreamProxy represents an upstream server to proxy to type UpstreamProxy struct { - upstream string - handler http.Handler - auth hmacauth.HmacAuth + upstream string + handler http.Handler + wsHandler http.Handler + auth hmacauth.HmacAuth } // ServeHTTP proxies requests to the upstream provider while signing the @@ -108,7 +110,12 @@ func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth")) u.auth.SignRequest(r) } - u.handler.ServeHTTP(w, r) + if u.wsHandler != nil && r.Header.Get("Connection") == "Upgrade" && r.Header.Get("Upgrade") == "websocket" { + u.wsHandler.ServeHTTP(w, r) + } else { + u.handler.ServeHTTP(w, r) + } + } // NewReverseProxy creates a new reverse proxy for proxying requests to upstream @@ -145,6 +152,26 @@ func NewFileServer(path string, filesystemPath string) (proxy http.Handler) { return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath))) } +// NewWebSocketOrRestReverseProxy creates a reverse proxy for REST or websocket based on url +func NewWebSocketOrRestReverseProxy(u *url.URL, opts *Options, auth hmacauth.HmacAuth) (restProxy http.Handler) { + u.Path = "" + proxy := NewReverseProxy(u, opts.FlushInterval) + if !opts.PassHostHeader { + setProxyUpstreamHostHeader(proxy, u) + } else { + setProxyDirector(proxy) + } + + // this should give us a wss:// scheme if the url is https:// based. + var wsProxy *wsutil.ReverseProxy + if opts.ProxyWebSockets { + wsScheme := "ws" + strings.TrimPrefix(u.Scheme, "http") + wsURL := &url.URL{Scheme: wsScheme, Host: u.Host} + wsProxy = wsutil.NewSingleHostReverseProxy(wsURL) + } + return &UpstreamProxy{u.Host, proxy, wsProxy, auth} +} + // NewOAuthProxy creates a new instance of OOuthProxy from the options provided func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { serveMux := http.NewServeMux() @@ -157,23 +184,17 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { path := u.Path switch u.Scheme { case httpScheme, httpsScheme: - u.Path = "" log.Printf("mapping path %q => upstream %q", path, u) - proxy := NewReverseProxy(u, opts.FlushInterval) - if !opts.PassHostHeader { - setProxyUpstreamHostHeader(proxy, u) - } else { - setProxyDirector(proxy) - } - serveMux.Handle(path, - &UpstreamProxy{u.Host, proxy, auth}) + proxy := NewWebSocketOrRestReverseProxy(u, opts, auth) + serveMux.Handle(path, proxy) + case "file": if u.Fragment != "" { path = u.Fragment } log.Printf("mapping path %q => file system %q", path, u.Path) proxy := NewFileServer(path, u.Path) - serveMux.Handle(path, &UpstreamProxy{path, proxy, nil}) + serveMux.Handle(path, &UpstreamProxy{path, proxy, nil, nil}) default: panic(fmt.Sprintf("unknown upstream protocol %s", u.Scheme)) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 468f7b2..757180a 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -19,6 +19,7 @@ import ( "github.com/pusher/oauth2_proxy/providers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/net/websocket" ) func init() { @@ -26,6 +27,83 @@ func init() { } +type WebSocketOrRestHandler struct { + restHandler http.Handler + wsHandler http.Handler +} + +func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") == "websocket" { + h.wsHandler.ServeHTTP(w, r) + } else { + h.restHandler.ServeHTTP(w, r) + } +} + +func TestWebSocketProxy(t *testing.T) { + handler := WebSocketOrRestHandler{ + restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + hostname, _, _ := net.SplitHostPort(r.Host) + w.Write([]byte(hostname)) + }), + wsHandler: websocket.Handler(func(ws *websocket.Conn) { + defer ws.Close() + var data []byte + err := websocket.Message.Receive(ws, &data) + if err != nil { + t.Fatalf("err %s", err) + return + } + err = websocket.Message.Send(ws, data) + if err != nil { + t.Fatalf("err %s", err) + } + return + }), + } + backend := httptest.NewServer(&handler) + defer backend.Close() + + backendURL, _ := url.Parse(backend.URL) + + options := NewOptions() + var auth hmacauth.HmacAuth + options.PassHostHeader = true + proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, options, auth) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + frontendURL, _ := url.Parse(frontend.URL) + frontendWSURL := "ws://" + frontendURL.Host + "/" + + ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/") + if err != nil { + t.Fatalf("err %s", err) + } + request := []byte("hello, world!") + err = websocket.Message.Send(ws, request) + if err != nil { + t.Fatalf("err %s", err) + } + var response = make([]byte, 1024) + websocket.Message.Receive(ws, &response) + if err != nil { + t.Fatalf("err %s", err) + } + if g, e := string(request), string(response); g != e { + t.Errorf("got body %q; expected %q", g, e) + } + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + res, _ := http.DefaultClient.Do(getReq) + bodyBytes, _ := ioutil.ReadAll(res.Body) + backendHostname, _, _ := net.SplitHostPort(backendURL.Host) + if g, e := string(bodyBytes), backendHostname; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + func TestNewReverseProxy(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(200) diff --git a/options.go b/options.go index 666b517..efa8596 100644 --- a/options.go +++ b/options.go @@ -21,14 +21,15 @@ import ( // Options holds Configuration Options that can be set by Command Line Flag, // or Config File type Options struct { - ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"` - HTTPAddress string `flag:"http-address" cfg:"http_address"` - HTTPSAddress string `flag:"https-address" cfg:"https_address"` - RedirectURL string `flag:"redirect-url" cfg:"redirect_url"` - ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` - ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` - TLSCertFile string `flag:"tls-cert" cfg:"tls_cert_file"` - TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file"` + ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"` + ProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets"` + HTTPAddress string `flag:"http-address" cfg:"http_address"` + HTTPSAddress string `flag:"https-address" cfg:"https_address"` + RedirectURL string `flag:"redirect-url" cfg:"redirect_url"` + ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` + ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` + TLSCertFile string `flag:"tls-cert" cfg:"tls_cert_file"` + TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file"` AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"` @@ -105,6 +106,7 @@ type SignatureData struct { func NewOptions() *Options { return &Options{ ProxyPrefix: "/oauth2", + ProxyWebSockets: true, HTTPAddress: "127.0.0.1:4180", HTTPSAddress: ":443", DisplayHtpasswdForm: true,