From b640a69d638445ad0f6a28c68e2e5fc2569bb446 Mon Sep 17 00:00:00 2001 From: Alan Braithwaite Date: Wed, 21 Jun 2017 15:02:34 -0700 Subject: [PATCH] oauthproxy: fix #284 -skip-provider-button for /sign_in route --- oauthproxy.go | 6 +++++- oauthproxy_test.go | 51 ++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index dd2b58e..91608ac 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -482,7 +482,11 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { p.SaveSession(rw, req, session) http.Redirect(rw, req, redirect, 302) } else { - p.SignInPage(rw, req, 200) + if p.SkipProviderButton { + p.OAuthStart(rw, req) + } else { + p.SignInPage(rw, req, http.StatusOK) + } } } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index a0bcc5c..43e165a 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -3,9 +3,6 @@ package main import ( "crypto" "encoding/base64" - "github.com/18F/hmacauth" - "github.com/bitly/oauth2_proxy/providers" - "github.com/bmizerany/assert" "io" "io/ioutil" "log" @@ -17,6 +14,10 @@ import ( "strings" "testing" "time" + + "github.com/18F/hmacauth" + "github.com/bitly/oauth2_proxy/providers" + "github.com/bmizerany/assert" ) func init() { @@ -359,26 +360,30 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { } type SignInPageTest struct { - opts *Options - proxy *OAuthProxy - sign_in_regexp *regexp.Regexp + opts *Options + proxy *OAuthProxy + sign_in_regexp *regexp.Regexp + sign_in_provider_regexp *regexp.Regexp } const signInRedirectPattern = `` +const signInSkipProvider = `>Found<` -func NewSignInPageTest() *SignInPageTest { +func NewSignInPageTest(skipProvider bool) *SignInPageTest { var sip_test SignInPageTest sip_test.opts = NewOptions() sip_test.opts.CookieSecret = "foobar" sip_test.opts.ClientID = "bazquux" sip_test.opts.ClientSecret = "xyzzyplugh" + sip_test.opts.SkipProviderButton = skipProvider sip_test.opts.Validate() sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool { return true }) sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) + sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider) return &sip_test } @@ -391,7 +396,7 @@ func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) { } func TestSignInPageIncludesTargetRedirect(t *testing.T) { - sip_test := NewSignInPageTest() + sip_test := NewSignInPageTest(false) const endpoint = "/some/random/endpoint" code, body := sip_test.GetEndpoint(endpoint) @@ -409,7 +414,7 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) { } func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { - sip_test := NewSignInPageTest() + sip_test := NewSignInPageTest(false) code, body := sip_test.GetEndpoint("/oauth2/sign_in") assert.Equal(t, 200, code) @@ -423,6 +428,34 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { } } +func TestSignInPageSkipProvider(t *testing.T) { + sip_test := NewSignInPageTest(true) + const endpoint = "/some/random/endpoint" + + code, body := sip_test.GetEndpoint(endpoint) + assert.Equal(t, 302, code) + + match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) + if match == nil { + t.Fatal("Did not find pattern in body: " + + signInSkipProvider + "\nBody:\n" + body) + } +} + +func TestSignInPageSkipProviderDirect(t *testing.T) { + sip_test := NewSignInPageTest(true) + const endpoint = "/sign_in" + + code, body := sip_test.GetEndpoint(endpoint) + assert.Equal(t, 302, code) + + match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) + if match == nil { + t.Fatal("Did not find pattern in body: " + + signInSkipProvider + "\nBody:\n" + body) + } +} + type ProcessCookieTest struct { opts *Options proxy *OAuthProxy