Merge pull request #81 from 18F/access-token-refactor
Refactor pass_access_token changes from #80
This commit is contained in:
commit
9534808a0d
3
.gitignore
vendored
3
.gitignore
vendored
@ -24,3 +24,6 @@ _testmain.go
|
|||||||
*.exe
|
*.exe
|
||||||
dist
|
dist
|
||||||
.godeps
|
.godeps
|
||||||
|
|
||||||
|
# Editor swap/temp files
|
||||||
|
.*.swp
|
||||||
|
31
cookies.go
31
cookies.go
@ -97,3 +97,34 @@ func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (st
|
|||||||
|
|
||||||
return string(encrypted_access_token), nil
|
return string(encrypted_access_token), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func buildCookieValue(email string, aes_cipher cipher.Block,
|
||||||
|
access_token string) (string, error) {
|
||||||
|
if aes_cipher == nil {
|
||||||
|
return email, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded_token, err := encodeAccessToken(aes_cipher, access_token)
|
||||||
|
if err != nil {
|
||||||
|
return email, fmt.Errorf(
|
||||||
|
"error encoding access token for %s: %s", email, err)
|
||||||
|
}
|
||||||
|
return email + "|" + encoded_token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCookieValue(value string, aes_cipher cipher.Block) (email, user,
|
||||||
|
access_token string, err error) {
|
||||||
|
components := strings.Split(value, "|")
|
||||||
|
email = components[0]
|
||||||
|
user = strings.Split(email, "@")[0]
|
||||||
|
|
||||||
|
if aes_cipher != nil && len(components) == 2 {
|
||||||
|
access_token, err = decodeAccessToken(aes_cipher, components[1])
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf(
|
||||||
|
"error decoding access token for %s: %s",
|
||||||
|
email, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return email, user, access_token, err
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@ package main
|
|||||||
import (
|
import (
|
||||||
"crypto/aes"
|
"crypto/aes"
|
||||||
"github.com/bmizerany/assert"
|
"github.com/bmizerany/assert"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -21,3 +22,54 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) {
|
|||||||
assert.NotEqual(t, access_token, encoded_token)
|
assert.NotEqual(t, access_token, encoded_token)
|
||||||
assert.Equal(t, access_token, decoded_token)
|
assert.Equal(t, access_token, decoded_token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildCookieValueWithoutAccessToken(t *testing.T) {
|
||||||
|
value, err := buildCookieValue("michael.bland@gsa.gov", nil, "")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "michael.bland@gsa.gov", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCookieValueWithAccessTokenAndNilCipher(t *testing.T) {
|
||||||
|
value, err := buildCookieValue("michael.bland@gsa.gov", nil,
|
||||||
|
"access token")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "michael.bland@gsa.gov", value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCookieValueWithoutAccessToken(t *testing.T) {
|
||||||
|
email, user, access_token, err := parseCookieValue(
|
||||||
|
"michael.bland@gsa.gov", nil)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||||
|
assert.Equal(t, "michael.bland", user)
|
||||||
|
assert.Equal(t, "", access_token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCookieValueWithAccessTokenAndNilCipher(t *testing.T) {
|
||||||
|
email, user, access_token, err := parseCookieValue(
|
||||||
|
"michael.bland@gsa.gov|access_token", nil)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||||
|
assert.Equal(t, "michael.bland", user)
|
||||||
|
assert.Equal(t, "", access_token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildAndParseCookieValueWithAccessToken(t *testing.T) {
|
||||||
|
aes_cipher, err := aes.NewCipher([]byte("0123456789abcdef"))
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
value, err := buildCookieValue("michael.bland@gsa.gov", aes_cipher,
|
||||||
|
"access_token")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
|
prefix := "michael.bland@gsa.gov|"
|
||||||
|
if !strings.HasPrefix(value, prefix) {
|
||||||
|
t.Fatal("cookie value does not start with \"%s\": %s",
|
||||||
|
prefix, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
email, user, access_token, err := parseCookieValue(value, aes_cipher)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||||
|
assert.Equal(t, "michael.bland", user)
|
||||||
|
assert.Equal(t, "access_token", access_token)
|
||||||
|
}
|
||||||
|
@ -47,7 +47,6 @@ type OauthProxy struct {
|
|||||||
DisplayHtpasswdForm bool
|
DisplayHtpasswdForm bool
|
||||||
serveMux http.Handler
|
serveMux http.Handler
|
||||||
PassBasicAuth bool
|
PassBasicAuth bool
|
||||||
PassAccessToken bool
|
|
||||||
AesCipher cipher.Block
|
AesCipher cipher.Block
|
||||||
skipAuthRegex []string
|
skipAuthRegex []string
|
||||||
compiledRegex []*regexp.Regexp
|
compiledRegex []*regexp.Regexp
|
||||||
@ -121,20 +120,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
|
|||||||
log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain)
|
log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain)
|
||||||
|
|
||||||
var aes_cipher cipher.Block
|
var aes_cipher cipher.Block
|
||||||
|
if opts.PassAccessToken {
|
||||||
if opts.PassAccessToken == true {
|
|
||||||
valid_cookie_secret_size := false
|
|
||||||
for _, i := range []int{16, 24, 32} {
|
|
||||||
if len(opts.CookieSecret) == i {
|
|
||||||
valid_cookie_secret_size = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if valid_cookie_secret_size == false {
|
|
||||||
log.Fatal("cookie_secret must be 16, 24, or 32 bytes " +
|
|
||||||
"to create an AES cipher when " +
|
|
||||||
"pass_access_token == true")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret))
|
aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -163,7 +149,6 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
|
|||||||
skipAuthRegex: opts.SkipAuthRegex,
|
skipAuthRegex: opts.SkipAuthRegex,
|
||||||
compiledRegex: opts.CompiledRegex,
|
compiledRegex: opts.CompiledRegex,
|
||||||
PassBasicAuth: opts.PassBasicAuth,
|
PassBasicAuth: opts.PassBasicAuth,
|
||||||
PassAccessToken: opts.PassAccessToken,
|
|
||||||
AesCipher: aes_cipher,
|
AesCipher: aes_cipher,
|
||||||
templates: loadTemplates(opts.CustomTemplatesDir),
|
templates: loadTemplates(opts.CustomTemplatesDir),
|
||||||
}
|
}
|
||||||
@ -440,20 +425,12 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
// set cookie, or deny
|
// set cookie, or deny
|
||||||
if p.Validator(email) {
|
if p.Validator(email) {
|
||||||
log.Printf("%s authenticating %s completed", remoteAddr, email)
|
log.Printf("%s authenticating %s completed", remoteAddr, email)
|
||||||
encoded_token := ""
|
value, err := buildCookieValue(
|
||||||
if p.PassAccessToken {
|
email, p.AesCipher, access_token)
|
||||||
encoded_token, err = encodeAccessToken(p.AesCipher, access_token)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error encoding access token: %s", err)
|
log.Printf(err.Error())
|
||||||
}
|
|
||||||
}
|
|
||||||
access_token = ""
|
|
||||||
|
|
||||||
if encoded_token != "" {
|
|
||||||
p.SetCookie(rw, req, email+"|"+encoded_token)
|
|
||||||
} else {
|
|
||||||
p.SetCookie(rw, req, email)
|
|
||||||
}
|
}
|
||||||
|
p.SetCookie(rw, req, value)
|
||||||
http.Redirect(rw, req, redirect, 302)
|
http.Redirect(rw, req, redirect, 302)
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
@ -467,15 +444,13 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
var value string
|
var value string
|
||||||
value, ok = validateCookie(cookie, p.CookieSeed)
|
value, ok = validateCookie(cookie, p.CookieSeed)
|
||||||
components := strings.Split(value, "|")
|
if ok {
|
||||||
email = components[0]
|
email, user, access_token, err = parseCookieValue(
|
||||||
if len(components) == 2 {
|
value, p.AesCipher)
|
||||||
access_token, err = decodeAccessToken(p.AesCipher, components[1])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error decoding access token: %s", err)
|
log.Printf(err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
user = strings.Split(email, "@")[0]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,24 +152,25 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
|
|||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
func Close(t *PassAccessTokenTest) {
|
func (pat_test *PassAccessTokenTest) Close() {
|
||||||
t.provider_server.Close()
|
pat_test.provider_server.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCallbackEndpoint(pac_test *PassAccessTokenTest) (http_code int, cookie string) {
|
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
|
||||||
|
cookie string) {
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
|
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
|
||||||
strings.NewReader(""))
|
strings.NewReader(""))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, ""
|
return 0, ""
|
||||||
}
|
}
|
||||||
pac_test.proxy.ServeHTTP(rw, req)
|
pat_test.proxy.ServeHTTP(rw, req)
|
||||||
return rw.Code, rw.HeaderMap["Set-Cookie"][0]
|
return rw.Code, rw.HeaderMap["Set-Cookie"][0]
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code int,
|
func (pat_test *PassAccessTokenTest) getRootEndpoint(
|
||||||
access_token string) {
|
cookie string) (http_code int, access_token string) {
|
||||||
cookie_key := pac_test.proxy.CookieKey
|
cookie_key := pat_test.proxy.CookieKey
|
||||||
var value string
|
var value string
|
||||||
key_prefix := cookie_key + "="
|
key_prefix := cookie_key + "="
|
||||||
|
|
||||||
@ -198,43 +199,43 @@ func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code in
|
|||||||
})
|
})
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
pac_test.proxy.ServeHTTP(rw, req)
|
pat_test.proxy.ServeHTTP(rw, req)
|
||||||
return rw.Code, rw.Body.String()
|
return rw.Code, rw.Body.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestForwardAccessTokenUpstream(t *testing.T) {
|
func TestForwardAccessTokenUpstream(t *testing.T) {
|
||||||
pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
||||||
PassAccessToken: true,
|
PassAccessToken: true,
|
||||||
})
|
})
|
||||||
defer Close(pac_test)
|
defer pat_test.Close()
|
||||||
|
|
||||||
// A successful validation will redirect and set the auth cookie.
|
// A successful validation will redirect and set the auth cookie.
|
||||||
code, cookie := getCallbackEndpoint(pac_test)
|
code, cookie := pat_test.getCallbackEndpoint()
|
||||||
assert.Equal(t, 302, code)
|
assert.Equal(t, 302, code)
|
||||||
assert.NotEqual(t, nil, cookie)
|
assert.NotEqual(t, nil, cookie)
|
||||||
|
|
||||||
// Now we make a regular request; the access_token from the cookie is
|
// Now we make a regular request; the access_token from the cookie is
|
||||||
// forwarded as the "X-Forwarded-Access-Token" header. The token is
|
// forwarded as the "X-Forwarded-Access-Token" header. The token is
|
||||||
// read by the test provider server and written in the response body.
|
// read by the test provider server and written in the response body.
|
||||||
code, payload := getRootEndpoint(pac_test, cookie)
|
code, payload := pat_test.getRootEndpoint(cookie)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, "my_auth_token", payload)
|
assert.Equal(t, "my_auth_token", payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
||||||
pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
||||||
PassAccessToken: false,
|
PassAccessToken: false,
|
||||||
})
|
})
|
||||||
defer Close(pac_test)
|
defer pat_test.Close()
|
||||||
|
|
||||||
// A successful validation will redirect and set the auth cookie.
|
// A successful validation will redirect and set the auth cookie.
|
||||||
code, cookie := getCallbackEndpoint(pac_test)
|
code, cookie := pat_test.getCallbackEndpoint()
|
||||||
assert.Equal(t, 302, code)
|
assert.Equal(t, 302, code)
|
||||||
assert.NotEqual(t, nil, cookie)
|
assert.NotEqual(t, nil, cookie)
|
||||||
|
|
||||||
// Now we make a regular request, but the access token header should
|
// Now we make a regular request, but the access token header should
|
||||||
// not be present.
|
// not be present.
|
||||||
code, payload := getRootEndpoint(pac_test, cookie)
|
code, payload := pat_test.getRootEndpoint(cookie)
|
||||||
assert.Equal(t, 200, code)
|
assert.Equal(t, 200, code)
|
||||||
assert.Equal(t, "No access token found.", payload)
|
assert.Equal(t, "No access token found.", payload)
|
||||||
}
|
}
|
||||||
|
17
options.go
17
options.go
@ -117,6 +117,23 @@ func (o *Options) Validate() error {
|
|||||||
}
|
}
|
||||||
msgs = parseProviderInfo(o, msgs)
|
msgs = parseProviderInfo(o, msgs)
|
||||||
|
|
||||||
|
if o.PassAccessToken {
|
||||||
|
valid_cookie_secret_size := false
|
||||||
|
for _, i := range []int{16, 24, 32} {
|
||||||
|
if len(o.CookieSecret) == i {
|
||||||
|
valid_cookie_secret_size = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if valid_cookie_secret_size == false {
|
||||||
|
msgs = append(msgs, fmt.Sprintf(
|
||||||
|
"cookie_secret must be 16, 24, or 32 bytes "+
|
||||||
|
"to create an AES cipher when "+
|
||||||
|
"pass_access_token == true, "+
|
||||||
|
"but is %d bytes",
|
||||||
|
len(o.CookieSecret)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if len(msgs) != 0 {
|
if len(msgs) != 0 {
|
||||||
return fmt.Errorf("Invalid configuration:\n %s",
|
return fmt.Errorf("Invalid configuration:\n %s",
|
||||||
strings.Join(msgs, "\n "))
|
strings.Join(msgs, "\n "))
|
||||||
|
@ -102,3 +102,22 @@ func TestDefaultProviderApiSettings(t *testing.T) {
|
|||||||
assert.Equal(t, "", p.ProfileUrl.String())
|
assert.Equal(t, "", p.ProfileUrl.String())
|
||||||
assert.Equal(t, "profile email", p.Scope)
|
assert.Equal(t, "profile email", p.Scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPassAccessTokenRequiresSpecificCookieSecretLengths(t *testing.T) {
|
||||||
|
o := testOptions()
|
||||||
|
assert.Equal(t, nil, o.Validate())
|
||||||
|
|
||||||
|
assert.Equal(t, false, o.PassAccessToken)
|
||||||
|
o.PassAccessToken = true
|
||||||
|
o.CookieSecret = "cookie of invalid length-"
|
||||||
|
assert.NotEqual(t, nil, o.Validate())
|
||||||
|
|
||||||
|
o.CookieSecret = "16 bytes AES-128"
|
||||||
|
assert.Equal(t, nil, o.Validate())
|
||||||
|
|
||||||
|
o.CookieSecret = "24 byte secret AES-192--"
|
||||||
|
assert.Equal(t, nil, o.Validate())
|
||||||
|
|
||||||
|
o.CookieSecret = "32 byte secret for AES-256------"
|
||||||
|
assert.Equal(t, nil, o.Validate())
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user