From 5f747bb7685a2c8d043ac3034266d7008e06780d Mon Sep 17 00:00:00 2001 From: Mike Bland Date: Mon, 6 Apr 2015 22:10:03 -0400 Subject: [PATCH] Redirect to / when /oauth2/sign_in accessed Without this change, clicking the sign-in button on /oauth2/sign_in will always redirect back to /oauth2/sign_in, essentially creating an infinite loop. --- oauthproxy.go | 7 ++++- oauthproxy_test.go | 67 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/oauthproxy.go b/oauthproxy.go index db4e2e7..0b5141e 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -307,6 +307,11 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code p.ClearCookie(rw, req) rw.WriteHeader(code) + redirect_url := req.URL.RequestURI() + if redirect_url == signInPath { + redirect_url = "/" + } + t := struct { ProviderName string SignInMessage string @@ -317,7 +322,7 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code ProviderName: p.provider.Data().ProviderName, SignInMessage: p.SignInMessage, CustomLogin: p.displayCustomLoginForm(), - Redirect: req.URL.RequestURI(), + Redirect: redirect_url, Version: VERSION, } p.templates.ExecuteTemplate(rw, "sign_in.html", t) diff --git a/oauthproxy_test.go b/oauthproxy_test.go index bae7e6b..3712551 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "strings" "testing" "time" @@ -237,3 +238,69 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { assert.Equal(t, 200, code) assert.Equal(t, "No access token found.", payload) } + +type SignInPageTest struct { + opts *Options + proxy *OauthProxy + sign_in_regexp *regexp.Regexp +} + +const signInRedirectPattern = `` + +func NewSignInPageTest() *SignInPageTest { + var sip_test SignInPageTest + + sip_test.opts = NewOptions() + sip_test.opts.Upstreams = append(sip_test.opts.Upstreams, "unused") + sip_test.opts.CookieSecret = "foobar" + sip_test.opts.ClientID = "bazquux" + sip_test.opts.ClientSecret = "xyzzyplugh" + sip_test.opts.Validate() + + sip_test.proxy = NewOauthProxy(sip_test.opts, func(email string) bool { + return true + }) + sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) + + return &sip_test +} + +func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) { + rw := httptest.NewRecorder() + req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) + sip_test.proxy.ServeHTTP(rw, req) + return rw.Code, rw.Body.String() +} + +func TestSignInPageIncludesTargetRedirect(t *testing.T) { + sip_test := NewSignInPageTest() + const endpoint = "/some/random/endpoint" + + code, body := sip_test.GetEndpoint(endpoint) + assert.Equal(t, 403, code) + + match := sip_test.sign_in_regexp.FindStringSubmatch(body) + if match == nil { + t.Fatal("Did not find pattern in body: " + + signInRedirectPattern + "\nBody:\n" + body) + } + if match[1] != endpoint { + t.Fatal(`expected redirect to "` + endpoint + + `", but was "` + match[1] + `"`) + } +} + +func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { + sip_test := NewSignInPageTest() + code, body := sip_test.GetEndpoint("/oauth2/sign_in") + assert.Equal(t, 200, code) + + match := sip_test.sign_in_regexp.FindStringSubmatch(body) + if match == nil { + t.Fatal("Did not find pattern in body: " + + signInRedirectPattern + "\nBody:\n" + body) + } + if match[1] != "/" { + t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`) + } +}