oauth2_proxy/oauthproxy_test.go
Lukasz Leszczuk 90208c7fe4 Fix tests
2019-09-06 16:54:41 +02:00

1339 lines
40 KiB
Go

package main
import (
"context"
"crypto"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
"time"
"github.com/coreos/go-oidc"
"github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/pkg/logger"
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
"github.com/pusher/oauth2_proxy/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
)
func init() {
logger.SetFlags(logger.Lshortfile)
}
type WebSocketOrRestHandler struct {
restHandler http.Handler
wsHandler http.Handler
}
func (h *WebSocketOrRestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Upgrade") == "websocket" {
h.wsHandler.ServeHTTP(w, r)
} else {
h.restHandler.ServeHTTP(w, r)
}
}
func TestWebSocketProxy(t *testing.T) {
handler := WebSocketOrRestHandler{
restHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
hostname, _, _ := net.SplitHostPort(r.Host)
w.Write([]byte(hostname))
}),
wsHandler: websocket.Handler(func(ws *websocket.Conn) {
defer ws.Close()
var data []byte
err := websocket.Message.Receive(ws, &data)
if err != nil {
t.Fatalf("err %s", err)
return
}
err = websocket.Message.Send(ws, data)
if err != nil {
t.Fatalf("err %s", err)
}
return
}),
}
backend := httptest.NewServer(&handler)
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
options := NewOptions()
var auth hmacauth.HmacAuth
options.PassHostHeader = true
proxyHandler := NewWebSocketOrRestReverseProxy(backendURL, options, auth)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
frontendURL, _ := url.Parse(frontend.URL)
frontendWSURL := "ws://" + frontendURL.Host + "/"
ws, err := websocket.Dial(frontendWSURL, "", "http://localhost/")
if err != nil {
t.Fatalf("err %s", err)
}
request := []byte("hello, world!")
err = websocket.Message.Send(ws, request)
if err != nil {
t.Fatalf("err %s", err)
}
var response = make([]byte, 1024)
websocket.Message.Receive(ws, &response)
if err != nil {
t.Fatalf("err %s", err)
}
if g, e := string(request), string(response); g != e {
t.Errorf("got body %q; expected %q", g, e)
}
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
backendHostname, _, _ := net.SplitHostPort(backendURL.Host)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
func TestNewReverseProxy(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
hostname, _, _ := net.SplitHostPort(r.Host)
w.Write([]byte(hostname))
}))
defer backend.Close()
backendURL, _ := url.Parse(backend.URL)
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
proxyHandler := NewReverseProxy(proxyURL, &Options{FlushInterval: time.Second})
setProxyUpstreamHostHeader(proxyHandler, proxyURL)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
res, _ := http.DefaultClient.Do(getReq)
bodyBytes, _ := ioutil.ReadAll(res.Body)
if g, e := string(bodyBytes), backendHostname; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
}
func TestEncodedSlashes(t *testing.T) {
var seen string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
seen = r.RequestURI
}))
defer backend.Close()
b, _ := url.Parse(backend.URL)
proxyHandler := NewReverseProxy(b, &Options{FlushInterval: time.Second})
setProxyDirector(proxyHandler)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
f, _ := url.Parse(frontend.URL)
encodedPath := "/a%2Fb/?c=1"
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}}
_, err := http.DefaultClient.Do(getReq)
if err != nil {
t.Fatalf("err %s", err)
}
if seen != encodedPath {
t.Errorf("got bad request %q expected %q", seen, encodedPath)
}
}
func TestRobotsTxt(t *testing.T) {
opts := NewOptions()
opts.ClientID = "asdlkjx"
opts.ClientSecret = "alkgks"
opts.CookieSecret = "asdkugkj"
opts.Validate()
proxy := NewOAuthProxy(opts, func(string) bool { return true })
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/robots.txt", nil)
proxy.ServeHTTP(rw, req)
assert.Equal(t, 200, rw.Code)
assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String())
}
func TestIsValidRedirect(t *testing.T) {
opts := NewOptions()
opts.ClientID = "skdlfj"
opts.ClientSecret = "fgkdsgj"
opts.CookieSecret = "ljgiogbj"
// Should match domains that are exactly foo.bar and any subdomain of bar.foo
opts.WhitelistDomains = []string{"foo.bar", ".bar.foo"}
opts.Validate()
proxy := NewOAuthProxy(opts, func(string) bool { return true })
noRD := proxy.IsValidRedirect("")
assert.Equal(t, false, noRD)
singleSlash := proxy.IsValidRedirect("/redirect")
assert.Equal(t, true, singleSlash)
doubleSlash := proxy.IsValidRedirect("//redirect")
assert.Equal(t, false, doubleSlash)
validHTTP := proxy.IsValidRedirect("http://foo.bar/redirect")
assert.Equal(t, true, validHTTP)
validHTTPS := proxy.IsValidRedirect("https://foo.bar/redirect")
assert.Equal(t, true, validHTTPS)
invalidHTTPSubdomain := proxy.IsValidRedirect("http://baz.foo.bar/redirect")
assert.Equal(t, false, invalidHTTPSubdomain)
invalidHTTPSSubdomain := proxy.IsValidRedirect("https://baz.foo.bar/redirect")
assert.Equal(t, false, invalidHTTPSSubdomain)
validHTTPSubdomain := proxy.IsValidRedirect("http://baz.bar.foo/redirect")
assert.Equal(t, true, validHTTPSubdomain)
validHTTPSSubdomain := proxy.IsValidRedirect("https://baz.bar.foo/redirect")
assert.Equal(t, true, validHTTPSSubdomain)
invalidHTTP1 := proxy.IsValidRedirect("http://foo.bar.evil.corp/redirect")
assert.Equal(t, false, invalidHTTP1)
invalidHTTPS1 := proxy.IsValidRedirect("https://foo.bar.evil.corp/redirect")
assert.Equal(t, false, invalidHTTPS1)
invalidHTTP2 := proxy.IsValidRedirect("http://evil.corp/redirect?rd=foo.bar")
assert.Equal(t, false, invalidHTTP2)
invalidHTTPS2 := proxy.IsValidRedirect("https://evil.corp/redirect?rd=foo.bar")
assert.Equal(t, false, invalidHTTPS2)
}
type TestProvider struct {
*providers.ProviderData
EmailAddress string
ValidToken bool
GroupValidator func(string) bool
}
func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
return &TestProvider{
ProviderData: &providers.ProviderData{
ProviderName: "Test Provider",
LoginURL: &url.URL{
Scheme: "http",
Host: providerURL.Host,
Path: "/oauth/authorize",
},
RedeemURL: &url.URL{
Scheme: "http",
Host: providerURL.Host,
Path: "/oauth/token",
},
ProfileURL: &url.URL{
Scheme: "http",
Host: providerURL.Host,
Path: "/api/v1/profile",
},
Scope: "profile.email",
},
EmailAddress: emailAddress,
GroupValidator: func(s string) bool {
return true
},
}
}
func (tp *TestProvider) GetUserDetails(session *sessions.SessionState) (*providers.UserDetails, error) {
return &providers.UserDetails{
Email: tp.EmailAddress,
}, nil
}
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
return tp.ValidToken
}
func (tp *TestProvider) ValidateGroup(email string) bool {
if tp.GroupValidator != nil {
return tp.GroupValidator(email)
}
return true
}
func TestBasicAuthPassword(t *testing.T) {
providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Printf("%#v", r)
var payload string
switch r.URL.Path {
case "/oauth/token":
payload = `{"access_token": "my_auth_token"}`
default:
payload = r.Header.Get("Authorization")
if payload == "" {
payload = "No Authorization header found."
}
}
w.WriteHeader(200)
w.Write([]byte(payload))
}))
opts := NewOptions()
opts.Upstreams = append(opts.Upstreams, providerServer.URL)
// The CookieSecret must be 32 bytes in order to create the AES
// cipher.
opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
opts.ClientID = "dlgkj"
opts.ClientSecret = "alkgret"
opts.CookieSecure = false
opts.PassBasicAuth = true
opts.PassUserHeaders = true
opts.BasicAuthPassword = "This is a secure password"
opts.Validate()
providerURL, _ := url.Parse(providerServer.URL)
const emailAddress = "john.doe@example.com"
opts.provider = NewTestProvider(providerURL, emailAddress)
proxy := NewOAuthProxy(opts, func(email string) bool {
return email == emailAddress
})
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
strings.NewReader(""))
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
proxy.ServeHTTP(rw, req)
if rw.Code >= 400 {
t.Fatalf("expected 3xx got %d", rw.Code)
}
cookie := rw.HeaderMap["Set-Cookie"][1]
cookieName := proxy.CookieName
var value string
keyPrefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, keyPrefix)
if value != field {
break
} else {
value = ""
}
}
req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
req.AddCookie(&http.Cookie{
Name: cookieName,
Value: value,
Path: "/",
Expires: time.Now().Add(time.Duration(24)),
HttpOnly: true,
})
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
rw = httptest.NewRecorder()
proxy.ServeHTTP(rw, req)
// The username in the basic auth credentials is expected to be equal to the email address from the
// auth response, so we use the same variable here.
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(emailAddress+":"+opts.BasicAuthPassword))
assert.Equal(t, expectedHeader, rw.Body.String())
providerServer.Close()
}
type PassAccessTokenTest struct {
providerServer *httptest.Server
proxy *OAuthProxy
opts *Options
}
type PassAccessTokenTestOptions struct {
PassAccessToken bool
}
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
t := &PassAccessTokenTest{}
t.providerServer = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Printf("%#v", r)
var payload string
switch r.URL.Path {
case "/oauth/token":
payload = `{"access_token": "my_auth_token"}`
default:
payload = r.Header.Get("X-Forwarded-Access-Token")
if payload == "" {
payload = "No access token found."
}
}
w.WriteHeader(200)
w.Write([]byte(payload))
}))
t.opts = NewOptions()
t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL)
// The CookieSecret must be 32 bytes in order to create the AES
// cipher.
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
t.opts.ClientID = "slgkj"
t.opts.ClientSecret = "gfjgojl"
t.opts.CookieSecure = false
t.opts.PassAccessToken = opts.PassAccessToken
t.opts.Validate()
providerURL, _ := url.Parse(t.providerServer.URL)
const emailAddress = "michael.bland@gsa.gov"
t.opts.provider = NewTestProvider(providerURL, emailAddress)
t.proxy = NewOAuthProxy(t.opts, func(email string) bool {
return email == emailAddress
})
return t
}
func (patTest *PassAccessTokenTest) Close() {
patTest.providerServer.Close()
}
func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie string) {
rw := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
strings.NewReader(""))
if err != nil {
return 0, ""
}
req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.HeaderMap["Set-Cookie"][1]
}
func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int, accessToken string) {
cookieName := patTest.proxy.CookieName
var value string
keyPrefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, keyPrefix)
if value != field {
break
} else {
value = ""
}
}
if value == "" {
return 0, ""
}
req, err := http.NewRequest("GET", "/", strings.NewReader(""))
if err != nil {
return 0, ""
}
req.AddCookie(&http.Cookie{
Name: cookieName,
Value: value,
Path: "/",
Expires: time.Now().Add(time.Duration(24)),
HttpOnly: true,
})
rw := httptest.NewRecorder()
patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
func TestForwardAccessTokenUpstream(t *testing.T) {
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true,
})
defer patTest.Close()
// A successful validation will redirect and set the auth cookie.
code, cookie := patTest.getCallbackEndpoint()
if code != 302 {
t.Fatalf("expected 302; got %d", code)
}
assert.NotEqual(t, nil, cookie)
// Now we make a regular request; the access_token from the cookie is
// forwarded as the "X-Forwarded-Access-Token" header. The token is
// read by the test provider server and written in the response body.
code, payload := patTest.getRootEndpoint(cookie)
if code != 200 {
t.Fatalf("expected 200; got %d", code)
}
assert.Equal(t, "my_auth_token", payload)
}
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: false,
})
defer patTest.Close()
// A successful validation will redirect and set the auth cookie.
code, cookie := patTest.getCallbackEndpoint()
if code != 302 {
t.Fatalf("expected 302; got %d", code)
}
assert.NotEqual(t, nil, cookie)
// Now we make a regular request, but the access token header should
// not be present.
code, payload := patTest.getRootEndpoint(cookie)
if code != 200 {
t.Fatalf("expected 200; got %d", code)
}
assert.Equal(t, "No access token found.", payload)
}
type SignInPageTest struct {
opts *Options
proxy *OAuthProxy
signInRegexp *regexp.Regexp
signInProviderRegexp *regexp.Regexp
}
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
const signInSkipProvider = `>Found<`
func NewSignInPageTest(skipProvider bool) *SignInPageTest {
var sipTest SignInPageTest
sipTest.opts = NewOptions()
sipTest.opts.CookieSecret = "adklsj2"
sipTest.opts.ClientID = "lkdgj"
sipTest.opts.ClientSecret = "sgiufgoi"
sipTest.opts.SkipProviderButton = skipProvider
sipTest.opts.Validate()
sipTest.proxy = NewOAuthProxy(sipTest.opts, func(email string) bool {
return true
})
sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern)
sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider)
return &sipTest
}
func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
rw := httptest.NewRecorder()
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
sipTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
sipTest := NewSignInPageTest(false)
const endpoint = "/some/random/endpoint"
code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 403, code)
match := sipTest.signInRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body)
}
if match[1] != endpoint {
t.Fatal(`expected redirect to "` + endpoint +
`", but was "` + match[1] + `"`)
}
}
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
sipTest := NewSignInPageTest(false)
code, body := sipTest.GetEndpoint("/oauth2/sign_in")
assert.Equal(t, 200, code)
match := sipTest.signInRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInRedirectPattern + "\nBody:\n" + body)
}
if match[1] != "/" {
t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
}
}
func TestSignInPageSkipProvider(t *testing.T) {
sipTest := NewSignInPageTest(true)
const endpoint = "/some/random/endpoint"
code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 302, code)
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInSkipProvider + "\nBody:\n" + body)
}
}
func TestSignInPageSkipProviderDirect(t *testing.T) {
sipTest := NewSignInPageTest(true)
const endpoint = "/sign_in"
code, body := sipTest.GetEndpoint(endpoint)
assert.Equal(t, 302, code)
match := sipTest.signInProviderRegexp.FindStringSubmatch(body)
if match == nil {
t.Fatal("Did not find pattern in body: " +
signInSkipProvider + "\nBody:\n" + body)
}
}
type ProcessCookieTest struct {
opts *Options
proxy *OAuthProxy
rw *httptest.ResponseRecorder
req *http.Request
provider TestProvider
responseCode int
validateUser bool
}
type ProcessCookieTestOpts struct {
providerValidateCookieResponse bool
}
type OptionsModifier func(*Options)
func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest {
var pcTest ProcessCookieTest
pcTest.opts = NewOptions()
for _, modifier := range modifiers {
modifier(pcTest.opts)
}
pcTest.opts.ClientID = "asdfljk"
pcTest.opts.ClientSecret = "lkjfdsig"
pcTest.opts.CookieSecret = "0123456789abcdefabcd"
// First, set the CookieRefresh option so proxy.AesCipher is created,
// needed to encrypt the access_token.
pcTest.opts.CookieRefresh = time.Hour
pcTest.opts.Validate()
pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pcTest.validateUser
})
pcTest.proxy.provider = &TestProvider{
ValidToken: opts.providerValidateCookieResponse,
}
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation.
pcTest.proxy.CookieRefresh = time.Duration(0)
pcTest.rw = httptest.NewRecorder()
pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
pcTest.validateUser = true
return &pcTest
}
func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
return NewProcessCookieTest(ProcessCookieTestOpts{
providerValidateCookieResponse: true,
})
}
func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest {
return NewProcessCookieTest(ProcessCookieTestOpts{
providerValidateCookieResponse: true,
}, modifiers...)
}
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error {
err := p.proxy.SaveSession(p.rw, p.req, s)
if err != nil {
return err
}
for _, cookie := range p.rw.Result().Cookies() {
p.req.AddCookie(cookie)
}
return nil
}
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) {
return p.proxy.LoadCookiedSession(p.req)
}
func TestLoadCookiedSession(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults()
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()}
pcTest.SaveSession(startSession)
session, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err)
assert.Equal(t, startSession.Email, session.Email)
assert.Equal(t, "john.doe@example.com", session.User)
assert.Equal(t, startSession.AccessToken, session.AccessToken)
}
func TestProcessCookieNoCookieError(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults()
session, err := pcTest.LoadCookiedSession()
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
if session != nil {
t.Errorf("expected nil session. got %#v", session)
}
}
func TestProcessCookieRefreshNotSet(t *testing.T) {
pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
opts.CookieExpire = time.Duration(23) * time.Hour
})
reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
pcTest.SaveSession(startSession)
session, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err)
if session.Age() < time.Duration(-2)*time.Hour {
t.Errorf("cookie too young %v", session.Age())
}
assert.Equal(t, startSession.Email, session.Email)
}
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
opts.CookieExpire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
pcTest.SaveSession(startSession)
session, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expected nil session %#v", session)
}
}
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
opts.CookieExpire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
pcTest.SaveSession(startSession)
pcTest.proxy.CookieRefresh = time.Hour
session, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expected nil session %#v", session)
}
}
func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest {
pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...)
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
return pcTest
}
func TestAuthOnlyEndpointAccepted(t *testing.T) {
test := NewAuthOnlyEndpointTest()
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
test.SaveSession(startSession)
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusAccepted, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "", string(bodyBytes))
}
func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
test := NewAuthOnlyEndpointTest()
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
}
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
test := NewAuthOnlyEndpointTest(func(opts *Options) {
opts.CookieExpire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
test.SaveSession(startSession)
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
}
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test := NewAuthOnlyEndpointTest()
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
test.SaveSession(startSession)
test.validateUser = false
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
}
func TestAuthOnlyEndpointUnauthorizedOnProviderGroupValidationFailure(t *testing.T) {
test := NewAuthOnlyEndpointTest()
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
test.SaveSession(startSession)
provider := &TestProvider{
ValidToken: true,
GroupValidator: func(s string) bool {
return false
},
}
test.proxy.provider = provider
test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
}
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
var pcTest ProcessCookieTest
pcTest.opts = NewOptions()
pcTest.opts.SetXAuthRequest = true
pcTest.opts.Validate()
pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool {
return pcTest.validateUser
})
pcTest.proxy.provider = &TestProvider{
ValidToken: true,
}
pcTest.validateUser = true
pcTest.rw = httptest.NewRecorder()
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
pcTest.SaveSession(startSession)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0])
assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0])
}
func TestAuthSkippedForPreflightRequests(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write([]byte("response"))
}))
defer upstream.Close()
opts := NewOptions()
opts.Upstreams = append(opts.Upstreams, upstream.URL)
opts.ClientID = "aljsal"
opts.ClientSecret = "jglkfsdgj"
opts.CookieSecret = "dkfjgdls"
opts.SkipAuthPreflight = true
opts.Validate()
upstreamURL, _ := url.Parse(upstream.URL)
opts.provider = NewTestProvider(upstreamURL, "")
proxy := NewOAuthProxy(opts, func(string) bool { return false })
rw := httptest.NewRecorder()
req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
proxy.ServeHTTP(rw, req)
assert.Equal(t, 200, rw.Code)
assert.Equal(t, "response", rw.Body.String())
}
type SignatureAuthenticator struct {
auth hmacauth.HmacAuth
}
func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) {
result, headerSig, computedSig := v.auth.AuthenticateRequest(r)
if result == hmacauth.ResultNoSignature {
w.Write([]byte("no signature received"))
} else if result == hmacauth.ResultMatch {
w.Write([]byte("signatures match"))
} else if result == hmacauth.ResultMismatch {
w.Write([]byte("signatures do not match:" +
"\n received: " + headerSig +
"\n computed: " + computedSig))
} else {
panic("Unknown result value: " + result.String())
}
}
type SignatureTest struct {
opts *Options
upstream *httptest.Server
upstreamHost string
provider *httptest.Server
header http.Header
rw *httptest.ResponseRecorder
authenticator *SignatureAuthenticator
}
func NewSignatureTest() *SignatureTest {
opts := NewOptions()
opts.CookieSecret = "cookie secret"
opts.ClientID = "client ID"
opts.ClientSecret = "client secret"
opts.EmailDomains = []string{"acm.org"}
authenticator := &SignatureAuthenticator{}
upstream := httptest.NewServer(
http.HandlerFunc(authenticator.Authenticate))
upstreamURL, _ := url.Parse(upstream.URL)
opts.Upstreams = append(opts.Upstreams, upstream.URL)
providerHandler := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"access_token": "my_auth_token"}`))
}
provider := httptest.NewServer(http.HandlerFunc(providerHandler))
providerURL, _ := url.Parse(provider.URL)
opts.provider = NewTestProvider(providerURL, "mbland@acm.org")
return &SignatureTest{
opts,
upstream,
upstreamURL.Host,
provider,
make(http.Header),
httptest.NewRecorder(),
authenticator,
}
}
func (st *SignatureTest) Close() {
st.provider.Close()
st.upstream.Close()
}
// fakeNetConn simulates an http.Request.Body buffer that will be consumed
// when it is read by the hmacauth.HmacAuth if not handled properly. See:
// https://github.com/18F/hmacauth/pull/4
type fakeNetConn struct {
reqBody string
}
func (fnc *fakeNetConn) Read(p []byte) (n int, err error) {
if bodyLen := len(fnc.reqBody); bodyLen != 0 {
copy(p, fnc.reqBody)
fnc.reqBody = ""
return bodyLen, io.EOF
}
return 0, io.EOF
}
func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
err := st.opts.Validate()
if err != nil {
panic(err)
}
proxy := NewOAuthProxy(st.opts, func(email string) bool { return true })
var bodyBuf io.ReadCloser
if body != "" {
bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body})
}
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
req.Header = st.header
state := &sessions.SessionState{
Email: "mbland@acm.org", AccessToken: "my_access_token"}
err = proxy.SaveSession(st.rw, req, state)
if err != nil {
panic(err)
}
for _, c := range st.rw.Result().Cookies() {
req.AddCookie(c)
}
// This is used by the upstream to validate the signature.
st.authenticator.auth = hmacauth.NewHmacAuth(
crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)
proxy.ServeHTTP(st.rw, req)
}
func TestNoRequestSignature(t *testing.T) {
st := NewSignatureTest()
defer st.Close()
st.MakeRequestWithExpectedKey("GET", "", "")
assert.Equal(t, 200, st.rw.Code)
assert.Equal(t, st.rw.Body.String(), "no signature received")
}
func TestRequestSignatureGetRequest(t *testing.T) {
st := NewSignatureTest()
defer st.Close()
st.opts.SignatureKey = "sha1:7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d"
st.MakeRequestWithExpectedKey("GET", "", "7d9e1aa87a5954e6f9fc59266b3af9d7c35fda2d")
assert.Equal(t, 200, st.rw.Code)
assert.Equal(t, st.rw.Body.String(), "signatures match")
}
func TestRequestSignaturePostRequest(t *testing.T) {
st := NewSignatureTest()
defer st.Close()
st.opts.SignatureKey = "sha1:d90df39e2d19282840252612dd7c81421a372f61"
payload := `{ "hello": "world!" }`
st.MakeRequestWithExpectedKey("POST", payload, "d90df39e2d19282840252612dd7c81421a372f61")
assert.Equal(t, 200, st.rw.Code)
assert.Equal(t, st.rw.Body.String(), "signatures match")
}
func TestGetRedirect(t *testing.T) {
options := NewOptions()
_ = options.Validate()
require.NotEmpty(t, options.ProxyPrefix)
proxy := NewOAuthProxy(options, func(s string) bool { return false })
tests := []struct {
name string
url string
expectedRedirect string
}{
{
name: "request outside of ProxyPrefix redirects to original URL",
url: "/foo/bar",
expectedRedirect: "/foo/bar",
},
{
name: "request under ProxyPrefix redirects to root",
url: proxy.ProxyPrefix + "/foo/bar",
expectedRedirect: "/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req, _ := http.NewRequest("GET", tt.url, nil)
redirect, err := proxy.GetRedirect(req)
assert.NoError(t, err)
assert.Equal(t, tt.expectedRedirect, redirect)
})
}
}
type ajaxRequestTest struct {
opts *Options
proxy *OAuthProxy
}
func newAjaxRequestTest() *ajaxRequestTest {
test := &ajaxRequestTest{}
test.opts = NewOptions()
test.opts.CookieSecret = "sdflsw"
test.opts.ClientID = "gkljfdl"
test.opts.ClientSecret = "sdflkjs"
test.opts.Validate()
test.proxy = NewOAuthProxy(test.opts, func(email string) bool {
return true
})
return test
}
func (test *ajaxRequestTest) getEndpoint(endpoint string, header http.Header) (int, http.Header, error) {
rw := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, endpoint, strings.NewReader(""))
if err != nil {
return 0, nil, err
}
req.Header = header
test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Header(), nil
}
func testAjaxUnauthorizedRequest(t *testing.T, header http.Header) {
test := newAjaxRequestTest()
endpoint := "/test"
code, rh, err := test.getEndpoint(endpoint, header)
assert.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, code)
mime := rh.Get("Content-Type")
assert.Equal(t, applicationJSON, mime)
}
func TestAjaxUnauthorizedRequest1(t *testing.T) {
header := make(http.Header)
header.Add("accept", applicationJSON)
testAjaxUnauthorizedRequest(t, header)
}
func TestAjaxUnauthorizedRequest2(t *testing.T) {
header := make(http.Header)
header.Add("Accept", applicationJSON)
testAjaxUnauthorizedRequest(t, header)
}
func TestAjaxForbiddendRequest(t *testing.T) {
test := newAjaxRequestTest()
endpoint := "/test"
header := make(http.Header)
code, rh, err := test.getEndpoint(endpoint, header)
assert.NoError(t, err)
assert.Equal(t, http.StatusForbidden, code)
mime := rh.Get("Content-Type")
assert.NotEqual(t, applicationJSON, mime)
}
func TestClearSplitCookie(t *testing.T) {
opts := NewOptions()
opts.CookieName = "oauth2"
opts.CookieDomain = "abc"
store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions)
assert.Equal(t, err, nil)
p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store}
var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)
req.AddCookie(&http.Cookie{
Name: "test1",
Value: "test1",
})
req.AddCookie(&http.Cookie{
Name: "oauth2_0",
Value: "oauth2_0",
})
req.AddCookie(&http.Cookie{
Name: "oauth2_1",
Value: "oauth2_1",
})
p.ClearSessionCookie(rw, req)
header := rw.Header()
assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries")
}
func TestClearSingleCookie(t *testing.T) {
opts := NewOptions()
opts.CookieName = "oauth2"
opts.CookieDomain = "abc"
store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions)
assert.Equal(t, err, nil)
p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store}
var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)
req.AddCookie(&http.Cookie{
Name: "test1",
Value: "test1",
})
req.AddCookie(&http.Cookie{
Name: "oauth2",
Value: "oauth2",
})
p.ClearSessionCookie(rw, req)
header := rw.Header()
assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries")
}
type NoOpKeySet struct {
}
func (NoOpKeySet) VerifySignature(ctx context.Context, jwt string) (payload []byte, err error) {
splitStrings := strings.Split(jwt, ".")
payloadString := splitStrings[1]
jsonString, err := base64.RawURLEncoding.DecodeString(payloadString)
return []byte(jsonString), err
}
func TestGetJwtSession(t *testing.T) {
/* token payload:
{
"sub": "1234567890",
"aud": "https://test.myapp.com",
"name": "John Doe",
"email": "john@example.com",
"iss": "https://issuer.example.com",
"iat": 1553691215,
"exp": 1912151821
}
*/
goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
"eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" +
"WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" +
"E1LCJleHAiOjE5MTIxNTE4MjF9." +
"rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" +
"OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8"
keyset := NoOpKeySet{}
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
test := NewAuthOnlyEndpointTest(func(opts *Options) {
opts.PassAuthorization = true
opts.SetAuthorization = true
opts.SetXAuthRequest = true
opts.SkipJwtBearerTokens = true
opts.jwtBearerVerifiers = append(opts.jwtBearerVerifiers, verifier)
})
tp, _ := test.proxy.provider.(*TestProvider)
tp.GroupValidator = func(s string) bool {
return true
}
authHeader := fmt.Sprintf("Bearer %s", goodJwt)
test.req.Header = map[string][]string{
"Authorization": {authHeader},
}
// Bearer
session, _ := test.proxy.GetJwtSession(test.req)
assert.Equal(t, session.User, "john@example.com")
assert.Equal(t, session.Email, "john@example.com")
assert.Equal(t, session.ExpiresOn, time.Unix(1912151821, 0))
assert.Equal(t, session.IDToken, goodJwt)
test.proxy.ServeHTTP(test.rw, test.req)
if test.rw.Code >= 400 {
t.Fatalf("expected 3xx got %d", test.rw.Code)
}
// Check PassAuthorization, should overwrite Basic header
assert.Equal(t, test.req.Header.Get("Authorization"), authHeader)
assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "john@example.com")
assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com")
// SetAuthorization and SetXAuthRequest
assert.Equal(t, test.rw.Header().Get("Authorization"), authHeader)
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "john@example.com")
assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com")
}
func TestJwtUnauthorizedOnGroupValidationFailure(t *testing.T) {
goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
"eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" +
"WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" +
"E1LCJleHAiOjE5MTIxNTE4MjF9." +
"rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" +
"OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8"
keyset := NoOpKeySet{}
verifier := oidc.NewVerifier("https://issuer.example.com", keyset,
&oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true})
test := NewAuthOnlyEndpointTest(func(opts *Options) {
opts.PassAuthorization = true
opts.SetAuthorization = true
opts.SetXAuthRequest = true
opts.SkipJwtBearerTokens = true
opts.jwtBearerVerifiers = append(opts.jwtBearerVerifiers, verifier)
})
tp, _ := test.proxy.provider.(*TestProvider)
// Verify ValidateGroup fails JWT authorization
tp.GroupValidator = func(s string) bool {
return false
}
authHeader := fmt.Sprintf("Bearer %s", goodJwt)
test.req.Header = map[string][]string{
"Authorization": {authHeader},
}
test.proxy.ServeHTTP(test.rw, test.req)
if test.rw.Code != http.StatusUnauthorized {
t.Fatalf("expected 401 got %d", test.rw.Code)
}
}
func TestFindJwtBearerToken(t *testing.T) {
p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"}
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}}
validToken := "eyJfoobar.eyJfoobar.12345asdf"
var token string
// Bearer
getReq.Header = map[string][]string{
"Authorization": {fmt.Sprintf("Bearer %s", validToken)},
}
token, _ = p.findBearerToken(getReq)
assert.Equal(t, validToken, token)
// Basic - no password
getReq.SetBasicAuth(token, "")
token, _ = p.findBearerToken(getReq)
assert.Equal(t, validToken, token)
// Basic - sentinel password
getReq.SetBasicAuth(token, "x-oauth-basic")
token, _ = p.findBearerToken(getReq)
assert.Equal(t, validToken, token)
// Basic - any username, password matching jwt pattern
getReq.SetBasicAuth("any-username-you-could-wish-for", token)
token, _ = p.findBearerToken(getReq)
assert.Equal(t, validToken, token)
failures := []string{
// Too many parts
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
// Not enough parts
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA",
// Invalid encrypted key
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.//////.dGVzdA.dGVzdA.dGVzdA",
// Invalid IV
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.//////.dGVzdA.dGVzdA",
// Invalid ciphertext
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.//////.dGVzdA",
// Invalid tag
"eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.//////",
// Invalid header
"W10.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
// Invalid header
"######.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
// Missing alc/enc params
"e30.dGVzdA.dGVzdA.dGVzdA.dGVzdA",
}
for _, failure := range failures {
getReq.Header = map[string][]string{
"Authorization": {fmt.Sprintf("Bearer %s", failure)},
}
_, err := p.findBearerToken(getReq)
assert.Error(t, err)
}
fmt.Printf("%s", token)
}