Merge pull request #82 from 18F/sign-in-redirect

Redirect to / when /oauth2/sign_in accessed
This commit is contained in:
Jehiah Czebotar 2015-04-06 23:20:26 -04:00
commit b0f0409f2b
2 changed files with 73 additions and 1 deletions

View File

@ -307,6 +307,11 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
p.ClearCookie(rw, req) p.ClearCookie(rw, req)
rw.WriteHeader(code) rw.WriteHeader(code)
redirect_url := req.URL.RequestURI()
if redirect_url == signInPath {
redirect_url = "/"
}
t := struct { t := struct {
ProviderName string ProviderName string
SignInMessage string SignInMessage string
@ -317,7 +322,7 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code
ProviderName: p.provider.Data().ProviderName, ProviderName: p.provider.Data().ProviderName,
SignInMessage: p.SignInMessage, SignInMessage: p.SignInMessage,
CustomLogin: p.displayCustomLoginForm(), CustomLogin: p.displayCustomLoginForm(),
Redirect: req.URL.RequestURI(), Redirect: redirect_url,
Version: VERSION, Version: VERSION,
} }
p.templates.ExecuteTemplate(rw, "sign_in.html", t) p.templates.ExecuteTemplate(rw, "sign_in.html", t)

View File

@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"regexp"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -237,3 +238,69 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
assert.Equal(t, 200, code) assert.Equal(t, 200, code)
assert.Equal(t, "No access token found.", payload) assert.Equal(t, "No access token found.", payload)
} }
type SignInPageTest struct {
opts *Options
proxy *OauthProxy
sign_in_regexp *regexp.Regexp
}
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
func NewSignInPageTest() *SignInPageTest {
var sip_test SignInPageTest
sip_test.opts = NewOptions()
sip_test.opts.Upstreams = append(sip_test.opts.Upstreams, "unused")
sip_test.opts.CookieSecret = "foobar"
sip_test.opts.ClientID = "bazquux"
sip_test.opts.ClientSecret = "xyzzyplugh"
sip_test.opts.Validate()
sip_test.proxy = NewOauthProxy(sip_test.opts, func(email string) bool {
return true
})
sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern)
return &sip_test
}
func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
sip_test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
sip_test := NewSignInPageTest()
const endpoint = "/some/random/endpoint"
code, body := sip_test.GetEndpoint(endpoint)
assert.Equal(t, 403, code)
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body)
}
if match[1] != endpoint {
t.Fatal(`expected redirect to "` + endpoint +
`", but was "` + match[1] + `"`)
}
}
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
sip_test := NewSignInPageTest()
code, body := sip_test.GetEndpoint("/oauth2/sign_in")
assert.Equal(t, 200, code)
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body)
}
if match[1] != "/" {
t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
}
}