e241fe86d3
Since I'm no longer with 18F, I've re-released hmacauth under the ISC license as opposed to the previous CC0 license. There have been no changes to the hmacauth code itself, and all tests still pass.
839 lines
24 KiB
Go
839 lines
24 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto"
|
|
"encoding/base64"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/bitly/oauth2_proxy/providers"
|
|
"github.com/mbland/hmacauth"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func init() {
|
|
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
|
|
|
|
}
|
|
|
|
func TestNewReverseProxy(t *testing.T) {
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
hostname, _, _ := net.SplitHostPort(r.Host)
|
|
w.Write([]byte(hostname))
|
|
}))
|
|
defer backend.Close()
|
|
|
|
backendURL, _ := url.Parse(backend.URL)
|
|
backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
|
|
backendHost := net.JoinHostPort(backendHostname, backendPort)
|
|
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
|
|
|
|
proxyHandler := NewReverseProxy(proxyURL)
|
|
setProxyUpstreamHostHeader(proxyHandler, proxyURL)
|
|
frontend := httptest.NewServer(proxyHandler)
|
|
defer frontend.Close()
|
|
|
|
getReq, _ := http.NewRequest("GET", frontend.URL, nil)
|
|
res, _ := http.DefaultClient.Do(getReq)
|
|
bodyBytes, _ := ioutil.ReadAll(res.Body)
|
|
if g, e := string(bodyBytes), backendHostname; g != e {
|
|
t.Errorf("got body %q; expected %q", g, e)
|
|
}
|
|
}
|
|
|
|
func TestEncodedSlashes(t *testing.T) {
|
|
var seen string
|
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
seen = r.RequestURI
|
|
}))
|
|
defer backend.Close()
|
|
|
|
b, _ := url.Parse(backend.URL)
|
|
proxyHandler := NewReverseProxy(b)
|
|
setProxyDirector(proxyHandler)
|
|
frontend := httptest.NewServer(proxyHandler)
|
|
defer frontend.Close()
|
|
|
|
f, _ := url.Parse(frontend.URL)
|
|
encodedPath := "/a%2Fb/?c=1"
|
|
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}}
|
|
_, err := http.DefaultClient.Do(getReq)
|
|
if err != nil {
|
|
t.Fatalf("err %s", err)
|
|
}
|
|
if seen != encodedPath {
|
|
t.Errorf("got bad request %q expected %q", seen, encodedPath)
|
|
}
|
|
}
|
|
|
|
func TestRobotsTxt(t *testing.T) {
|
|
opts := NewOptions()
|
|
opts.ClientID = "bazquux"
|
|
opts.ClientSecret = "foobar"
|
|
opts.CookieSecret = "xyzzyplugh"
|
|
opts.Validate()
|
|
|
|
proxy := NewOAuthProxy(opts, func(string) bool { return true })
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/robots.txt", nil)
|
|
proxy.ServeHTTP(rw, req)
|
|
assert.Equal(t, 200, rw.Code)
|
|
assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String())
|
|
}
|
|
|
|
type TestProvider struct {
|
|
*providers.ProviderData
|
|
EmailAddress string
|
|
ValidToken bool
|
|
}
|
|
|
|
func NewTestProvider(provider_url *url.URL, email_address string) *TestProvider {
|
|
return &TestProvider{
|
|
ProviderData: &providers.ProviderData{
|
|
ProviderName: "Test Provider",
|
|
LoginURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: provider_url.Host,
|
|
Path: "/oauth/authorize",
|
|
},
|
|
RedeemURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: provider_url.Host,
|
|
Path: "/oauth/token",
|
|
},
|
|
ProfileURL: &url.URL{
|
|
Scheme: "http",
|
|
Host: provider_url.Host,
|
|
Path: "/api/v1/profile",
|
|
},
|
|
Scope: "profile.email",
|
|
},
|
|
EmailAddress: email_address,
|
|
}
|
|
}
|
|
|
|
func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) {
|
|
return tp.EmailAddress, nil
|
|
}
|
|
|
|
func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool {
|
|
return tp.ValidToken
|
|
}
|
|
|
|
func TestBasicAuthPassword(t *testing.T) {
|
|
provider_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
log.Printf("%#v", r)
|
|
url := r.URL
|
|
payload := ""
|
|
switch url.Path {
|
|
case "/oauth/token":
|
|
payload = `{"access_token": "my_auth_token"}`
|
|
default:
|
|
payload = r.Header.Get("Authorization")
|
|
if payload == "" {
|
|
payload = "No Authorization header found."
|
|
}
|
|
}
|
|
w.WriteHeader(200)
|
|
w.Write([]byte(payload))
|
|
}))
|
|
opts := NewOptions()
|
|
opts.Upstreams = append(opts.Upstreams, provider_server.URL)
|
|
// The CookieSecret must be 32 bytes in order to create the AES
|
|
// cipher.
|
|
opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
|
|
opts.ClientID = "bazquux"
|
|
opts.ClientSecret = "foobar"
|
|
opts.CookieSecure = false
|
|
opts.PassBasicAuth = true
|
|
opts.PassUserHeaders = true
|
|
opts.BasicAuthPassword = "This is a secure password"
|
|
opts.Validate()
|
|
|
|
provider_url, _ := url.Parse(provider_server.URL)
|
|
const email_address = "michael.bland@gsa.gov"
|
|
const user_name = "michael.bland"
|
|
|
|
opts.provider = NewTestProvider(provider_url, email_address)
|
|
proxy := NewOAuthProxy(opts, func(email string) bool {
|
|
return email == email_address
|
|
})
|
|
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
|
|
strings.NewReader(""))
|
|
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
|
|
proxy.ServeHTTP(rw, req)
|
|
if rw.Code >= 400 {
|
|
t.Fatalf("expected 3xx got %d", rw.Code)
|
|
}
|
|
cookie := rw.HeaderMap["Set-Cookie"][1]
|
|
|
|
cookieName := proxy.CookieName
|
|
var value string
|
|
key_prefix := cookieName + "="
|
|
|
|
for _, field := range strings.Split(cookie, "; ") {
|
|
value = strings.TrimPrefix(field, key_prefix)
|
|
if value != field {
|
|
break
|
|
} else {
|
|
value = ""
|
|
}
|
|
}
|
|
|
|
req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
|
|
req.AddCookie(&http.Cookie{
|
|
Name: cookieName,
|
|
Value: value,
|
|
Path: "/",
|
|
Expires: time.Now().Add(time.Duration(24)),
|
|
HttpOnly: true,
|
|
})
|
|
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
|
|
|
|
rw = httptest.NewRecorder()
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword))
|
|
assert.Equal(t, expectedHeader, rw.Body.String())
|
|
provider_server.Close()
|
|
}
|
|
|
|
type PassAccessTokenTest struct {
|
|
provider_server *httptest.Server
|
|
proxy *OAuthProxy
|
|
opts *Options
|
|
}
|
|
|
|
type PassAccessTokenTestOptions struct {
|
|
PassAccessToken bool
|
|
}
|
|
|
|
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
|
|
t := &PassAccessTokenTest{}
|
|
|
|
t.provider_server = httptest.NewServer(
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
log.Printf("%#v", r)
|
|
url := r.URL
|
|
payload := ""
|
|
switch url.Path {
|
|
case "/oauth/token":
|
|
payload = `{"access_token": "my_auth_token"}`
|
|
default:
|
|
payload = r.Header.Get("X-Forwarded-Access-Token")
|
|
if payload == "" {
|
|
payload = "No access token found."
|
|
}
|
|
}
|
|
w.WriteHeader(200)
|
|
w.Write([]byte(payload))
|
|
}))
|
|
|
|
t.opts = NewOptions()
|
|
t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL)
|
|
// The CookieSecret must be 32 bytes in order to create the AES
|
|
// cipher.
|
|
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
|
|
t.opts.ClientID = "bazquux"
|
|
t.opts.ClientSecret = "foobar"
|
|
t.opts.CookieSecure = false
|
|
t.opts.PassAccessToken = opts.PassAccessToken
|
|
t.opts.Validate()
|
|
|
|
provider_url, _ := url.Parse(t.provider_server.URL)
|
|
const email_address = "michael.bland@gsa.gov"
|
|
|
|
t.opts.provider = NewTestProvider(provider_url, email_address)
|
|
t.proxy = NewOAuthProxy(t.opts, func(email string) bool {
|
|
return email == email_address
|
|
})
|
|
return t
|
|
}
|
|
|
|
func (pat_test *PassAccessTokenTest) Close() {
|
|
pat_test.provider_server.Close()
|
|
}
|
|
|
|
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
|
|
cookie string) {
|
|
rw := httptest.NewRecorder()
|
|
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
|
|
strings.NewReader(""))
|
|
if err != nil {
|
|
return 0, ""
|
|
}
|
|
req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
|
|
pat_test.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.HeaderMap["Set-Cookie"][1]
|
|
}
|
|
|
|
func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
|
|
cookieName := pat_test.proxy.CookieName
|
|
var value string
|
|
key_prefix := cookieName + "="
|
|
|
|
for _, field := range strings.Split(cookie, "; ") {
|
|
value = strings.TrimPrefix(field, key_prefix)
|
|
if value != field {
|
|
break
|
|
} else {
|
|
value = ""
|
|
}
|
|
}
|
|
if value == "" {
|
|
return 0, ""
|
|
}
|
|
|
|
req, err := http.NewRequest("GET", "/", strings.NewReader(""))
|
|
if err != nil {
|
|
return 0, ""
|
|
}
|
|
req.AddCookie(&http.Cookie{
|
|
Name: cookieName,
|
|
Value: value,
|
|
Path: "/",
|
|
Expires: time.Now().Add(time.Duration(24)),
|
|
HttpOnly: true,
|
|
})
|
|
|
|
rw := httptest.NewRecorder()
|
|
pat_test.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Body.String()
|
|
}
|
|
|
|
func TestForwardAccessTokenUpstream(t *testing.T) {
|
|
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: true,
|
|
})
|
|
defer pat_test.Close()
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := pat_test.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotEqual(t, nil, cookie)
|
|
|
|
// Now we make a regular request; the access_token from the cookie is
|
|
// forwarded as the "X-Forwarded-Access-Token" header. The token is
|
|
// read by the test provider server and written in the response body.
|
|
code, payload := pat_test.getRootEndpoint(cookie)
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "my_auth_token", payload)
|
|
}
|
|
|
|
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
|
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
|
PassAccessToken: false,
|
|
})
|
|
defer pat_test.Close()
|
|
|
|
// A successful validation will redirect and set the auth cookie.
|
|
code, cookie := pat_test.getCallbackEndpoint()
|
|
if code != 302 {
|
|
t.Fatalf("expected 302; got %d", code)
|
|
}
|
|
assert.NotEqual(t, nil, cookie)
|
|
|
|
// Now we make a regular request, but the access token header should
|
|
// not be present.
|
|
code, payload := pat_test.getRootEndpoint(cookie)
|
|
if code != 200 {
|
|
t.Fatalf("expected 200; got %d", code)
|
|
}
|
|
assert.Equal(t, "No access token found.", payload)
|
|
}
|
|
|
|
type SignInPageTest struct {
|
|
opts *Options
|
|
proxy *OAuthProxy
|
|
sign_in_regexp *regexp.Regexp
|
|
sign_in_provider_regexp *regexp.Regexp
|
|
}
|
|
|
|
const signInRedirectPattern = `<input type="hidden" name="rd" value="(.*)">`
|
|
const signInSkipProvider = `>Found<`
|
|
|
|
func NewSignInPageTest(skipProvider bool) *SignInPageTest {
|
|
var sip_test SignInPageTest
|
|
|
|
sip_test.opts = NewOptions()
|
|
sip_test.opts.CookieSecret = "foobar"
|
|
sip_test.opts.ClientID = "bazquux"
|
|
sip_test.opts.ClientSecret = "xyzzyplugh"
|
|
sip_test.opts.SkipProviderButton = skipProvider
|
|
sip_test.opts.Validate()
|
|
|
|
sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool {
|
|
return true
|
|
})
|
|
sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern)
|
|
sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider)
|
|
|
|
return &sip_test
|
|
}
|
|
|
|
func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) {
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("GET", endpoint, strings.NewReader(""))
|
|
sip_test.proxy.ServeHTTP(rw, req)
|
|
return rw.Code, rw.Body.String()
|
|
}
|
|
|
|
func TestSignInPageIncludesTargetRedirect(t *testing.T) {
|
|
sip_test := NewSignInPageTest(false)
|
|
const endpoint = "/some/random/endpoint"
|
|
|
|
code, body := sip_test.GetEndpoint(endpoint)
|
|
assert.Equal(t, 403, code)
|
|
|
|
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInRedirectPattern + "\nBody:\n" + body)
|
|
}
|
|
if match[1] != endpoint {
|
|
t.Fatal(`expected redirect to "` + endpoint +
|
|
`", but was "` + match[1] + `"`)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
|
sip_test := NewSignInPageTest(false)
|
|
code, body := sip_test.GetEndpoint("/oauth2/sign_in")
|
|
assert.Equal(t, 200, code)
|
|
|
|
match := sip_test.sign_in_regexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInRedirectPattern + "\nBody:\n" + body)
|
|
}
|
|
if match[1] != "/" {
|
|
t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageSkipProvider(t *testing.T) {
|
|
sip_test := NewSignInPageTest(true)
|
|
const endpoint = "/some/random/endpoint"
|
|
|
|
code, body := sip_test.GetEndpoint(endpoint)
|
|
assert.Equal(t, 302, code)
|
|
|
|
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInSkipProvider + "\nBody:\n" + body)
|
|
}
|
|
}
|
|
|
|
func TestSignInPageSkipProviderDirect(t *testing.T) {
|
|
sip_test := NewSignInPageTest(true)
|
|
const endpoint = "/sign_in"
|
|
|
|
code, body := sip_test.GetEndpoint(endpoint)
|
|
assert.Equal(t, 302, code)
|
|
|
|
match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body)
|
|
if match == nil {
|
|
t.Fatal("Did not find pattern in body: " +
|
|
signInSkipProvider + "\nBody:\n" + body)
|
|
}
|
|
}
|
|
|
|
type ProcessCookieTest struct {
|
|
opts *Options
|
|
proxy *OAuthProxy
|
|
rw *httptest.ResponseRecorder
|
|
req *http.Request
|
|
provider TestProvider
|
|
response_code int
|
|
validate_user bool
|
|
}
|
|
|
|
type ProcessCookieTestOpts struct {
|
|
provider_validate_cookie_response bool
|
|
}
|
|
|
|
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
|
|
var pc_test ProcessCookieTest
|
|
|
|
pc_test.opts = NewOptions()
|
|
pc_test.opts.ClientID = "bazquux"
|
|
pc_test.opts.ClientSecret = "xyzzyplugh"
|
|
pc_test.opts.CookieSecret = "0123456789abcdefabcd"
|
|
// First, set the CookieRefresh option so proxy.AesCipher is created,
|
|
// needed to encrypt the access_token.
|
|
pc_test.opts.CookieRefresh = time.Hour
|
|
pc_test.opts.Validate()
|
|
|
|
pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool {
|
|
return pc_test.validate_user
|
|
})
|
|
pc_test.proxy.provider = &TestProvider{
|
|
ValidToken: opts.provider_validate_cookie_response,
|
|
}
|
|
|
|
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
|
// access_token validation.
|
|
pc_test.proxy.CookieRefresh = time.Duration(0)
|
|
pc_test.rw = httptest.NewRecorder()
|
|
pc_test.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
|
|
pc_test.validate_user = true
|
|
return &pc_test
|
|
}
|
|
|
|
func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
|
|
return NewProcessCookieTest(ProcessCookieTestOpts{
|
|
provider_validate_cookie_response: true,
|
|
})
|
|
}
|
|
|
|
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
|
|
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
|
}
|
|
|
|
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
|
|
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref))
|
|
return nil
|
|
}
|
|
|
|
func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) {
|
|
return p.proxy.LoadCookiedSession(p.req)
|
|
}
|
|
|
|
func TestLoadCookiedSession(t *testing.T) {
|
|
pc_test := NewProcessCookieTestWithDefaults()
|
|
|
|
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
|
pc_test.SaveSession(startSession, time.Now())
|
|
|
|
session, _, err := pc_test.LoadCookiedSession()
|
|
assert.Equal(t, nil, err)
|
|
assert.Equal(t, startSession.Email, session.Email)
|
|
assert.Equal(t, "michael.bland", session.User)
|
|
assert.Equal(t, startSession.AccessToken, session.AccessToken)
|
|
}
|
|
|
|
func TestProcessCookieNoCookieError(t *testing.T) {
|
|
pc_test := NewProcessCookieTestWithDefaults()
|
|
|
|
session, _, err := pc_test.LoadCookiedSession()
|
|
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
|
|
if session != nil {
|
|
t.Errorf("expected nil session. got %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestProcessCookieRefreshNotSet(t *testing.T) {
|
|
pc_test := NewProcessCookieTestWithDefaults()
|
|
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
|
|
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
|
|
|
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
|
pc_test.SaveSession(startSession, reference)
|
|
|
|
session, age, err := pc_test.LoadCookiedSession()
|
|
assert.Equal(t, nil, err)
|
|
if age < time.Duration(-2)*time.Hour {
|
|
t.Errorf("cookie too young %v", age)
|
|
}
|
|
assert.Equal(t, startSession.Email, session.Email)
|
|
}
|
|
|
|
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
|
pc_test := NewProcessCookieTestWithDefaults()
|
|
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
|
pc_test.SaveSession(startSession, reference)
|
|
|
|
session, _, err := pc_test.LoadCookiedSession()
|
|
assert.NotEqual(t, nil, err)
|
|
if session != nil {
|
|
t.Errorf("expected nil session %#v", session)
|
|
}
|
|
}
|
|
|
|
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
|
pc_test := NewProcessCookieTestWithDefaults()
|
|
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
|
pc_test.SaveSession(startSession, reference)
|
|
|
|
pc_test.proxy.CookieRefresh = time.Hour
|
|
session, _, err := pc_test.LoadCookiedSession()
|
|
assert.NotEqual(t, nil, err)
|
|
if session != nil {
|
|
t.Errorf("expected nil session %#v", session)
|
|
}
|
|
}
|
|
|
|
func NewAuthOnlyEndpointTest() *ProcessCookieTest {
|
|
pc_test := NewProcessCookieTestWithDefaults()
|
|
pc_test.req, _ = http.NewRequest("GET",
|
|
pc_test.opts.ProxyPrefix+"/auth", nil)
|
|
return pc_test
|
|
}
|
|
|
|
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
|
test := NewAuthOnlyEndpointTest()
|
|
startSession := &providers.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
|
test.SaveSession(startSession, time.Now())
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusAccepted, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
|
test := NewAuthOnlyEndpointTest()
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
|
test := NewAuthOnlyEndpointTest()
|
|
test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
|
startSession := &providers.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
|
test.SaveSession(startSession, reference)
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
|
test := NewAuthOnlyEndpointTest()
|
|
startSession := &providers.SessionState{
|
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
|
test.SaveSession(startSession, time.Now())
|
|
test.validate_user = false
|
|
|
|
test.proxy.ServeHTTP(test.rw, test.req)
|
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
|
bodyBytes, _ := ioutil.ReadAll(test.rw.Body)
|
|
assert.Equal(t, "unauthorized request\n", string(bodyBytes))
|
|
}
|
|
|
|
func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
|
var pc_test ProcessCookieTest
|
|
|
|
pc_test.opts = NewOptions()
|
|
pc_test.opts.SetXAuthRequest = true
|
|
pc_test.opts.Validate()
|
|
|
|
pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool {
|
|
return pc_test.validate_user
|
|
})
|
|
pc_test.proxy.provider = &TestProvider{
|
|
ValidToken: true,
|
|
}
|
|
|
|
pc_test.validate_user = true
|
|
|
|
pc_test.rw = httptest.NewRecorder()
|
|
pc_test.req, _ = http.NewRequest("GET",
|
|
pc_test.opts.ProxyPrefix+"/auth", nil)
|
|
|
|
startSession := &providers.SessionState{
|
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
|
|
pc_test.SaveSession(startSession, time.Now())
|
|
|
|
pc_test.proxy.ServeHTTP(pc_test.rw, pc_test.req)
|
|
assert.Equal(t, http.StatusAccepted, pc_test.rw.Code)
|
|
assert.Equal(t, "oauth_user", pc_test.rw.HeaderMap["X-Auth-Request-User"][0])
|
|
assert.Equal(t, "oauth_user@example.com", pc_test.rw.HeaderMap["X-Auth-Request-Email"][0])
|
|
}
|
|
|
|
func TestAuthSkippedForPreflightRequests(t *testing.T) {
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(200)
|
|
w.Write([]byte("response"))
|
|
}))
|
|
defer upstream.Close()
|
|
|
|
opts := NewOptions()
|
|
opts.Upstreams = append(opts.Upstreams, upstream.URL)
|
|
opts.ClientID = "bazquux"
|
|
opts.ClientSecret = "foobar"
|
|
opts.CookieSecret = "xyzzyplugh"
|
|
opts.SkipAuthPreflight = true
|
|
opts.Validate()
|
|
|
|
upstream_url, _ := url.Parse(upstream.URL)
|
|
opts.provider = NewTestProvider(upstream_url, "")
|
|
|
|
proxy := NewOAuthProxy(opts, func(string) bool { return false })
|
|
rw := httptest.NewRecorder()
|
|
req, _ := http.NewRequest("OPTIONS", "/preflight-request", nil)
|
|
proxy.ServeHTTP(rw, req)
|
|
|
|
assert.Equal(t, 200, rw.Code)
|
|
assert.Equal(t, "response", rw.Body.String())
|
|
}
|
|
|
|
type SignatureAuthenticator struct {
|
|
auth hmacauth.HmacAuth
|
|
}
|
|
|
|
func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Request) {
|
|
result, headerSig, computedSig := v.auth.AuthenticateRequest(r)
|
|
if result == hmacauth.ResultNoSignature {
|
|
w.Write([]byte("no signature received"))
|
|
} else if result == hmacauth.ResultMatch {
|
|
w.Write([]byte("signatures match"))
|
|
} else if result == hmacauth.ResultMismatch {
|
|
w.Write([]byte("signatures do not match:" +
|
|
"\n received: " + headerSig +
|
|
"\n computed: " + computedSig))
|
|
} else {
|
|
panic("Unknown result value: " + result.String())
|
|
}
|
|
}
|
|
|
|
type SignatureTest struct {
|
|
opts *Options
|
|
upstream *httptest.Server
|
|
upstream_host string
|
|
provider *httptest.Server
|
|
header http.Header
|
|
rw *httptest.ResponseRecorder
|
|
authenticator *SignatureAuthenticator
|
|
}
|
|
|
|
func NewSignatureTest() *SignatureTest {
|
|
opts := NewOptions()
|
|
opts.CookieSecret = "cookie secret"
|
|
opts.ClientID = "client ID"
|
|
opts.ClientSecret = "client secret"
|
|
opts.EmailDomains = []string{"acm.org"}
|
|
|
|
authenticator := &SignatureAuthenticator{}
|
|
upstream := httptest.NewServer(
|
|
http.HandlerFunc(authenticator.Authenticate))
|
|
upstream_url, _ := url.Parse(upstream.URL)
|
|
opts.Upstreams = append(opts.Upstreams, upstream.URL)
|
|
|
|
providerHandler := func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write([]byte(`{"access_token": "my_auth_token"}`))
|
|
}
|
|
provider := httptest.NewServer(http.HandlerFunc(providerHandler))
|
|
provider_url, _ := url.Parse(provider.URL)
|
|
opts.provider = NewTestProvider(provider_url, "mbland@acm.org")
|
|
|
|
return &SignatureTest{
|
|
opts,
|
|
upstream,
|
|
upstream_url.Host,
|
|
provider,
|
|
make(http.Header),
|
|
httptest.NewRecorder(),
|
|
authenticator,
|
|
}
|
|
}
|
|
|
|
func (st *SignatureTest) Close() {
|
|
st.provider.Close()
|
|
st.upstream.Close()
|
|
}
|
|
|
|
// fakeNetConn simulates an http.Request.Body buffer that will be consumed
|
|
// when it is read by the hmacauth.HmacAuth if not handled properly. See:
|
|
// https://github.com/18F/hmacauth/pull/4
|
|
type fakeNetConn struct {
|
|
reqBody string
|
|
}
|
|
|
|
func (fnc *fakeNetConn) Read(p []byte) (n int, err error) {
|
|
if bodyLen := len(fnc.reqBody); bodyLen != 0 {
|
|
copy(p, fnc.reqBody)
|
|
fnc.reqBody = ""
|
|
return bodyLen, io.EOF
|
|
}
|
|
return 0, io.EOF
|
|
}
|
|
|
|
func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
|
|
err := st.opts.Validate()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
proxy := NewOAuthProxy(st.opts, func(email string) bool { return true })
|
|
|
|
var bodyBuf io.ReadCloser
|
|
if body != "" {
|
|
bodyBuf = ioutil.NopCloser(&fakeNetConn{reqBody: body})
|
|
}
|
|
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
|
|
req.Header = st.header
|
|
|
|
state := &providers.SessionState{
|
|
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
|
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now())
|
|
req.AddCookie(cookie)
|
|
// This is used by the upstream to validate the signature.
|
|
st.authenticator.auth = hmacauth.NewHmacAuth(
|
|
crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)
|
|
proxy.ServeHTTP(st.rw, req)
|
|
}
|
|
|
|
func TestNoRequestSignature(t *testing.T) {
|
|
st := NewSignatureTest()
|
|
defer st.Close()
|
|
st.MakeRequestWithExpectedKey("GET", "", "")
|
|
assert.Equal(t, 200, st.rw.Code)
|
|
assert.Equal(t, st.rw.Body.String(), "no signature received")
|
|
}
|
|
|
|
func TestRequestSignatureGetRequest(t *testing.T) {
|
|
st := NewSignatureTest()
|
|
defer st.Close()
|
|
st.opts.SignatureKey = "sha1:foobar"
|
|
st.MakeRequestWithExpectedKey("GET", "", "foobar")
|
|
assert.Equal(t, 200, st.rw.Code)
|
|
assert.Equal(t, st.rw.Body.String(), "signatures match")
|
|
}
|
|
|
|
func TestRequestSignaturePostRequest(t *testing.T) {
|
|
st := NewSignatureTest()
|
|
defer st.Close()
|
|
st.opts.SignatureKey = "sha1:foobar"
|
|
payload := `{ "hello": "world!" }`
|
|
st.MakeRequestWithExpectedKey("POST", payload, "foobar")
|
|
assert.Equal(t, 200, st.rw.Code)
|
|
assert.Equal(t, st.rw.Body.String(), "signatures match")
|
|
}
|