From c12db0ebf797e3489c1bef93878ae1bdb6850125 Mon Sep 17 00:00:00 2001 From: Cosmin Cojocar Date: Wed, 30 Jan 2019 11:13:12 +0100 Subject: [PATCH 1/2] Returns HTTP unauthorized for ajax requests instead of redirecting to the sing-in page --- oauthproxy.go | 28 ++++++++++++++++++++ oauthproxy_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/oauthproxy.go b/oauthproxy.go index 971005d..daef9c5 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -754,6 +754,8 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { } else { p.SignInPage(rw, req, http.StatusForbidden) } + } else if status == http.StatusUnauthorized { + p.ErrorJSON(rw, status) } else { p.serveMux.ServeHTTP(rw, req) } @@ -826,6 +828,11 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int } if session == nil { + // Check if is an ajax request and return unauthorized to avoid a redirect + // to the login page + if p.isAjax(req) { + return http.StatusUnauthorized + } return http.StatusForbidden } @@ -894,3 +901,24 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, } return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) } + +// isAjax checks if a request is an ajax request +func (p *OAuthProxy) isAjax(req *http.Request) bool { + acceptValues, ok := req.Header["accept"] + if !ok { + acceptValues = req.Header["Accept"] + } + const ajaxReq = "application/json" + for _, v := range acceptValues { + if v == ajaxReq { + return true + } + } + return false +} + +// ErrorJSON returns the error code witht an application/json mime type +func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(code) +} diff --git a/oauthproxy_test.go b/oauthproxy_test.go index a5dd545..d6b5357 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -871,3 +871,67 @@ func TestGetRedirect(t *testing.T) { }) } } + +type ajaxRequestTest struct { + opts *Options + proxy *OAuthProxy +} + +func newAjaxRequestTest() *ajaxRequestTest { + test := &ajaxRequestTest{} + test.opts = NewOptions() + test.opts.CookieSecret = "foobar" + test.opts.ClientID = "bazquux" + test.opts.ClientSecret = "xyzzyplugh" + test.opts.Validate() + test.proxy = NewOAuthProxy(test.opts, func(email string) bool { + return true + }) + return test +} + +func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) { + rw := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader("")) + if err != nil { + return 0, nil, err + } + req.Header = header + test.proxy.ServeHTTP(rw, req) + return rw.Code, rw.Header(), nil +} + +func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) { + test := newAjaxRequestTest() + const endpoint = "/test" + + code, rh, err := test.getEndpoint(endpoint, header) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, code) + mime := rh.Get("Content-Type") + assert.Equal(t, "application/json", mime) +} +func TestAjaxUnauthorizedRequest1(t *testing.T) { + header := make(http.Header) + header.Add("accept", "application/json") + + testAjaxUnauthorizedRequest(t, header) +} + +func TestAjaxUnauthorizedRequest2(t *testing.T) { + header := make(http.Header) + header.Add("Accept", "application/json") + + testAjaxUnauthorizedRequest(t, header) +} + +func TestAjaxForbiddendRequest(t *testing.T) { + test := newAjaxRequestTest() + const endpoint = "/test" + header := make(http.Header) + code, rh, err := test.getEndpoint(endpoint, header) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, code) + mime := rh.Get("Content-Type") + assert.NotEqual(t, "application/json", mime) +} From 33261944228724bb4212d52d357fbe9842a7b00f Mon Sep 17 00:00:00 2001 From: Cosmin Cojocar Date: Thu, 31 Jan 2019 16:22:30 +0100 Subject: [PATCH 2/2] Extract the application/json mime type into a const --- oauthproxy.go | 6 ++++-- oauthproxy_test.go | 12 ++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index daef9c5..9c82d4f 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -30,6 +30,8 @@ const ( // Cookies are limited to 4kb including the length of the cookie name, // the cookie name can be up to 256 bytes maxCookieLength = 3840 + + applicationJSON = "application/json" ) // SignatureHeaders contains the headers to be signed by the hmac algorithm @@ -908,7 +910,7 @@ func (p *OAuthProxy) isAjax(req *http.Request) bool { if !ok { acceptValues = req.Header["Accept"] } - const ajaxReq = "application/json" + const ajaxReq = applicationJSON for _, v := range acceptValues { if v == ajaxReq { return true @@ -919,6 +921,6 @@ func (p *OAuthProxy) isAjax(req *http.Request) bool { // ErrorJSON returns the error code witht an application/json mime type func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { - rw.Header().Set("Content-Type", "application/json") + rw.Header().Set("Content-Type", applicationJSON) rw.WriteHeader(code) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index d6b5357..927f886 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -903,35 +903,35 @@ func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (i func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) { test := newAjaxRequestTest() - const endpoint = "/test" + endpoint := "/test" code, rh, err := test.getEndpoint(endpoint, header) assert.NoError(t, err) assert.Equal(t, http.StatusUnauthorized, code) mime := rh.Get("Content-Type") - assert.Equal(t, "application/json", mime) + assert.Equal(t, applicationJSON, mime) } func TestAjaxUnauthorizedRequest1(t *testing.T) { header := make(http.Header) - header.Add("accept", "application/json") + header.Add("accept", applicationJSON) testAjaxUnauthorizedRequest(t, header) } func TestAjaxUnauthorizedRequest2(t *testing.T) { header := make(http.Header) - header.Add("Accept", "application/json") + header.Add("Accept", applicationJSON) testAjaxUnauthorizedRequest(t, header) } func TestAjaxForbiddendRequest(t *testing.T) { test := newAjaxRequestTest() - const endpoint = "/test" + endpoint := "/test" header := make(http.Header) code, rh, err := test.getEndpoint(endpoint, header) assert.NoError(t, err) assert.Equal(t, http.StatusForbidden, code) mime := rh.Get("Content-Type") - assert.NotEqual(t, "application/json", mime) + assert.NotEqual(t, applicationJSON, mime) }