diff --git a/oauthproxy.go b/oauthproxy.go index 971005d..daef9c5 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -754,6 +754,8 @@ func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { } else { p.SignInPage(rw, req, http.StatusForbidden) } + } else if status == http.StatusUnauthorized { + p.ErrorJSON(rw, status) } else { p.serveMux.ServeHTTP(rw, req) } @@ -826,6 +828,11 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int } if session == nil { + // Check if is an ajax request and return unauthorized to avoid a redirect + // to the login page + if p.isAjax(req) { + return http.StatusUnauthorized + } return http.StatusForbidden } @@ -894,3 +901,24 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, } return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) } + +// isAjax checks if a request is an ajax request +func (p *OAuthProxy) isAjax(req *http.Request) bool { + acceptValues, ok := req.Header["accept"] + if !ok { + acceptValues = req.Header["Accept"] + } + const ajaxReq = "application/json" + for _, v := range acceptValues { + if v == ajaxReq { + return true + } + } + return false +} + +// ErrorJSON returns the error code witht an application/json mime type +func (p *OAuthProxy) ErrorJSON(rw http.ResponseWriter, code int) { + rw.Header().Set("Content-Type", "application/json") + rw.WriteHeader(code) +} diff --git a/oauthproxy_test.go b/oauthproxy_test.go index a5dd545..d6b5357 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -871,3 +871,67 @@ func TestGetRedirect(t *testing.T) { }) } } + +type ajaxRequestTest struct { + opts *Options + proxy *OAuthProxy +} + +func newAjaxRequestTest() *ajaxRequestTest { + test := &ajaxRequestTest{} + test.opts = NewOptions() + test.opts.CookieSecret = "foobar" + test.opts.ClientID = "bazquux" + test.opts.ClientSecret = "xyzzyplugh" + test.opts.Validate() + test.proxy = NewOAuthProxy(test.opts, func(email string) bool { + return true + }) + return test +} + +func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) { + rw := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader("")) + if err != nil { + return 0, nil, err + } + req.Header = header + test.proxy.ServeHTTP(rw, req) + return rw.Code, rw.Header(), nil +} + +func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) { + test := newAjaxRequestTest() + const endpoint = "/test" + + code, rh, err := test.getEndpoint(endpoint, header) + assert.NoError(t, err) + assert.Equal(t, http.StatusUnauthorized, code) + mime := rh.Get("Content-Type") + assert.Equal(t, "application/json", mime) +} +func TestAjaxUnauthorizedRequest1(t *testing.T) { + header := make(http.Header) + header.Add("accept", "application/json") + + testAjaxUnauthorizedRequest(t, header) +} + +func TestAjaxUnauthorizedRequest2(t *testing.T) { + header := make(http.Header) + header.Add("Accept", "application/json") + + testAjaxUnauthorizedRequest(t, header) +} + +func TestAjaxForbiddendRequest(t *testing.T) { + test := newAjaxRequestTest() + const endpoint = "/test" + header := make(http.Header) + code, rh, err := test.getEndpoint(endpoint, header) + assert.NoError(t, err) + assert.Equal(t, http.StatusForbidden, code) + mime := rh.Get("Content-Type") + assert.NotEqual(t, "application/json", mime) +}