Merge pull request #88 from 18F/auto-refresh

Auto refresh auth token
This commit is contained in:
Jehiah Czebotar 2015-05-11 22:24:50 -04:00
commit 9047920e90
13 changed files with 406 additions and 50 deletions

View File

@ -77,6 +77,7 @@ Usage of google_auth_proxy:
-cookie-expire=168h0m0s: expire timeframe for cookie -cookie-expire=168h0m0s: expire timeframe for cookie
-cookie-httponly=true: set HttpOnly cookie flag -cookie-httponly=true: set HttpOnly cookie flag
-cookie-https-only=true: set secure (HTTPS) cookies (deprecated. use --cookie-secure setting) -cookie-https-only=true: set secure (HTTPS) cookies (deprecated. use --cookie-secure setting)
-cookie-refresh=0: refresh the cookie when less than this much time remains before expiration; 0 to disable
-cookie-secret="": the seed string for secure cookies -cookie-secret="": the seed string for secure cookies
-cookie-secure=true: set secure (HTTPS) cookie flag -cookie-secure=true: set secure (HTTPS) cookie flag
-custom-templates-dir="": path to custom html templates -custom-templates-dir="": path to custom html templates
@ -96,6 +97,7 @@ Usage of google_auth_proxy:
-scope="": Oauth scope specification -scope="": Oauth scope specification
-skip-auth-regex=: bypass authentication for requests path's that match (may be given multiple times) -skip-auth-regex=: bypass authentication for requests path's that match (may be given multiple times)
-upstream=: the http url(s) of the upstream endpoint. If multiple, routing is based on path -upstream=: the http url(s) of the upstream endpoint. If multiple, routing is based on path
-validate-url="": Access token validation endpoint
-version=false: print version string -version=false: print version string
``` ```

View File

@ -46,12 +46,17 @@
## Cookie Settings ## Cookie Settings
## Secret - the seed string for secure cookies ## Secret - the seed string for secure cookies; should be 16, 24, or 32 bytes
## for use with an AES cipher when cookie_refresh or pass_access_code
## is set
## Domain - optional cookie domain to force cookies to (ie: .yourcompany.com) ## Domain - optional cookie domain to force cookies to (ie: .yourcompany.com)
## Expire - expire timeframe for cookie ## Expire - expire timeframe for cookie
## Refresh - refresh the cookie when less than this much time remains before
## expiration; should be less than cookie_expire; set to 0 to disable
# cookie_secret = "" # cookie_secret = ""
# cookie_domain = "" # cookie_domain = ""
# cookie_expire = "168h" # cookie_expire = "168h"
# cookie_refresh = "144h"
# cookie_secure = true # cookie_secure = true
# cookie_httponly = true # cookie_httponly = true
# pass_access_code = true

View File

