diff --git a/oauthproxy.go b/oauthproxy.go index 971005d..9c82d4f 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -30,6 +30,8 @@ const ( // Cookies are limited to 4kb including the length of the cookie name, // the cookie name can be up to 256 bytes maxCookieLength = 3840 + + applicationJSON = "application/json" ) // SignatureHeaders contains the headers to be signed by the hmac algorithm @@ -754,6 +756,8 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { } else { p.SignInPage(rw, req, http.StatusForbidden) } + } else if status == http.StatusUnauthorized { + p.ErrorJSON(rw, status) } else { p.serveMux.ServeHTTP(rw, req) } @@ -826,6 +830,11 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int } if session == nil { + // Check if is an ajax request and return unauthorized to avoid a redirect + // to the login page + if p.isAjax(req) { + return http.StatusUnauthorized + } return http.StatusForbidden } @@ -894,3 +903,24 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, } return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) } + +// isAjax checks if a request is an ajax request +func (p *OAuthProxy) isAjax(req *http.Request) bool { + acceptValues, ok := req.Header["accept"] + if !ok { + acceptValues = req.Header["Accept"] + } + const ajaxReq = applicationJSON + for _, v := range acceptValues { + if v == ajaxReq { + return true + } + } + return false +} + +// ErrorJSON returns the error code witht an application/json mime type +func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { + rw.Header().Set("Content-Type", applicationJSON) + rw.WriteHeader(code) +} diff --git a/oauthproxy_test.go b/oauthproxy_test.go index a5dd545..927f886 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -871,3 +871,67 @@ func TestGetRedirect(t *testing.T) { }) } } + +type ajaxRequestTest struct { + opts *Options + proxy *OAuthProxy +} + +func newAjaxRequestTest() *ajaxRequestTest { + test := &ajaxRequestTest{} + test.opts = NewOptions() + test.opts.CookieSecret = "foobar" + test.opts.ClientID = "bazquux" + test.opts.ClientSecret = "xyzzyplugh" + test.opts.Validate() + test.proxy = NewOAuthProxy(test.opts, func(email string) bool { + return true + }) + return test +} + +func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) { + rw := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader("")) + if err != nil { + return 0, nil, err + } + req.Header = header + test.proxy.ServeHTTP(rw, req) + return rw.Code, rw.Header(), nil +} + +func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) { + test := newAjaxRequestTest() + endpoint := "/test" + + code, rh, err := test.getEndpoint(endpoint, header) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, code) + mime := rh.Get("Content-Type") + assert.Equal(t, applicationJSON, mime) +} +func TestAjaxUnauthorizedRequest1(t *testing.T) { + header := make(http.Header) + header.Add("accept", applicationJSON) + + testAjaxUnauthorizedRequest(t, header) +} + +func TestAjaxUnauthorizedRequest2(t *testing.T) { + header := make(http.Header) + header.Add("Accept", applicationJSON) + + testAjaxUnauthorizedRequest(t, header) +} + +func TestAjaxForbiddendRequest(t *testing.T) { + test := newAjaxRequestTest() + endpoint := "/test" + header := make(http.Header) + code, rh, err := test.getEndpoint(endpoint, header) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, code) + mime := rh.Get("Content-Type") + assert.NotEqual(t, applicationJSON, mime) +}