oauthproxy: fix #284 -skip-provider-button for /sign_in route
This commit is contained in:
parent
3c51c914ac
commit
b640a69d63
@ -482,7 +482,11 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
|||||||
p.SaveSession(rw, req, session)
|
p.SaveSession(rw, req, session)
|
||||||
http.Redirect(rw, req, redirect, 302)
|
http.Redirect(rw, req, redirect, 302)
|
||||||
} else {
|
} else {
|
||||||
p.SignInPage(rw, req, 200)
|
if p.SkipProviderButton {
|
||||||
|
p.OAuthStart(rw, req)
|
||||||
|
} else {
|
||||||
|
p.SignInPage(rw, req, http.StatusOK)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,9 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"github.com/18F/hmacauth"
|
|
||||||
"github.com/bitly/oauth2_proxy/providers"
|
|
||||||
"github.com/bmizerany/assert"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
@ -17,6 +14,10 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/18F/hmacauth"
|
||||||
|
"github.com/bitly/oauth2_proxy/providers"
|
||||||
|
"github.com/bmizerany/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@ -359,26 +360,30 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SignInPageTest struct {
|
type SignInPageTest struct {
|
||||||
opts *Options
|
opts *Options
|
||||||
proxy *OAuthProxy
|
proxy *OAuthProxy
|
||||||
sign_in_regexp *regexp.Regexp
|
sign_in_regexp *regexp.Regexp
|
||||||
|
sign_in_provider_regexp *regexp.Regexp
|
||||||
}
|
}
|
||||||
|
|
||||||
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
|
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
|
||||||
|
const signInSkipProvider = `>Found<`
|
||||||
|
|
||||||
func NewSignInPageTest() *SignInPageTest {
|
func NewSignInPageTest(skipProvider bool) *SignInPageTest {
|
||||||
var sip_test SignInPageTest
|
var sip_test SignInPageTest
|
||||||
|
|
||||||
sip_test.opts = NewOptions()
|
sip_test.opts = NewOptions()
|
||||||
sip_test.opts.CookieSecret = "foobar"
|
sip_test.opts.CookieSecret = "foobar"
|
||||||
sip_test.opts.ClientID = "bazquux"
|
sip_test.opts.ClientID = "bazquux"
|
||||||
sip_test.opts.ClientSecret = "xyzzyplugh"
|
sip_test.opts.ClientSecret = "xyzzyplugh"
|
||||||
|
sip_test.opts.SkipProviderButton = skipProvider
|
||||||
sip_test.opts.Validate()
|
sip_test.opts.Validate()
|
||||||
|
|
||||||
sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool {
|
sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool {
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern)
|
sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern)
|
||||||
|
sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider)
|
||||||
|
|
||||||
return &sip_test
|
return &sip_test
|
||||||
}
|
}
|
||||||
@ -391,7 +396,7 @@ func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
|
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
|
||||||
sip_test := NewSignInPageTest()
|
sip_test := NewSignInPageTest(false)
|
||||||
const endpoint = "/some/random/endpoint"
|
const endpoint = "/some/random/endpoint"
|
||||||
|
|
||||||
code, body := sip_test.GetEndpoint(endpoint)
|
code, body := sip_test.GetEndpoint(endpoint)
|
||||||
@ -409,7 +414,7 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
||||||
sip_test := NewSignInPageTest()
|
sip_test := NewSignInPageTest(false)
|
||||||
code, body := sip_test.GetEndpoint("/oauth2/sign_in")
|
code, body := sip_test.GetEndpoint("/oauth2/sign_in")
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
|
|
||||||
@ -423,6 +428,34 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSignInPageSkipProvider(t *testing.T) {
|
||||||
|
sip_test := NewSignInPageTest(true)
|
||||||
|
const endpoint = "/some/random/endpoint"
|
||||||
|
|
||||||
|
code, body := sip_test.GetEndpoint(endpoint)
|
||||||
|
assert.Equal(t, 302, code)
|
||||||
|
|
||||||
|
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body)
|
||||||
|
if match == nil {
|
||||||
|
t.Fatal("Did not find pattern in body: " +
|
||||||
|
signInSkipProvider + "\nBody:\n" + body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignInPageSkipProviderDirect(t *testing.T) {
|
||||||
|
sip_test := NewSignInPageTest(true)
|
||||||
|
const endpoint = "/sign_in"
|
||||||
|
|
||||||
|
code, body := sip_test.GetEndpoint(endpoint)
|
||||||
|
assert.Equal(t, 302, code)
|
||||||
|
|
||||||
|
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body)
|
||||||
|
if match == nil {
|
||||||
|
t.Fatal("Did not find pattern in body: " +
|
||||||
|
signInSkipProvider + "\nBody:\n" + body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type ProcessCookieTest struct {
|
type ProcessCookieTest struct {
|
||||||
opts *Options
|
opts *Options
|
||||||
proxy *OAuthProxy
|
proxy *OAuthProxy
|
||||||
|
Loading…
Reference in New Issue
Block a user