@ -15,11 +15,11 @@ import (
"time" "time"
) )
func validateCookie(cookie *http.Cookie, seed string) (string, bool) { func validateCookie(cookie *http.Cookie, seed string) (string, time.Time, bool) {
// value, timestamp, sig // value, timestamp, sig
parts := strings.Split(cookie.Value, "|") parts := strings.Split(cookie.Value, "|")
if len(parts) != 3 { if len(parts) != 3 {
return "", false return "", time.Unix(0, 0), false
} }
sig := cookieSignature(seed, cookie.Name, parts[0], parts[1]) sig := cookieSignature(seed, cookie.Name, parts[0], parts[1])
if checkHmac(parts[2], sig) { if checkHmac(parts[2], sig) {
@ -28,11 +28,11 @@ func validateCookie(cookie *http.Cookie, seed string) (string, bool) {
// it's a valid cookie. now get the contents // it's a valid cookie. now get the contents
rawValue, err := base64.URLEncoding.DecodeString(parts[0]) rawValue, err := base64.URLEncoding.DecodeString(parts[0])
if err == nil { if err == nil {
return string(rawValue), true return string(rawValue), time.Unix(int64(ts), 0), true
} }
} }
} }
return "", false return "", time.Unix(0, 0), false
} }
func signedCookieValue(seed string, key string, value string) string { func signedCookieValue(seed string, key string, value string) string {

View File

@ -45,6 +45,7 @@ func main() {
flagSet.String("cookie-secret", "", "the seed string for secure cookies") flagSet.String("cookie-secret", "", "the seed string for secure cookies")
flagSet.String("cookie-domain", "", "an optional cookie domain to force cookies to (ie: .yourcompany.com)*") flagSet.String("cookie-domain", "", "an optional cookie domain to force cookies to (ie: .yourcompany.com)*")
flagSet.Duration("cookie-expire", time.Duration(168)*time.Hour, "expire timeframe for cookie") flagSet.Duration("cookie-expire", time.Duration(168)*time.Hour, "expire timeframe for cookie")
flagSet.Duration("cookie-refresh", time.Duration(0)*time.Hour, "refresh the cookie when less than this much time remains before expiration; 0 to disable")
flagSet.Bool("cookie-https-only", true, "set secure (HTTPS) cookies (deprecated. use --cookie-secure setting)") flagSet.Bool("cookie-https-only", true, "set secure (HTTPS) cookies (deprecated. use --cookie-secure setting)")
flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag") flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag")
flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag") flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag")
@ -55,6 +56,7 @@ func main() {
flagSet.String("login-url", "", "Authentication endpoint") flagSet.String("login-url", "", "Authentication endpoint")
flagSet.String("redeem-url", "", "Token redemption endpoint") flagSet.String("redeem-url", "", "Token redemption endpoint")
flagSet.String("profile-url", "", "Profile access endpoint") flagSet.String("profile-url", "", "Profile access endpoint")
flagSet.String("validate-url", "", "Access token validation endpoint")
flagSet.String("scope", "", "Oauth scope specification") flagSet.String("scope", "", "Oauth scope specification")
flagSet.Parse(os.Args[1:]) flagSet.Parse(os.Args[1:])

View File

@ -34,12 +34,14 @@ type OauthProxy struct {
CookieSecure bool CookieSecure bool
CookieHttpOnly bool CookieHttpOnly bool
CookieExpire time.Duration CookieExpire time.Duration
CookieRefresh time.Duration
Validator func(string) bool Validator func(string) bool
redirectUrl *url.URL // the url to receive requests at redirectUrl *url.URL // the url to receive requests at
provider providers.Provider provider providers.Provider
oauthRedemptionUrl *url.URL // endpoint to redeem the code oauthRedemptionUrl *url.URL // endpoint to redeem the code
oauthLoginUrl *url.URL // to redirect the user to oauthLoginUrl *url.URL // to redirect the user to
oauthValidateUrl *url.URL // to validate the access token
oauthScope string oauthScope string
clientID string clientID string
clientSecret string clientSecret string
@ -48,6 +50,7 @@ type OauthProxy struct {
DisplayHtpasswdForm bool DisplayHtpasswdForm bool
serveMux http.Handler serveMux http.Handler
PassBasicAuth bool PassBasicAuth bool
PassAccessToken bool
AesCipher cipher.Block AesCipher cipher.Block
skipAuthRegex []string skipAuthRegex []string
compiledRegex []*regexp.Regexp compiledRegex []*regexp.Regexp
@ -121,12 +124,12 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain) log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain)
var aes_cipher cipher.Block var aes_cipher cipher.Block
if opts.PassAccessToken { if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
var err error var err error
aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret)) aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret))
if err != nil { if err != nil {
log.Fatal("error creating AES cipher with "+ log.Fatal("error creating AES cipher with "+
"pass_access_token == true: %s", err) "cookie-secret ", opts.CookieSecret, ": ", err)
} }
} }
@ -137,6 +140,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
CookieSecure: opts.CookieSecure, CookieSecure: opts.CookieSecure,
CookieHttpOnly: opts.CookieHttpOnly, CookieHttpOnly: opts.CookieHttpOnly,
CookieExpire: opts.CookieExpire, CookieExpire: opts.CookieExpire,
CookieRefresh: opts.CookieRefresh,
Validator: validator, Validator: validator,
clientID: opts.ClientID, clientID: opts.ClientID,
@ -145,11 +149,13 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
provider: opts.provider, provider: opts.provider,
oauthRedemptionUrl: opts.provider.Data().RedeemUrl, oauthRedemptionUrl: opts.provider.Data().RedeemUrl,
oauthLoginUrl: opts.provider.Data().LoginUrl, oauthLoginUrl: opts.provider.Data().LoginUrl,
oauthValidateUrl: opts.provider.Data().ValidateUrl,
serveMux: serveMux, serveMux: serveMux,
redirectUrl: redirectUrl, redirectUrl: redirectUrl,
skipAuthRegex: opts.SkipAuthRegex, skipAuthRegex: opts.SkipAuthRegex,
compiledRegex: opts.CompiledRegex, compiledRegex: opts.CompiledRegex,
PassBasicAuth: opts.PassBasicAuth, PassBasicAuth: opts.PassBasicAuth,
PassAccessToken: opts.PassAccessToken,
AesCipher: aes_cipher, AesCipher: aes_cipher,
templates: loadTemplates(opts.CustomTemplatesDir), templates: loadTemplates(opts.CustomTemplatesDir),
} }
@ -224,7 +230,7 @@ func (p *OauthProxy) redeemCode(host, code string) (string, string, error) {
return access_token, email, nil return access_token, email, nil
} }
func (p *OauthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) { func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration) *http.Cookie {
domain := req.Host domain := req.Host
if h, _, err := net.SplitHostPort(domain); err == nil { if h, _, err := net.SplitHostPort(domain); err == nil {
domain = h domain = h
@ -235,40 +241,76 @@ func (p *OauthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) {
} }
domain = p.CookieDomain domain = p.CookieDomain
} }
cookie := &http.Cookie{
if value != "" {
value = signedCookieValue(p.CookieSeed, p.CookieKey, value)
}
return &http.Cookie{
Name: p.CookieKey, Name: p.CookieKey,
Value: "", Value: value,
Path: "/", Path: "/",
Domain: domain, Domain: domain,
HttpOnly: p.CookieHttpOnly, HttpOnly: p.CookieHttpOnly,
Secure: p.CookieSecure, Secure: p.CookieSecure,
Expires: time.Now().Add(time.Duration(1) * time.Hour * -1), Expires: time.Now().Add(expiration),
} }
http.SetCookie(rw, cookie) }
func (p *OauthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, p.MakeCookie(req, "", time.Duration(1)*time.Hour*-1))
} }
func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) { func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire))
}
domain := req.Host func (p *OauthProxy) ValidateToken(access_token string) bool {
if h, _, err := net.SplitHostPort(domain); err == nil { if access_token == "" || p.oauthValidateUrl == nil {
domain = h return false
} }
if p.CookieDomain != "" {
if !strings.HasSuffix(domain, p.CookieDomain) { req, err := http.NewRequest("GET",
log.Printf("Warning: request host is %q but using configured cookie domain of %q", domain, p.CookieDomain) p.oauthValidateUrl.String()+"?access_token="+access_token, nil)
if err != nil {
log.Printf("failed building token validation request: %s", err)
return false
}
httpclient := &http.Client{}
resp, err := httpclient.Do(req)
if err != nil {
log.Printf("token validation request failed: %s", err)
return false
}
return resp.StatusCode == 200
}
func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (email, user, access_token string, ok bool) {
var value string
var timestamp time.Time
cookie, err := req.Cookie(p.CookieKey)
if err == nil {
value, timestamp, ok = validateCookie(cookie, p.CookieSeed)
if ok {
email, user, access_token, err = parseCookieValue(
value, p.AesCipher)
} }
domain = p.CookieDomain
} }
cookie := &http.Cookie{ if err != nil {
Name: p.CookieKey, log.Printf(err.Error())
Value: signedCookieValue(p.CookieSeed, p.CookieKey, val), ok = false
Path: "/", } else if p.CookieRefresh != time.Duration(0) {
Domain: domain, expires := timestamp.Add(p.CookieExpire)
HttpOnly: p.CookieHttpOnly, refresh_threshold := time.Now().Add(p.CookieRefresh)
Secure: p.CookieSecure, if refresh_threshold.Unix() > expires.Unix() {
Expires: time.Now().Add(p.CookieExpire), ok = p.Validator(email) && p.ValidateToken(access_token)
if ok {
p.SetCookie(rw, req, value)
}
}
} }
http.SetCookie(rw, cookie) return
} }
func (p *OauthProxy) RobotsTxt(rw http.ResponseWriter) { func (p *OauthProxy) RobotsTxt(rw http.ResponseWriter) {
@ -451,18 +493,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
} }
if !ok { if !ok {
cookie, err := req.Cookie(p.CookieKey) email, user, access_token, ok = p.ProcessCookie(rw, req)
if err == nil {
var value string
value, ok = validateCookie(cookie, p.CookieSeed)
if ok {
email, user, access_token, err = parseCookieValue(
value, p.AesCipher)
if err != nil {
log.Printf(err.Error())
}
}
}
} }
if !ok { if !ok {
@ -480,7 +511,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
req.Header["X-Forwarded-User"] = []string{user} req.Header["X-Forwarded-User"] = []string{user}
req.Header["X-Forwarded-Email"] = []string{email} req.Header["X-Forwarded-Email"] = []string{email}
} }
if access_token != "" { if p.PassAccessToken {
req.Header["X-Forwarded-Access-Token"] = []string{access_token} req.Header["X-Forwarded-Access-Token"] = []string{access_token}
} }
if email == "" { if email == "" {

View File

@ -321,3 +321,260 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`) t.Fatal(`expected redirect to "/", but was "` + match[1] + `"`)
} }
} }
type ValidateTokenTest struct {
opts *Options
proxy *OauthProxy
backend *httptest.Server
response_code int
}
func NewValidateTokenTest() *ValidateTokenTest {
var vt_test ValidateTokenTest
vt_test.opts = NewOptions()
vt_test.opts.Upstreams = append(vt_test.opts.Upstreams, "unused")
vt_test.opts.CookieSecret = "foobar"
vt_test.opts.ClientID = "bazquux"
vt_test.opts.ClientSecret = "xyzzyplugh"
vt_test.opts.Validate()
vt_test.backend = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/oauth/tokeninfo":
w.WriteHeader(vt_test.response_code)
w.Write([]byte("only code matters; contents disregarded"))
default:
w.WriteHeader(500)
w.Write([]byte("unknown URL"))
}
}))
backend_url, _ := url.Parse(vt_test.backend.URL)
vt_test.opts.provider.Data().ValidateUrl = &url.URL{
Scheme: "http",
Host: backend_url.Host,
Path: "/oauth/tokeninfo",
}
vt_test.response_code = 200
vt_test.proxy = NewOauthProxy(vt_test.opts, func(email string) bool {
return true
})
return &vt_test
}
func (vt_test *ValidateTokenTest) Close() {
vt_test.backend.Close()
}
func TestValidateTokenEmptyToken(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
assert.Equal(t, false, vt_test.proxy.ValidateToken(""))
}
func TestValidateTokenEmptyValidateUrl(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
vt_test.proxy.oauthValidateUrl = nil
assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar"))
}
func TestValidateTokenRequestNetworkFailure(t *testing.T) {
vt_test := NewValidateTokenTest()
// Close immediately to simulate a network failure
vt_test.Close()
assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar"))
}
func TestValidateTokenExpiredToken(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
vt_test.response_code = 401
assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar"))
}
func TestValidateTokenValidToken(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
assert.Equal(t, true, vt_test.proxy.ValidateToken("foobar"))
}
type ProcessCookieTest struct {
opts *Options
proxy *OauthProxy
rw *httptest.ResponseRecorder
req *http.Request
backend *httptest.Server
response_code int
validate_user bool
}
func NewProcessCookieTest() *ProcessCookieTest {
var pc_test ProcessCookieTest
pc_test.opts = NewOptions()
pc_test.opts.Upstreams = append(pc_test.opts.Upstreams, "unused")
pc_test.opts.CookieSecret = "foobar"
pc_test.opts.ClientID = "bazquux"
pc_test.opts.ClientSecret = "xyzzyplugh"
pc_test.opts.CookieSecret = "0123456789abcdef"
// First, set the CookieRefresh option so proxy.AesCipher is created,
// needed to encrypt the access_token.
pc_test.opts.CookieRefresh = time.Duration(24) * time.Hour
pc_test.opts.Validate()
pc_test.proxy = NewOauthProxy(pc_test.opts, func(email string) bool {
return pc_test.validate_user
})
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation.
pc_test.proxy.CookieRefresh = time.Duration(0)
pc_test.rw = httptest.NewRecorder()
pc_test.req, _ = http.NewRequest("GET", "/", strings.NewReader(""))
pc_test.validate_user = true
return &pc_test
}
func (p *ProcessCookieTest) InstantiateBackend() {
p.backend = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(p.response_code)
}))
backend_url, _ := url.Parse(p.backend.URL)
p.proxy.oauthValidateUrl = &url.URL{
Scheme: "http",
Host: backend_url.Host,
Path: "/oauth/tokeninfo",
}
p.response_code = 200
}
func (p *ProcessCookieTest) Close() {
p.backend.Close()
}
func (p *ProcessCookieTest) MakeCookie(value, access_token string) *http.Cookie {
cookie_value, _ := buildCookieValue(
value, p.proxy.AesCipher, access_token)
return p.proxy.MakeCookie(p.req, cookie_value, p.opts.CookieExpire)
}
func (p *ProcessCookieTest) AddCookie(value, access_token string) {
p.req.AddCookie(p.MakeCookie(value, access_token))
}
func (p *ProcessCookieTest) ProcessCookie() (email, user, access_token string, ok bool) {
return p.proxy.ProcessCookie(p.rw, p.req)
}
func TestProcessCookie(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.AddCookie("michael.bland@gsa.gov", "my_access_token")
email, user, access_token, ok := pc_test.ProcessCookie()
assert.Equal(t, true, ok)
assert.Equal(t, "michael.bland@gsa.gov", email)
assert.Equal(t, "michael.bland", user)
assert.Equal(t, "my_access_token", access_token)
}
func TestProcessCookieNoCookieError(t *testing.T) {
pc_test := NewProcessCookieTest()
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, false, ok)
}
func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) {
pc_test := NewProcessCookieTest()
value, _ := buildCookieValue("michael.bland@gsa.gov",
pc_test.proxy.AesCipher, "my_access_token")
pc_test.req.AddCookie(pc_test.proxy.MakeCookie(
pc_test.req, value+"some bogus bytes",
pc_test.opts.CookieExpire))
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, false, ok)
}
func TestProcessCookieRefreshNotSet(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.InstantiateBackend()
defer pc_test.Close()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "")
pc_test.req.AddCookie(cookie)
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, true, ok)
assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
}
func TestProcessCookieRefresh(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.InstantiateBackend()
defer pc_test.Close()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
pc_test.req.AddCookie(cookie)
pc_test.proxy.CookieRefresh = time.Duration(24) * time.Hour
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, true, ok)
assert.NotEqual(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
}
func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.InstantiateBackend()
defer pc_test.Close()
pc_test.proxy.CookieExpire = time.Duration(25) * time.Hour
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
pc_test.req.AddCookie(cookie)
pc_test.proxy.CookieRefresh = time.Duration(24) * time.Hour
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, true, ok)
assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
}
func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.InstantiateBackend()
defer pc_test.Close()
pc_test.response_code = 401
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
pc_test.req.AddCookie(cookie)
pc_test.proxy.CookieRefresh = time.Duration(24) * time.Hour
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, false, ok)
assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
}
func TestProcessCookieFailIfRefreshSetAndUserNoLongerValid(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.InstantiateBackend()
defer pc_test.Close()
pc_test.validate_user = false
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
pc_test.req.AddCookie(cookie)
pc_test.proxy.CookieRefresh = time.Duration(24) * time.Hour
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, false, ok)
assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
}

View File

@ -26,6 +26,7 @@ type Options struct {
CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"GOOGLE_AUTH_PROXY_COOKIE_SECRET"` CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"GOOGLE_AUTH_PROXY_COOKIE_SECRET"`
CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"GOOGLE_AUTH_PROXY_COOKIE_DOMAIN"` CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"GOOGLE_AUTH_PROXY_COOKIE_DOMAIN"`
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"GOOGLE_AUTH_PROXY_COOKIE_EXPIRE"` CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"GOOGLE_AUTH_PROXY_COOKIE_EXPIRE"`
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"GOOGLE_AUTH_PROXY_COOKIE_REFRESH"`
CookieHttpsOnly bool `flag:"cookie-https-only" cfg:"cookie_https_only"` // deprecated use cookie-secure CookieHttpsOnly bool `flag:"cookie-https-only" cfg:"cookie_https_only"` // deprecated use cookie-secure
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"` CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"`
CookieHttpOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` CookieHttpOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"`
@ -38,11 +39,12 @@ type Options struct {
// These options allow for other providers besides Google, with // These options allow for other providers besides Google, with
// potential overrides. // potential overrides.
Provider string `flag:"provider" cfg:"provider"` Provider string `flag:"provider" cfg:"provider"`
LoginUrl string `flag:"login-url" cfg:"login_url"` LoginUrl string `flag:"login-url" cfg:"login_url"`
RedeemUrl string `flag:"redeem-url" cfg:"redeem_url"` RedeemUrl string `flag:"redeem-url" cfg:"redeem_url"`
ProfileUrl string `flag:"profile-url" cfg:"profile_url"` ProfileUrl string `flag:"profile-url" cfg:"profile_url"`
Scope string `flag:"scope" cfg:"scope"` ValidateUrl string `flag:"validate-url" cfg:"validate_url"`
Scope string `flag:"scope" cfg:"scope"`
RequestLogging bool `flag:"request-logging" cfg:"request_logging"` RequestLogging bool `flag:"request-logging" cfg:"request_logging"`
@ -61,6 +63,7 @@ func NewOptions() *Options {
CookieSecure: true, CookieSecure: true,
CookieHttpOnly: true, CookieHttpOnly: true,
CookieExpire: time.Duration(168) * time.Hour, CookieExpire: time.Duration(168) * time.Hour,
CookieRefresh: time.Duration(0),
PassBasicAuth: true, PassBasicAuth: true,
PassAccessToken: false, PassAccessToken: false,
PassHostHeader: true, PassHostHeader: true,
@ -117,7 +120,7 @@ func (o *Options) Validate() error {
} }
msgs = parseProviderInfo(o, msgs) msgs = parseProviderInfo(o, msgs)
if o.PassAccessToken { if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) {
valid_cookie_secret_size := false valid_cookie_secret_size := false
for _, i := range []int{16, 24, 32} { for _, i := range []int{16, 24, 32} {
if len(o.CookieSecret) == i { if len(o.CookieSecret) == i {
@ -128,12 +131,20 @@ func (o *Options) Validate() error {
msgs = append(msgs, fmt.Sprintf( msgs = append(msgs, fmt.Sprintf(
"cookie_secret must be 16, 24, or 32 bytes "+ "cookie_secret must be 16, 24, or 32 bytes "+
"to create an AES cipher when "+ "to create an AES cipher when "+
"pass_access_token == true, "+ "pass_access_token == true or "+
"but is %d bytes", "cookie_refresh != 0, but is %d bytes",
len(o.CookieSecret))) len(o.CookieSecret)))
} }
} }
if o.CookieRefresh >= o.CookieExpire {
msgs = append(msgs, fmt.Sprintf(
"cookie_refresh (%s) must be less than "+
"cookie_expire (%s)",
o.CookieRefresh.String(),
o.CookieExpire.String()))
}
if len(msgs) != 0 { if len(msgs) != 0 {
return fmt.Errorf("Invalid configuration:\n %s", return fmt.Errorf("Invalid configuration:\n %s",
strings.Join(msgs, "\n ")) strings.Join(msgs, "\n "))
@ -146,6 +157,7 @@ func parseProviderInfo(o *Options, msgs []string) []string {
p.LoginUrl, msgs = parseUrl(o.LoginUrl, "login", msgs) p.LoginUrl, msgs = parseUrl(o.LoginUrl, "login", msgs)
p.RedeemUrl, msgs = parseUrl(o.RedeemUrl, "redeem", msgs) p.RedeemUrl, msgs = parseUrl(o.RedeemUrl, "redeem", msgs)
p.ProfileUrl, msgs = parseUrl(o.ProfileUrl, "profile", msgs) p.ProfileUrl, msgs = parseUrl(o.ProfileUrl, "profile", msgs)
p.ValidateUrl, msgs = parseUrl(o.ValidateUrl, "validate", msgs)
o.provider = providers.New(o.Provider, p) o.provider = providers.New(o.Provider, p)
return msgs return msgs
} }

View File

@ -4,6 +4,7 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time"
"github.com/bmizerany/assert" "github.com/bmizerany/assert"
) )
@ -112,6 +113,10 @@ func TestPassAccessTokenRequiresSpecificCookieSecretLengths(t *testing.T) {
o.CookieSecret = "cookie of invalid length-" o.CookieSecret = "cookie of invalid length-"
assert.NotEqual(t, nil, o.Validate()) assert.NotEqual(t, nil, o.Validate())
o.PassAccessToken = false
o.CookieRefresh = time.Duration(24) * time.Hour
assert.NotEqual(t, nil, o.Validate())
o.CookieSecret = "16 bytes AES-128" o.CookieSecret = "16 bytes AES-128"
assert.Equal(t, nil, o.Validate()) assert.Equal(t, nil, o.Validate())
@ -121,3 +126,15 @@ func TestPassAccessTokenRequiresSpecificCookieSecretLengths(t *testing.T) {
o.CookieSecret = "32 byte secret for AES-256------" o.CookieSecret = "32 byte secret for AES-256------"
assert.Equal(t, nil, o.Validate()) assert.Equal(t, nil, o.Validate())
} }
func TestCookieRefreshMustBeLessThanCookieExpire(t *testing.T) {
o := testOptions()
assert.Equal(t, nil, o.Validate())
o.CookieSecret = "0123456789abcdef"
o.CookieRefresh = o.CookieExpire
assert.NotEqual(t, nil, o.Validate())
o.CookieRefresh -= time.Duration(1)
assert.Equal(t, nil, o.Validate())
}

View File

@ -24,6 +24,11 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
Host: "accounts.google.com", Host: "accounts.google.com",
Path: "/o/oauth2/token"} Path: "/o/oauth2/token"}
} }
if p.ValidateUrl.String() == "" {
p.ValidateUrl = &url.URL{Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v1/tokeninfo"}
}
if p.Scope == "" { if p.Scope == "" {
p.Scope = "profile email" p.Scope = "profile email"
} }

View File

@ -15,6 +15,7 @@ func newGoogleProvider() *GoogleProvider {
LoginUrl: &url.URL{}, LoginUrl: &url.URL{},
RedeemUrl: &url.URL{}, RedeemUrl: &url.URL{},
ProfileUrl: &url.URL{}, ProfileUrl: &url.URL{},
ValidateUrl: &url.URL{},
Scope: ""}) Scope: ""})
} }
@ -26,6 +27,8 @@ func TestGoogleProviderDefaults(t *testing.T) {
p.Data().LoginUrl.String()) p.Data().LoginUrl.String())
assert.Equal(t, "https://accounts.google.com/o/oauth2/token", assert.Equal(t, "https://accounts.google.com/o/oauth2/token",
p.Data().RedeemUrl.String()) p.Data().RedeemUrl.String())
assert.Equal(t, "https://www.googleapis.com/oauth2/v1/tokeninfo",
p.Data().ValidateUrl.String())
assert.Equal(t, "", p.Data().ProfileUrl.String()) assert.Equal(t, "", p.Data().ProfileUrl.String())
assert.Equal(t, "profile email", p.Data().Scope) assert.Equal(t, "profile email", p.Data().Scope)
} }
@ -45,6 +48,10 @@ func TestGoogleProviderOverrides(t *testing.T) {
Scheme: "https", Scheme: "https",
Host: "example.com", Host: "example.com",
Path: "/oauth/profile"}, Path: "/oauth/profile"},
ValidateUrl: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/oauth/tokeninfo"},
Scope: "profile"}) Scope: "profile"})
assert.NotEqual(t, nil, p) assert.NotEqual(t, nil, p)
assert.Equal(t, "Google", p.Data().ProviderName) assert.Equal(t, "Google", p.Data().ProviderName)
@ -54,6 +61,8 @@ func TestGoogleProviderOverrides(t *testing.T) {
p.Data().RedeemUrl.String()) p.Data().RedeemUrl.String())
assert.Equal(t, "https://example.com/oauth/profile", assert.Equal(t, "https://example.com/oauth/profile",
p.Data().ProfileUrl.String()) p.Data().ProfileUrl.String())
assert.Equal(t, "https://example.com/oauth/tokeninfo",
p.Data().ValidateUrl.String())
assert.Equal(t, "profile", p.Data().Scope) assert.Equal(t, "profile", p.Data().Scope)
} }

