diff --git a/.gitignore b/.gitignore index 50d93ea..90dbc51 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,6 @@ _testmain.go *.exe dist .godeps + +# Editor swap/temp files +.*.swp diff --git a/cookies.go b/cookies.go index 0ae6a92..9605398 100644 --- a/cookies.go +++ b/cookies.go @@ -97,3 +97,34 @@ func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (st return string(encrypted_access_token), nil } + +func buildCookieValue(email string, aes_cipher cipher.Block, + access_token string) (string, error) { + if aes_cipher == nil { + return email, nil + } + + encoded_token, err := encodeAccessToken(aes_cipher, access_token) + if err != nil { + return email, fmt.Errorf( + "error encoding access token for %s: %s", email, err) + } + return email + "|" + encoded_token, nil +} + +func parseCookieValue(value string, aes_cipher cipher.Block) (email, user, + access_token string, err error) { + components := strings.Split(value, "|") + email = components[0] + user = strings.Split(email, "@")[0] + + if aes_cipher != nil && len(components) == 2 { + access_token, err = decodeAccessToken(aes_cipher, components[1]) + if err != nil { + err = fmt.Errorf( + "error decoding access token for %s: %s", + email, err) + } + } + return email, user, access_token, err +} diff --git a/cookies_test.go b/cookies_test.go index d5470d0..44696e8 100644 --- a/cookies_test.go +++ b/cookies_test.go @@ -3,6 +3,7 @@ package main import ( "crypto/aes" "github.com/bmizerany/assert" + "strings" "testing" ) @@ -21,3 +22,54 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) { assert.NotEqual(t, access_token, encoded_token) assert.Equal(t, access_token, decoded_token) } + +func TestBuildCookieValueWithoutAccessToken(t *testing.T) { + value, err := buildCookieValue("michael.bland@gsa.gov", nil, "") + assert.Equal(t, nil, err) + assert.Equal(t, "michael.bland@gsa.gov", value) +} + +func TestBuildCookieValueWithAccessTokenAndNilCipher(t *testing.T) { + value, err := buildCookieValue("michael.bland@gsa.gov", nil, + "access token") + assert.Equal(t, nil, err) + assert.Equal(t, "michael.bland@gsa.gov", value) +} + +func TestParseCookieValueWithoutAccessToken(t *testing.T) { + email, user, access_token, err := parseCookieValue( + "michael.bland@gsa.gov", nil) + assert.Equal(t, nil, err) + assert.Equal(t, "michael.bland@gsa.gov", email) + assert.Equal(t, "michael.bland", user) + assert.Equal(t, "", access_token) +} + +func TestParseCookieValueWithAccessTokenAndNilCipher(t *testing.T) { + email, user, access_token, err := parseCookieValue( + "michael.bland@gsa.gov|access_token", nil) + assert.Equal(t, nil, err) + assert.Equal(t, "michael.bland@gsa.gov", email) + assert.Equal(t, "michael.bland", user) + assert.Equal(t, "", access_token) +} + +func TestBuildAndParseCookieValueWithAccessToken(t *testing.T) { + aes_cipher, err := aes.NewCipher([]byte("0123456789abcdef")) + assert.Equal(t, nil, err) + value, err := buildCookieValue("michael.bland@gsa.gov", aes_cipher, + "access_token") + assert.Equal(t, nil, err) + + prefix := "michael.bland@gsa.gov|" + if !strings.HasPrefix(value, prefix) { + t.Fatal("cookie value does not start with \"%s\": %s", + prefix, value) + } + + email, user, access_token, err := parseCookieValue(value, aes_cipher) + assert.Equal(t, nil, err) + assert.Equal(t, "michael.bland@gsa.gov", email) + assert.Equal(t, "michael.bland", user) + assert.Equal(t, "access_token", access_token) +} diff --git a/oauthproxy.go b/oauthproxy.go index 0b5141e..33f4698 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -47,7 +47,6 @@ type OauthProxy struct { DisplayHtpasswdForm bool serveMux http.Handler PassBasicAuth bool - PassAccessToken bool AesCipher cipher.Block skipAuthRegex []string compiledRegex []*regexp.Regexp @@ -121,20 +120,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain) var aes_cipher cipher.Block - - if opts.PassAccessToken == true { - valid_cookie_secret_size := false - for _, i := range []int{16, 24, 32} { - if len(opts.CookieSecret) == i { - valid_cookie_secret_size = true - } - } - if valid_cookie_secret_size == false { - log.Fatal("cookie_secret must be 16, 24, or 32 bytes " + - "to create an AES cipher when " + - "pass_access_token == true") - } - + if opts.PassAccessToken { var err error aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret)) if err != nil { @@ -163,7 +149,6 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { skipAuthRegex: opts.SkipAuthRegex, compiledRegex: opts.CompiledRegex, PassBasicAuth: opts.PassBasicAuth, - PassAccessToken: opts.PassAccessToken, AesCipher: aes_cipher, templates: loadTemplates(opts.CustomTemplatesDir), } @@ -440,20 +425,12 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // set cookie, or deny if p.Validator(email) { log.Printf("%s authenticating %s completed", remoteAddr, email) - encoded_token := "" - if p.PassAccessToken { - encoded_token, err = encodeAccessToken(p.AesCipher, access_token) - if err != nil { - log.Printf("error encoding access token: %s", err) - } - } - access_token = "" - - if encoded_token != "" { - p.SetCookie(rw, req, email+"|"+encoded_token) - } else { - p.SetCookie(rw, req, email) + value, err := buildCookieValue( + email, p.AesCipher, access_token) + if err != nil { + log.Printf(err.Error()) } + p.SetCookie(rw, req, value) http.Redirect(rw, req, redirect, 302) return } else { @@ -467,15 +444,13 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if err == nil { var value string value, ok = validateCookie(cookie, p.CookieSeed) - components := strings.Split(value, "|") - email = components[0] - if len(components) == 2 { - access_token, err = decodeAccessToken(p.AesCipher, components[1]) + if ok { + email, user, access_token, err = parseCookieValue( + value, p.AesCipher) if err != nil { - log.Printf("error decoding access token: %s", err) + log.Printf(err.Error()) } } - user = strings.Split(email, "@")[0] } } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 3712551..d3fe400 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -152,24 +152,25 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes return t } -func Close(t *PassAccessTokenTest) { - t.provider_server.Close() +func (pat_test *PassAccessTokenTest) Close() { + pat_test.provider_server.Close() } -func getCallbackEndpoint(pac_test *PassAccessTokenTest) (http_code int, cookie string) { +func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, + cookie string) { rw := httptest.NewRecorder() req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code", strings.NewReader("")) if err != nil { return 0, "" } - pac_test.proxy.ServeHTTP(rw, req) + pat_test.proxy.ServeHTTP(rw, req) return rw.Code, rw.HeaderMap["Set-Cookie"][0] } -func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code int, - access_token string) { - cookie_key := pac_test.proxy.CookieKey +func (pat_test *PassAccessTokenTest) getRootEndpoint( + cookie string) (http_code int, access_token string) { + cookie_key := pat_test.proxy.CookieKey var value string key_prefix := cookie_key + "=" @@ -198,43 +199,43 @@ func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code in }) rw := httptest.NewRecorder() - pac_test.proxy.ServeHTTP(rw, req) + pat_test.proxy.ServeHTTP(rw, req) return rw.Code, rw.Body.String() } func TestForwardAccessTokenUpstream(t *testing.T) { - pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, }) - defer Close(pac_test) + defer pat_test.Close() // A successful validation will redirect and set the auth cookie. - code, cookie := getCallbackEndpoint(pac_test) + code, cookie := pat_test.getCallbackEndpoint() assert.Equal(t, 302, code) assert.NotEqual(t, nil, cookie) // Now we make a regular request; the access_token from the cookie is // forwarded as the "X-Forwarded-Access-Token" header. The token is // read by the test provider server and written in the response body. - code, payload := getRootEndpoint(pac_test, cookie) + code, payload := pat_test.getRootEndpoint(cookie) assert.Equal(t, 200, code) assert.Equal(t, "my_auth_token", payload) } func TestDoNotForwardAccessTokenUpstream(t *testing.T) { - pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: false, }) - defer Close(pac_test) + defer pat_test.Close() // A successful validation will redirect and set the auth cookie. - code, cookie := getCallbackEndpoint(pac_test) + code, cookie := pat_test.getCallbackEndpoint() assert.Equal(t, 302, code) assert.NotEqual(t, nil, cookie) // Now we make a regular request, but the access token header should // not be present. - code, payload := getRootEndpoint(pac_test, cookie) + code, payload := pat_test.getRootEndpoint(cookie) assert.Equal(t, 200, code) assert.Equal(t, "No access token found.", payload) } diff --git a/options.go b/options.go index e02cfb8..bbfd466 100644 --- a/options.go +++ b/options.go @@ -117,6 +117,23 @@ func (o *Options) Validate() error { } msgs = parseProviderInfo(o, msgs) + if o.PassAccessToken { + valid_cookie_secret_size := false + for _, i := range []int{16, 24, 32} { + if len(o.CookieSecret) == i { + valid_cookie_secret_size = true + } + } + if valid_cookie_secret_size == false { + msgs = append(msgs, fmt.Sprintf( + "cookie_secret must be 16, 24, or 32 bytes "+ + "to create an AES cipher when "+ + "pass_access_token == true, "+ + "but is %d bytes", + len(o.CookieSecret))) + } + } + if len(msgs) != 0 { return fmt.Errorf("Invalid configuration:\n %s", strings.Join(msgs, "\n ")) diff --git a/options_test.go b/options_test.go index 515c1c8..dcb5421 100644 --- a/options_test.go +++ b/options_test.go @@ -102,3 +102,22 @@ func TestDefaultProviderApiSettings(t *testing.T) { assert.Equal(t, "", p.ProfileUrl.String()) assert.Equal(t, "profile email", p.Scope) } + +func TestPassAccessTokenRequiresSpecificCookieSecretLengths(t *testing.T) { + o := testOptions() + assert.Equal(t, nil, o.Validate()) + + assert.Equal(t, false, o.PassAccessToken) + o.PassAccessToken = true + o.CookieSecret = "cookie of invalid length-" + assert.NotEqual(t, nil, o.Validate()) + + o.CookieSecret = "16 bytes AES-128" + assert.Equal(t, nil, o.Validate()) + + o.CookieSecret = "24 byte secret AES-192--" + assert.Equal(t, nil, o.Validate()) + + o.CookieSecret = "32 byte secret for AES-256------" + assert.Equal(t, nil, o.Validate()) +}