View File

@ -32,6 +32,11 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider {
Host: myUsaHost, Host: myUsaHost,
Path: "/api/v1/profile"} Path: "/api/v1/profile"}
} }
if p.ValidateUrl.String() == "" {
p.ValidateUrl = &url.URL{Scheme: "https",
Host: myUsaHost,
Path: "/api/v1/tokeninfo"}
}
if p.Scope == "" { if p.Scope == "" {
p.Scope = "profile.email" p.Scope = "profile.email"
} }

View File

@ -21,11 +21,13 @@ func testMyUsaProvider(hostname string) *MyUsaProvider {
LoginUrl: &url.URL{}, LoginUrl: &url.URL{},
RedeemUrl: &url.URL{}, RedeemUrl: &url.URL{},
ProfileUrl: &url.URL{}, ProfileUrl: &url.URL{},
ValidateUrl: &url.URL{},
Scope: ""}) Scope: ""})
if hostname != "" { if hostname != "" {
updateUrl(p.Data().LoginUrl, hostname) updateUrl(p.Data().LoginUrl, hostname)
updateUrl(p.Data().RedeemUrl, hostname) updateUrl(p.Data().RedeemUrl, hostname)
updateUrl(p.Data().ProfileUrl, hostname) updateUrl(p.Data().ProfileUrl, hostname)
updateUrl(p.Data().ValidateUrl, hostname)
} }
return p return p
} }
@ -56,6 +58,8 @@ func TestMyUsaProviderDefaults(t *testing.T) {
p.Data().RedeemUrl.String()) p.Data().RedeemUrl.String())
assert.Equal(t, "https://alpha.my.usa.gov/api/v1/profile", assert.Equal(t, "https://alpha.my.usa.gov/api/v1/profile",
p.Data().ProfileUrl.String()) p.Data().ProfileUrl.String())
assert.Equal(t, "https://alpha.my.usa.gov/api/v1/tokeninfo",
p.Data().ValidateUrl.String())
assert.Equal(t, "profile.email", p.Data().Scope) assert.Equal(t, "profile.email", p.Data().Scope)
} }
@ -74,6 +78,10 @@ func TestMyUsaProviderOverrides(t *testing.T) {
Scheme: "https", Scheme: "https",
Host: "example.com", Host: "example.com",
Path: "/oauth/profile"}, Path: "/oauth/profile"},
ValidateUrl: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/oauth/tokeninfo"},
Scope: "profile"}) Scope: "profile"})
assert.NotEqual(t, nil, p) assert.NotEqual(t, nil, p)
assert.Equal(t, "MyUSA", p.Data().ProviderName) assert.Equal(t, "MyUSA", p.Data().ProviderName)
@ -83,6 +91,8 @@ func TestMyUsaProviderOverrides(t *testing.T) {
p.Data().RedeemUrl.String()) p.Data().RedeemUrl.String())
assert.Equal(t, "https://example.com/oauth/profile", assert.Equal(t, "https://example.com/oauth/profile",
p.Data().ProfileUrl.String()) p.Data().ProfileUrl.String())
assert.Equal(t, "https://example.com/oauth/tokeninfo",
p.Data().ValidateUrl.String())
assert.Equal(t, "profile", p.Data().Scope) assert.Equal(t, "profile", p.Data().Scope)
} }

View File

@ -9,6 +9,7 @@ type ProviderData struct {
LoginUrl *url.URL LoginUrl *url.URL
RedeemUrl *url.URL RedeemUrl *url.URL
ProfileUrl *url.URL ProfileUrl *url.URL
ValidateUrl *url.URL
Scope string Scope string
} }