From d49c3e167f3dced5d8630be77dff5b6e25b4645c Mon Sep 17 00:00:00 2001 From: Jehiah Czebotar Date: Tue, 23 Jun 2015 07:23:39 -0400 Subject: [PATCH] SessionState refactoring; improve token renewal and cookie refresh * New SessionState to consolidate email, access token and refresh token * split ServeHttp into individual methods * log on session renewal * log on access token refresh * refactor cookie encription/decription and session state serialization --- api/api.go | 3 + cookie/cookies.go | 128 ++++++++++ cookie/cookies_test.go | 23 ++ cookies.go | 140 ----------- cookies_test.go | 75 ------ oauthproxy.go | 363 +++++++++++++++++------------ oauthproxy_test.go | 149 ++++-------- providers/github.go | 21 +- providers/google.go | 68 ++++-- providers/google_test.go | 118 ++++++---- providers/internal_util.go | 4 + providers/internal_util_test.go | 50 ++-- providers/linkedin.go | 14 +- providers/linkedin_test.go | 10 +- providers/myusa.go | 8 +- providers/myusa_test.go | 12 +- providers/provider_default.go | 61 ++++- providers/provider_default_test.go | 17 ++ providers/providers.go | 13 +- providers/session_state.go | 115 +++++++++ providers/session_state_test.go | 88 +++++++ 21 files changed, 883 insertions(+), 597 deletions(-) create mode 100644 cookie/cookies.go create mode 100644 cookie/cookies_test.go delete mode 100644 cookies.go delete mode 100644 cookies_test.go create mode 100644 providers/provider_default_test.go create mode 100644 providers/session_state.go create mode 100644 providers/session_state_test.go diff --git a/api/api.go b/api/api.go index 0dac604..245bfb1 100644 --- a/api/api.go +++ b/api/api.go @@ -3,6 +3,7 @@ package api import ( "fmt" "io/ioutil" + "log" "net/http" "github.com/bitly/go-simplejson" @@ -11,10 +12,12 @@ import ( func Request(req *http.Request) (*simplejson.Json, error) { resp, err := http.DefaultClient.Do(req) if err != nil { + log.Printf("%s %s %s", req.Method, req.URL, err) return nil, err } body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() + log.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body) if err != nil { return nil, err } diff --git a/cookie/cookies.go b/cookie/cookies.go new file mode 100644 index 0000000..b9df87f --- /dev/null +++ b/cookie/cookies.go @@ -0,0 +1,128 @@ +package cookie + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/rand" + "crypto/sha1" + "encoding/base64" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" +) + +// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set. +// additionally, the 'value' is encrypted so it's opaque to the browser + +// Validate ensures a cookie is properly signed +func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) { + // value, timestamp, sig + parts := strings.Split(cookie.Value, "|") + if len(parts) != 3 { + return + } + sig := cookieSignature(seed, cookie.Name, parts[0], parts[1]) + if checkHmac(parts[2], sig) { + ts, err := strconv.Atoi(parts[1]) + if err != nil { + return + } + // The expiration timestamp set when the cookie was created + // isn't sent back by the browser. Hence, we check whether the + // creation timestamp stored in the cookie falls within the + // window defined by (Now()-expiration, Now()]. + t = time.Unix(int64(ts), 0) + if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) { + // it's a valid cookie. now get the contents + rawValue, err := base64.URLEncoding.DecodeString(parts[0]) + if err == nil { + value = string(rawValue) + ok = true + return + } + } + } + return +} + +// SignedValue returns a cookie that is signed and can later be checked with Validate +func SignedValue(seed string, key string, value string, now time.Time) string { + encodedValue := base64.URLEncoding.EncodeToString([]byte(value)) + timeStr := fmt.Sprintf("%d", now.Unix()) + sig := cookieSignature(seed, key, encodedValue, timeStr) + cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig) + return cookieVal +} + +func cookieSignature(args ...string) string { + h := hmac.New(sha1.New, []byte(args[0])) + for _, arg := range args[1:] { + h.Write([]byte(arg)) + } + var b []byte + b = h.Sum(b) + return base64.URLEncoding.EncodeToString(b) +} + +func checkHmac(input, expected string) bool { + inputMAC, err1 := base64.URLEncoding.DecodeString(input) + if err1 == nil { + expectedMAC, err2 := base64.URLEncoding.DecodeString(expected) + if err2 == nil { + return hmac.Equal(inputMAC, expectedMAC) + } + } + return false +} + +// Cipher provides methods to encrypt and decrypt cookie values +type Cipher struct { + cipher.Block +} + +// NewCipher returns a new aes Cipher for encrypting cookie values +func NewCipher(secret string) (*Cipher, error) { + c, err := aes.NewCipher([]byte(secret)) + if err != nil { + return nil, err + } + return &Cipher{Block: c}, err +} + +// Encrypt a value for use in a cookie +func (c *Cipher) Encrypt(value string) (string, error) { + ciphertext := make([]byte, aes.BlockSize+len(value)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return "", fmt.Errorf("failed to create initialization vector %s", err) + } + + stream := cipher.NewCFBEncrypter(c.Block, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value)) + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// Decrypt a value from a cookie to it's original string +func (c *Cipher) Decrypt(s string) (string, error) { + encrypted, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return "", fmt.Errorf("failed to decrypt cookie value %s", err) + } + + if len(encrypted) < aes.BlockSize { + return "", fmt.Errorf("encrypted cookie value should be "+ + "at least %d bytes, but is only %d bytes", + aes.BlockSize, len(encrypted)) + } + + iv := encrypted[:aes.BlockSize] + encrypted = encrypted[aes.BlockSize:] + stream := cipher.NewCFBDecrypter(c.Block, iv) + stream.XORKeyStream(encrypted, encrypted) + + return string(encrypted), nil +} diff --git a/cookie/cookies_test.go b/cookie/cookies_test.go new file mode 100644 index 0000000..d527cb8 --- /dev/null +++ b/cookie/cookies_test.go @@ -0,0 +1,23 @@ +package cookie + +import ( + "testing" + + "github.com/bmizerany/assert" +) + +func TestEncodeAndDecodeAccessToken(t *testing.T) { + const secret = "0123456789abcdefghijklmnopqrstuv" + const token = "my access token" + c, err := NewCipher(secret) + assert.Equal(t, nil, err) + + encoded, err := c.Encrypt(token) + assert.Equal(t, nil, err) + + decoded, err := c.Decrypt(encoded) + assert.Equal(t, nil, err) + + assert.NotEqual(t, token, encoded) + assert.Equal(t, token, decoded) +} diff --git a/cookies.go b/cookies.go deleted file mode 100644 index 511be37..0000000 --- a/cookies.go +++ /dev/null @@ -1,140 +0,0 @@ -package main - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "time" -) - -func validateCookie(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) { - // value, timestamp, sig - parts := strings.Split(cookie.Value, "|") - if len(parts) != 3 { - return - } - sig := cookieSignature(seed, cookie.Name, parts[0], parts[1]) - if checkHmac(parts[2], sig) { - ts, err := strconv.Atoi(parts[1]) - if err != nil { - return - } - // The expiration timestamp set when the cookie was created - // isn't sent back by the browser. Hence, we check whether the - // creation timestamp stored in the cookie falls within the - // window defined by (Now()-expiration, Now()]. - t = time.Unix(int64(ts), 0) - if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) { - // it's a valid cookie. now get the contents - rawValue, err := base64.URLEncoding.DecodeString(parts[0]) - if err == nil { - value = string(rawValue) - ok = true - return - } - } - } - return -} - -func signedCookieValue(seed string, key string, value string, now time.Time) string { - encodedValue := base64.URLEncoding.EncodeToString([]byte(value)) - timeStr := fmt.Sprintf("%d", now.Unix()) - sig := cookieSignature(seed, key, encodedValue, timeStr) - cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig) - return cookieVal -} - -func cookieSignature(args ...string) string { - h := hmac.New(sha1.New, []byte(args[0])) - for _, arg := range args[1:] { - h.Write([]byte(arg)) - } - var b []byte - b = h.Sum(b) - return base64.URLEncoding.EncodeToString(b) -} - -func checkHmac(input, expected string) bool { - inputMAC, err1 := base64.URLEncoding.DecodeString(input) - if err1 == nil { - expectedMAC, err2 := base64.URLEncoding.DecodeString(expected) - if err2 == nil { - return hmac.Equal(inputMAC, expectedMAC) - } - } - return false -} - -func encodeAccessToken(aes_cipher cipher.Block, access_token string) (string, error) { - ciphertext := make([]byte, aes.BlockSize+len(access_token)) - iv := ciphertext[:aes.BlockSize] - if _, err := io.ReadFull(rand.Reader, iv); err != nil { - return "", fmt.Errorf("failed to create access code initialization vector") - } - - stream := cipher.NewCFBEncrypter(aes_cipher, iv) - stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(access_token)) - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (string, error) { - encrypted_access_token, err := base64.StdEncoding.DecodeString( - encoded_access_token) - - if err != nil { - return "", fmt.Errorf("failed to decode access token") - } - - if len(encrypted_access_token) < aes.BlockSize { - return "", fmt.Errorf("encrypted access token should be "+ - "at least %d bytes, but is only %d bytes", - aes.BlockSize, len(encrypted_access_token)) - } - - iv := encrypted_access_token[:aes.BlockSize] - encrypted_access_token = encrypted_access_token[aes.BlockSize:] - stream := cipher.NewCFBDecrypter(aes_cipher, iv) - stream.XORKeyStream(encrypted_access_token, encrypted_access_token) - - return string(encrypted_access_token), nil -} - -func buildCookieValue(email string, aes_cipher cipher.Block, - access_token string) (string, error) { - if aes_cipher == nil { - return email, nil - } - - encoded_token, err := encodeAccessToken(aes_cipher, access_token) - if err != nil { - return email, fmt.Errorf( - "error encoding access token for %s: %s", email, err) - } - return email + "|" + encoded_token, nil -} - -func parseCookieValue(value string, aes_cipher cipher.Block) (email, user, - access_token string, err error) { - components := strings.Split(value, "|") - email = components[0] - user = strings.Split(email, "@")[0] - - if aes_cipher != nil && len(components) == 2 { - access_token, err = decodeAccessToken(aes_cipher, components[1]) - if err != nil { - err = fmt.Errorf( - "error decoding access token for %s: %s", - email, err) - } - } - return email, user, access_token, err -} diff --git a/cookies_test.go b/cookies_test.go deleted file mode 100644 index 44696e8..0000000 --- a/cookies_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package main - -import ( - "crypto/aes" - "github.com/bmizerany/assert" - "strings" - "testing" -) - -func TestEncodeAndDecodeAccessToken(t *testing.T) { - const key = "0123456789abcdefghijklmnopqrstuv" - const access_token = "my access token" - c, err := aes.NewCipher([]byte(key)) - assert.Equal(t, nil, err) - - encoded_token, err := encodeAccessToken(c, access_token) - assert.Equal(t, nil, err) - - decoded_token, err := decodeAccessToken(c, encoded_token) - assert.Equal(t, nil, err) - - assert.NotEqual(t, access_token, encoded_token) - assert.Equal(t, access_token, decoded_token) -} - -func TestBuildCookieValueWithoutAccessToken(t *testing.T) { - value, err := buildCookieValue("michael.bland@gsa.gov", nil, "") - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", value) -} - -func TestBuildCookieValueWithAccessTokenAndNilCipher(t *testing.T) { - value, err := buildCookieValue("michael.bland@gsa.gov", nil, - "access token") - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", value) -} - -func TestParseCookieValueWithoutAccessToken(t *testing.T) { - email, user, access_token, err := parseCookieValue( - "michael.bland@gsa.gov", nil) - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) - assert.Equal(t, "michael.bland", user) - assert.Equal(t, "", access_token) -} - -func TestParseCookieValueWithAccessTokenAndNilCipher(t *testing.T) { - email, user, access_token, err := parseCookieValue( - "michael.bland@gsa.gov|access_token", nil) - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) - assert.Equal(t, "michael.bland", user) - assert.Equal(t, "", access_token) -} - -func TestBuildAndParseCookieValueWithAccessToken(t *testing.T) { - aes_cipher, err := aes.NewCipher([]byte("0123456789abcdef")) - assert.Equal(t, nil, err) - value, err := buildCookieValue("michael.bland@gsa.gov", aes_cipher, - "access_token") - assert.Equal(t, nil, err) - - prefix := "michael.bland@gsa.gov|" - if !strings.HasPrefix(value, prefix) { - t.Fatal("cookie value does not start with \"%s\": %s", - prefix, value) - } - - email, user, access_token, err := parseCookieValue(value, aes_cipher) - assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) - assert.Equal(t, "michael.bland", user) - assert.Equal(t, "access_token", access_token) -} diff --git a/oauthproxy.go b/oauthproxy.go index 62084cb..43d3e52 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -1,8 +1,6 @@ package main import ( - "crypto/aes" - "crypto/cipher" "encoding/base64" "errors" "fmt" @@ -16,6 +14,7 @@ import ( "strings" "time" + "github.com/bitly/oauth2_proxy/cookie" "github.com/bitly/oauth2_proxy/providers" ) @@ -44,7 +43,7 @@ type OauthProxy struct { serveMux http.Handler PassBasicAuth bool PassAccessToken bool - AesCipher cipher.Block + CookieCipher *cookie.Cipher skipAuthRegex []string compiledRegex []*regexp.Regexp templates *template.Template @@ -116,10 +115,10 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain, refresh) - var aes_cipher cipher.Block + var cipher *cookie.Cipher if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { var err error - aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret)) + cipher, err = cookie.NewCipher(opts.CookieSecret) if err != nil { log.Fatal("error creating AES cipher with "+ "cookie-secret ", opts.CookieSecret, ": ", err) @@ -150,7 +149,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { compiledRegex: opts.CompiledRegex, PassBasicAuth: opts.PassBasicAuth, PassAccessToken: opts.PassAccessToken, - AesCipher: aes_cipher, + CookieCipher: cipher, templates: loadTemplates(opts.CustomTemplatesDir), } } @@ -177,22 +176,20 @@ func (p *OauthProxy) displayCustomLoginForm() bool { return p.HtpasswdFile != nil && p.DisplayHtpasswdForm } -func (p *OauthProxy) redeemCode(host, code string) (string, string, error) { +func (p *OauthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { if code == "" { - return "", "", errors.New("missing code") + return nil, errors.New("missing code") } redirectUri := p.GetRedirectURI(host) - body, access_token, err := p.provider.Redeem(redirectUri, code) + s, err = p.provider.Redeem(redirectUri, code) if err != nil { - return "", "", err + return } - email, err := p.provider.GetEmailAddress(body, access_token) - if err != nil { - return "", "", err + if s.Email == "" { + s.Email, err = p.provider.GetEmailAddress(s) } - - return access_token, email, nil + return } func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { @@ -208,9 +205,8 @@ func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time } if value != "" { - value = signedCookieValue(p.CookieSeed, p.CookieName, value, now) + value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) } - return &http.Cookie{ Name: p.CookieName, Value: value, @@ -230,35 +226,34 @@ func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val st http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now())) } -func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (email, user, access_token string, ok bool) { - var value string - var timestamp time.Time - cookie, err := req.Cookie(p.CookieName) - if err == nil { - value, timestamp, ok = validateCookie(cookie, p.CookieSeed, p.CookieExpire) - if ok { - email, user, access_token, err = parseCookieValue(value, p.AesCipher) - } - } +func (p *OauthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { + var age time.Duration + c, err := req.Cookie(p.CookieName) if err != nil { - log.Printf(err.Error()) - ok = false - } else if ok && p.CookieRefresh != time.Duration(0) { - refresh := timestamp.Add(p.CookieRefresh) - if refresh.Before(time.Now()) { - log.Printf("refreshing %s old session for %s (refresh after %s)", time.Now().Sub(timestamp), email, p.CookieRefresh) - ok = p.Validator(email) - log.Printf("re-validating %s valid:%v", email, ok) - if ok { - ok = p.provider.ValidateToken(access_token) - log.Printf("re-validating access token. valid:%v", ok) - } - if ok { - p.SetCookie(rw, req, value) - } - } + // always http.ErrNoCookie + return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) } - return + val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire) + if !ok { + return nil, age, errors.New("Cookie Signature not valid") + } + + session, err := p.provider.SessionFromCookie(val, p.CookieCipher) + if err != nil { + return nil, age, err + } + + age = time.Now().Truncate(time.Second).Sub(timestamp) + return session, age, nil +} + +func (p *OauthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { + value, err := p.provider.CookieForSession(s, p.CookieCipher) + if err != nil { + return err + } + p.SetCookie(rw, req, value) + return nil } func (p *OauthProxy) RobotsTxt(rw http.ResponseWriter) { @@ -344,156 +339,226 @@ func (p *OauthProxy) GetRedirect(req *http.Request) (string, error) { return redirect, err } -func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - // check if this is a redirect back at the end of oauth - remoteAddr := req.RemoteAddr - if req.Header.Get("X-Real-IP") != "" { - remoteAddr += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP")) - } - - var ok bool - var user string - var email string - var access_token string - - if req.URL.Path == p.RobotsPath { - p.RobotsTxt(rw) - return - } - - if req.URL.Path == p.PingPath { - p.PingPage(rw) - return - } - +func (p *OauthProxy) IsWhitelistedPath(path string) (ok bool) { for _, u := range p.compiledRegex { - match := u.MatchString(req.URL.Path) - if match { - p.serveMux.ServeHTTP(rw, req) - return - } - - } - - if req.URL.Path == p.SignInPath { - redirect, err := p.GetRedirect(req) - if err != nil { - p.ErrorPage(rw, 500, "Internal Error", err.Error()) - return - } - - user, ok = p.ManualSignIn(rw, req) + ok = u.MatchString(path) if ok { - p.SetCookie(rw, req, user) - http.Redirect(rw, req, redirect, 302) - } else { - p.SignInPage(rw, req, 200) + return } + } + return +} + +func getRemoteAddr(req *http.Request) (s string) { + s = req.RemoteAddr + if req.Header.Get("X-Real-IP") != "" { + s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP")) + } + return +} + +func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + switch path := req.URL.Path; { + case path == p.RobotsPath: + p.RobotsTxt(rw) + case path == p.PingPath: + p.PingPage(rw) + case p.IsWhitelistedPath(path): + p.serveMux.ServeHTTP(rw, req) + case path == p.SignInPath: + p.SignIn(rw, req) + case path == p.OauthStartPath: + p.OauthStart(rw, req) + case path == p.OauthCallbackPath: + p.OauthCallback(rw, req) + default: + p.Proxy(rw, req) + } +} + +func (p *OauthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { + redirect, err := p.GetRedirect(req) + if err != nil { + p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } - if req.URL.Path == p.OauthStartPath { - redirect, err := p.GetRedirect(req) - if err != nil { - p.ErrorPage(rw, 500, "Internal Error", err.Error()) - return - } - redirectURI := p.GetRedirectURI(req.Host) - http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302) + + user, ok := p.ManualSignIn(rw, req) + if ok { + session := &providers.SessionState{User: user} + p.SaveSession(rw, req, session) + http.Redirect(rw, req, redirect, 302) + } else { + p.SignInPage(rw, req, 200) + } +} + +func (p *OauthProxy) OauthStart(rw http.ResponseWriter, req *http.Request) { + redirect, err := p.GetRedirect(req) + if err != nil { + p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } - if req.URL.Path == p.OauthCallbackPath { - // finish the oauth cycle - err := req.ParseForm() + redirectURI := p.GetRedirectURI(req.Host) + http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302) +} + +func (p *OauthProxy) OauthCallback(rw http.ResponseWriter, req *http.Request) { + remoteAddr := getRemoteAddr(req) + + // finish the oauth cycle + err := req.ParseForm() + if err != nil { + p.ErrorPage(rw, 500, "Internal Error", err.Error()) + return + } + errorString := req.Form.Get("error") + if errorString != "" { + p.ErrorPage(rw, 403, "Permission Denied", errorString) + return + } + + session, err := p.redeemCode(req.Host, req.Form.Get("code")) + if err != nil { + log.Printf("%s error redeeming code %s", remoteAddr, err) + p.ErrorPage(rw, 500, "Internal Error", "Internal Error") + return + } + + redirect := req.Form.Get("state") + if redirect == "" { + redirect = "/" + } + + // set cookie, or deny + if p.Validator(session.Email) { + log.Printf("%s authentication complete %s", remoteAddr, session) + err := p.SaveSession(rw, req, session) if err != nil { - p.ErrorPage(rw, 500, "Internal Error", err.Error()) - return - } - errorString := req.Form.Get("error") - if errorString != "" { - p.ErrorPage(rw, 403, "Permission Denied", errorString) + log.Printf("%s %s", remoteAddr, err) + p.ErrorPage(rw, 500, "Internal Error", "Internal Error") return } + http.Redirect(rw, req, redirect, 302) + } else { + log.Printf("%s Permission Denied: %q is unauthorized", remoteAddr, session.Email) + p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") + } +} - access_token, email, err = p.redeemCode(req.Host, req.Form.Get("code")) +func (p *OauthProxy) Proxy(rw http.ResponseWriter, req *http.Request) { + var saveSession, clearSession, revalidated bool + remoteAddr := getRemoteAddr(req) + + session, sessionAge, err := p.LoadCookiedSession(req) + if err != nil { + log.Printf("%s %s", remoteAddr, err) + } + if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { + log.Printf("%s refreshing %s old session cookie for %s (refresh after %s)", remoteAddr, sessionAge, session, p.CookieRefresh) + saveSession = true + } + + if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil { + log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session) + clearSession = true + session = nil + } else if ok { + saveSession = true + revalidated = true + } + + if session != nil && session.IsExpired() { + log.Printf("%s removing session. token expired %s", remoteAddr, session) + session = nil + saveSession = false + clearSession = true + } + + if saveSession && !revalidated && session.AccessToken != "" { + if !p.provider.ValidateSessionState(session) { + log.Printf("%s removing session. error validating %s", remoteAddr, session) + saveSession = false + session = nil + clearSession = true + } + } + + if saveSession && session.Email != "" && !p.Validator(session.Email) { + log.Printf("%s Permission Denied: removing session %s", remoteAddr, session) + session = nil + saveSession = false + clearSession = true + } + + if saveSession { + err := p.SaveSession(rw, req, session) if err != nil { - log.Printf("%s error redeeming code %s", remoteAddr, err) - p.ErrorPage(rw, 500, "Internal Error", err.Error()) - return - } - - redirect := req.Form.Get("state") - if redirect == "" { - redirect = "/" - } - - // set cookie, or deny - if p.Validator(email) { - log.Printf("%s authenticating %s completed", remoteAddr, email) - value, err := buildCookieValue( - email, p.AesCipher, access_token) - if err != nil { - log.Printf("%s", err) - } - p.SetCookie(rw, req, value) - http.Redirect(rw, req, redirect, 302) - return - } else { - log.Printf("validating: %s is unauthorized") - p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") + log.Printf("%s %s", remoteAddr, err) + p.ErrorPage(rw, 500, "Internal Error", "Internal Error") return } } - if !ok { - email, user, access_token, ok = p.ProcessCookie(rw, req) + if clearSession { + p.ClearCookie(rw, req) } - if !ok { - user, ok = p.CheckBasicAuth(req) + if session == nil { + session, err = p.CheckBasicAuth(req) + if err != nil { + log.Printf("%s %s", remoteAddr, err) + } } - if !ok { + if session == nil { p.SignInPage(rw, req, 403) return } // At this point, the user is authenticated. proxy normally if p.PassBasicAuth { - req.SetBasicAuth(user, "") - req.Header["X-Forwarded-User"] = []string{user} - req.Header["X-Forwarded-Email"] = []string{email} + req.SetBasicAuth(session.User, "") + req.Header["X-Forwarded-User"] = []string{session.User} + if session.Email != "" { + req.Header["X-Forwarded-Email"] = []string{session.Email} + } } - if p.PassAccessToken { - req.Header["X-Forwarded-Access-Token"] = []string{access_token} + if p.PassAccessToken && session.AccessToken != "" { + req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken} } - if email == "" { - rw.Header().Set("GAP-Auth", user) + if session.Email == "" { + rw.Header().Set("GAP-Auth", session.User) } else { - rw.Header().Set("GAP-Auth", email) + rw.Header().Set("GAP-Auth", session.Email) } p.serveMux.ServeHTTP(rw, req) } -func (p *OauthProxy) CheckBasicAuth(req *http.Request) (string, bool) { +func (p *OauthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { if p.HtpasswdFile == nil { - return "", false + return nil, nil } - s := strings.SplitN(req.Header.Get("Authorization"), " ", 2) + auth := req.Header.Get("Authorization") + if auth == "" { + return nil, nil + } + s := strings.SplitN(auth, " ", 2) if len(s) != 2 || s[0] != "Basic" { - return "", false + return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization")) } b, err := base64.StdEncoding.DecodeString(s[1]) if err != nil { - return "", false + return nil, err } pair := strings.SplitN(string(b), ":", 2) if len(pair) != 2 { - return "", false + return nil, fmt.Errorf("invalid format %s", b) } if p.HtpasswdFile.Validate(pair[0], pair[1]) { log.Printf("authenticated %q via basic auth", pair[0]) - return pair[0], true + return &providers.SessionState{User: pair[0]}, nil } - return "", false + return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index ed02b88..5b5a935 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -94,11 +94,11 @@ type TestProvider struct { ValidToken bool } -func (tp *TestProvider) GetEmailAddress(body []byte, access_token string) (string, error) { +func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { return tp.EmailAddress, nil } -func (tp *TestProvider) ValidateToken(access_token string) bool { +func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { return tp.ValidToken } @@ -378,97 +378,73 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest { }) } -func (p *ProcessCookieTest) MakeCookie(value, access_token string, ref time.Time) *http.Cookie { - cookie_value, _ := buildCookieValue(value, p.proxy.AesCipher, access_token) - return p.proxy.MakeCookie(p.req, cookie_value, p.opts.CookieExpire, ref) +func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie { + return p.proxy.MakeCookie(p.req, value, p.opts.CookieExpire, ref) } -func (p *ProcessCookieTest) AddCookie(value, access_token string) { - p.req.AddCookie(p.MakeCookie(value, access_token, time.Now())) +func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { + value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) + if err != nil { + return err + } + p.req.AddCookie(p.proxy.MakeCookie(p.req, value, p.proxy.CookieExpire, ref)) + return nil } -func (p *ProcessCookieTest) ProcessCookie() (email, user, access_token string, ok bool) { - return p.proxy.ProcessCookie(p.rw, p.req) +func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) { + return p.proxy.LoadCookiedSession(p.req) } -func TestProcessCookie(t *testing.T) { +func TestLoadCookiedSession(t *testing.T) { pc_test := NewProcessCookieTestWithDefaults() - pc_test.AddCookie("michael.bland@gsa.gov", "my_access_token") - email, user, access_token, ok := pc_test.ProcessCookie() - assert.Equal(t, true, ok) - assert.Equal(t, "michael.bland@gsa.gov", email) - assert.Equal(t, "michael.bland", user) - assert.Equal(t, "my_access_token", access_token) + startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + pc_test.SaveSession(startSession, time.Now()) + + session, _, err := pc_test.LoadCookiedSession() + assert.Equal(t, nil, err) + assert.Equal(t, startSession.Email, session.Email) + assert.Equal(t, "michael.bland", session.User) + assert.Equal(t, startSession.AccessToken, session.AccessToken) } func TestProcessCookieNoCookieError(t *testing.T) { pc_test := NewProcessCookieTestWithDefaults() - _, _, _, ok := pc_test.ProcessCookie() - assert.Equal(t, false, ok) -} -func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) { - pc_test := NewProcessCookieTestWithDefaults() - value, _ := buildCookieValue("michael.bland@gsa.gov", - pc_test.proxy.AesCipher, "my_access_token") - pc_test.req.AddCookie(pc_test.proxy.MakeCookie( - pc_test.req, value+"some bogus bytes", - pc_test.opts.CookieExpire, time.Now())) - _, _, _, ok := pc_test.ProcessCookie() - assert.Equal(t, false, ok) + session, _, err := pc_test.LoadCookiedSession() + assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) + if session != nil { + t.Errorf("expected nil session. got %#v", session) + } } func TestProcessCookieRefreshNotSet(t *testing.T) { pc_test := NewProcessCookieTestWithDefaults() pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour reference := time.Now().Add(time.Duration(-2) * time.Hour) - cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "", reference) - pc_test.req.AddCookie(cookie) - _, _, _, ok := pc_test.ProcessCookie() - assert.Equal(t, true, ok) - assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"]) -} + startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + pc_test.SaveSession(startSession, reference) -func TestProcessCookieRefresh(t *testing.T) { - pc_test := NewProcessCookieTestWithDefaults() - pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour - reference := time.Now().Add(time.Duration(-2) * time.Hour) - cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) - pc_test.req.AddCookie(cookie) - - pc_test.proxy.CookieRefresh = time.Hour - _, _, _, ok := pc_test.ProcessCookie() - assert.Equal(t, true, ok) - assert.NotEqual(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"]) -} - -func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) { - pc_test := NewProcessCookieTestWithDefaults() - pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour - reference := time.Now().Add(time.Duration(-30) * time.Minute) - cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) - pc_test.req.AddCookie(cookie) - - pc_test.proxy.CookieRefresh = time.Hour - _, _, _, ok := pc_test.ProcessCookie() - assert.Equal(t, true, ok) - assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"]) + session, age, err := pc_test.LoadCookiedSession() + assert.Equal(t, nil, err) + if age < time.Duration(-2)*time.Hour { + t.Errorf("cookie too young %v", age) + } + assert.Equal(t, startSession.Email, session.Email) } func TestProcessCookieFailIfCookieExpired(t *testing.T) { pc_test := NewProcessCookieTestWithDefaults() pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) - pc_test.req.AddCookie(cookie) + startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + pc_test.SaveSession(startSession, reference) - if _, _, _, ok := pc_test.ProcessCookie(); ok { - t.Error("ProcessCookie() should have failed") - } - if set_cookie := pc_test.rw.HeaderMap["Set-Cookie"]; set_cookie != nil { - t.Error("expected Set-Cookie to be nil, instead was: ", set_cookie) + session, _, err := pc_test.LoadCookiedSession() + assert.NotEqual(t, nil, err) + if session != nil { + t.Errorf("expected nil session %#v", session) } } @@ -476,44 +452,13 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { pc_test := NewProcessCookieTestWithDefaults() pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) - pc_test.req.AddCookie(cookie) + startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + pc_test.SaveSession(startSession, reference) pc_test.proxy.CookieRefresh = time.Hour - if _, _, _, ok := pc_test.ProcessCookie(); ok { - t.Error("ProcessCookie() should have failed") - } - if set_cookie := pc_test.rw.HeaderMap["Set-Cookie"]; set_cookie != nil { - t.Error("expected Set-Cookie to be nil, instead was: ", set_cookie) + session, _, err := pc_test.LoadCookiedSession() + assert.NotEqual(t, nil, err) + if session != nil { + t.Errorf("expected nil session %#v", session) } } - -func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) { - pc_test := NewProcessCookieTest(ProcessCookieTestOpts{ - provider_validate_cookie_response: false, - }) - pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour - reference := time.Now().Add(time.Duration(-24) * time.Hour) - cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) - pc_test.req.AddCookie(cookie) - - pc_test.proxy.CookieRefresh = time.Hour - _, _, _, ok := pc_test.ProcessCookie() - assert.Equal(t, false, ok) - assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"]) -} - -func TestProcessCookieFailIfRefreshSetAndUserNoLongerValid(t *testing.T) { - pc_test := NewProcessCookieTestWithDefaults() - pc_test.validate_user = false - - pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour - reference := time.Now().Add(time.Duration(-2) * time.Hour) - cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) - pc_test.req.AddCookie(cookie) - - pc_test.proxy.CookieRefresh = time.Hour - _, _, _, ok := pc_test.ProcessCookie() - assert.Equal(t, false, ok) - assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"]) -} diff --git a/providers/github.go b/providers/github.go index b138af6..4f2a988 100644 --- a/providers/github.go +++ b/providers/github.go @@ -2,8 +2,10 @@ package providers import ( "encoding/json" + "errors" "fmt" "io/ioutil" + "log" "net/http" "net/url" ) @@ -138,7 +140,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { return false, nil } -func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (string, error) { +func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { var emails []struct { Email string `json:"email"` @@ -148,31 +150,34 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri // if we require an Org or Team, check that first if p.Org != "" { if p.Team != "" { - if ok, err := p.hasOrgAndTeam(access_token); err != nil || !ok { + if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok { return "", err } } else { - if ok, err := p.hasOrg(access_token); err != nil || !ok { + if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok { return "", err } } } params := url.Values{ - "access_token": {access_token}, + "access_token": {s.AccessToken}, } endpoint := "https://api.github.com/user/emails?" + params.Encode() resp, err := http.DefaultClient.Get(endpoint) if err != nil { return "", err } - body, err = ioutil.ReadAll(resp.Body) + body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { return "", err } + if resp.StatusCode != 200 { return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint, body) + } else { + log.Printf("got %d from %q %s", resp.StatusCode, endpoint, body) } if err := json.Unmarshal(body, &emails); err != nil { @@ -185,9 +190,5 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri } } - return "", nil -} - -func (p *GitHubProvider) ValidateToken(access_token string) bool { - return validateToken(p, access_token, nil) + return "", errors.New("no email address found") } diff --git a/providers/google.go b/providers/google.go index 40a6228..8c0a0cc 100644 --- a/providers/google.go +++ b/providers/google.go @@ -7,9 +7,11 @@ import ( "errors" "fmt" "io/ioutil" + "log" "net/http" "net/url" "strings" + "time" ) type GoogleProvider struct { @@ -43,25 +45,19 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { return &GoogleProvider{ProviderData: p} } -func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (string, error) { - var response struct { - IdToken string `json:"id_token"` - } - - if err := json.Unmarshal(body, &response); err != nil { - return "", err - } +func emailFromIdToken(idToken string) (string, error) { // id_token is a base64 encode ID token payload // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo - jwt := strings.Split(response.IdToken, ".") + jwt := strings.Split(idToken, ".") b, err := jwtDecodeSegment(jwt[1]) if err != nil { return "", err } var email struct { - Email string `json:"email"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` } err = json.Unmarshal(b, &email) if err != nil { @@ -70,6 +66,9 @@ func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (stri if email.Email == "" { return "", errors.New("missing email") } + if !email.EmailVerified { + return "", fmt.Errorf("email %s not listed as verified", email.Email) + } return email.Email, nil } @@ -81,11 +80,7 @@ func jwtDecodeSegment(seg string) ([]byte, error) { return base64.URLEncoding.DecodeString(seg) } -func (p *GoogleProvider) ValidateToken(access_token string) bool { - return validateToken(p, access_token, nil) -} - -func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token string, err error) { +func (p *GoogleProvider) Redeem(redirectUrl, code string) (s *SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -108,6 +103,7 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token st if err != nil { return } + var body []byte body, err = ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { @@ -122,17 +118,44 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token st var jsonResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + IdToken string `json:"id_token"` } err = json.Unmarshal(body, &jsonResponse) if err != nil { return } - - token, err = p.redeemRefreshToken(jsonResponse.RefreshToken) + var email string + email, err = emailFromIdToken(jsonResponse.IdToken) + if err != nil { + return + } + s = &SessionState{ + AccessToken: jsonResponse.AccessToken, + ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), + RefreshToken: jsonResponse.RefreshToken, + Email: email, + } return } -func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, err error) { +func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { + if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { + return false, nil + } + + newToken, duration, err := p.redeemRefreshToken(s.RefreshToken) + if err != nil { + return false, err + } + origExpiration := s.ExpiresOn + s.AccessToken = newToken + s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second) + log.Printf("refreshed access token %s (expired on %s)", s, origExpiration) + return true, nil +} + +func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) { // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh params := url.Values{} params.Add("client_id", p.ClientID) @@ -162,12 +185,15 @@ func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, return } - var jsonResponse struct { + var data struct { AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` } - err = json.Unmarshal(body, &jsonResponse) + err = json.Unmarshal(body, &data) if err != nil { return } - return jsonResponse.AccessToken, nil + token = data.AccessToken + expires = time.Duration(data.ExpiresIn) * time.Second + return } diff --git a/providers/google_test.go b/providers/google_test.go index 7551640..0da80f4 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -3,11 +3,22 @@ package providers import ( "encoding/base64" "encoding/json" - "github.com/bmizerany/assert" + "net/http" + "net/http/httptest" "net/url" "testing" + + "github.com/bmizerany/assert" ) +func newRedeemServer(body []byte) (*url.URL, *httptest.Server) { + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Write(body) + })) + u, _ := url.Parse(s.URL) + return u, s +} + func newGoogleProvider() *GoogleProvider { return NewGoogleProvider( &ProviderData{ @@ -66,63 +77,88 @@ func TestGoogleProviderOverrides(t *testing.T) { assert.Equal(t, "profile", p.Data().Scope) } -func TestGoogleProviderGetEmailAddress(t *testing.T) { - p := newGoogleProvider() - body, err := json.Marshal( - struct { - IdToken string `json:"id_token"` - }{ - IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov"}`)), - }, - ) - assert.Equal(t, nil, err) - email, err := p.GetEmailAddress(body, "ignored access_token") - assert.Equal(t, "michael.bland@gsa.gov", email) - assert.Equal(t, nil, err) +type redeemResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + IdToken string `json:"id_token"` } +func TestGoogleProviderGetEmailAddress(t *testing.T) { + p := newGoogleProvider() + body, err := json.Marshal(redeemResponse{ + AccessToken: "a1234", + ExpiresIn: 10, + RefreshToken: "refresh12345", + IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), + }) + assert.Equal(t, nil, err) + var server *httptest.Server + p.RedeemUrl, server = newRedeemServer(body) + defer server.Close() + + session, err := p.Redeem("http://redirect/", "code1234") + assert.Equal(t, nil, err) + assert.NotEqual(t, session, nil) + assert.Equal(t, "michael.bland@gsa.gov", session.Email) + assert.Equal(t, "a1234", session.AccessToken) + assert.Equal(t, "refresh12345", session.RefreshToken) +} + +// func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { p := newGoogleProvider() - body, err := json.Marshal( - struct { - IdToken string `json:"id_token"` - }{ - IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, - }, - ) + body, err := json.Marshal(redeemResponse{ + AccessToken: "a1234", + IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, + }) assert.Equal(t, nil, err) - email, err := p.GetEmailAddress(body, "ignored access_token") - assert.Equal(t, "", email) + var server *httptest.Server + p.RedeemUrl, server = newRedeemServer(body) + defer server.Close() + + session, err := p.Redeem("http://redirect/", "code1234") assert.NotEqual(t, nil, err) + if session != nil { + t.Errorf("expect nill session %#v", session) + } } func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { p := newGoogleProvider() - body, err := json.Marshal( - struct { - IdToken string `json:"id_token"` - }{ - IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), - }, - ) + body, err := json.Marshal(redeemResponse{ + AccessToken: "a1234", + IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), + }) assert.Equal(t, nil, err) - email, err := p.GetEmailAddress(body, "ignored access_token") - assert.Equal(t, "", email) + var server *httptest.Server + p.RedeemUrl, server = newRedeemServer(body) + defer server.Close() + + session, err := p.Redeem("http://redirect/", "code1234") assert.NotEqual(t, nil, err) + if session != nil { + t.Errorf("expect nill session %#v", session) + } + } func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { p := newGoogleProvider() - body, err := json.Marshal( - struct { - IdToken string `json:"id_token"` - }{ - IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), - }, - ) + body, err := json.Marshal(redeemResponse{ + AccessToken: "a1234", + IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), + }) assert.Equal(t, nil, err) - email, err := p.GetEmailAddress(body, "ignored access_token") - assert.Equal(t, "", email) + var server *httptest.Server + p.RedeemUrl, server = newRedeemServer(body) + defer server.Close() + + session, err := p.Redeem("http://redirect/", "code1234") assert.NotEqual(t, nil, err) + if session != nil { + t.Errorf("expect nill session %#v", session) + } + } diff --git a/providers/internal_util.go b/providers/internal_util.go index 4ccd037..ff0cafa 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -9,6 +9,7 @@ import ( "github.com/bitly/oauth2_proxy/api" ) +// validateToken returns true if token is valid func validateToken(p Provider, access_token string, header http.Header) bool { if access_token == "" || p.Data().ValidateUrl == nil { return false @@ -20,12 +21,15 @@ func validateToken(p Provider, access_token string, header http.Header) bool { } resp, err := api.RequestUnparsedResponse(endpoint, header) if err != nil { + log.Printf("GET %s", endpoint) log.Printf("token validation request failed: %s", err) return false } body, _ := ioutil.ReadAll(resp.Body) resp.Body.Close() + log.Printf("%d GET %s %s", resp.StatusCode, endpoint, body) + if resp.StatusCode == 200 { return true } diff --git a/providers/internal_util_test.go b/providers/internal_util_test.go index 36a1d37..bace76d 100644 --- a/providers/internal_util_test.go +++ b/providers/internal_util_test.go @@ -1,36 +1,38 @@ package providers import ( - "github.com/bmizerany/assert" + "errors" "net/http" "net/http/httptest" "net/url" "testing" + + "github.com/bmizerany/assert" ) -type ValidateTokenTestProvider struct { +type ValidateSessionStateTestProvider struct { *ProviderData } -func (tp *ValidateTokenTestProvider) GetEmailAddress(body []byte, access_token string) (string, error) { - return "", nil +func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { + return "", errors.New("not implemented") } // Note that we're testing the internal validateToken() used to implement -// several Provider's ValidateToken() implementations -func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) bool { +// several Provider's ValidateSessionState() implementations +func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { return false } -type ValidateTokenTest struct { +type ValidateSessionStateTest struct { backend *httptest.Server response_code int - provider *ValidateTokenTestProvider + provider *ValidateSessionStateTestProvider header http.Header } -func NewValidateTokenTest() *ValidateTokenTest { - var vt_test ValidateTokenTest +func NewValidateSessionStateTest() *ValidateSessionStateTest { + var vt_test ValidateSessionStateTest vt_test.backend = httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -59,7 +61,7 @@ func NewValidateTokenTest() *ValidateTokenTest { })) backend_url, _ := url.Parse(vt_test.backend.URL) - vt_test.provider = &ValidateTokenTestProvider{ + vt_test.provider = &ValidateSessionStateTestProvider{ ProviderData: &ProviderData{ ValidateUrl: &url.URL{ Scheme: "http", @@ -72,18 +74,18 @@ func NewValidateTokenTest() *ValidateTokenTest { return &vt_test } -func (vt_test *ValidateTokenTest) Close() { +func (vt_test *ValidateSessionStateTest) Close() { vt_test.backend.Close() } -func TestValidateTokenValidToken(t *testing.T) { - vt_test := NewValidateTokenTest() +func TestValidateSessionStateValidToken(t *testing.T) { + vt_test := NewValidateSessionStateTest() defer vt_test.Close() assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil)) } -func TestValidateTokenValidTokenWithHeaders(t *testing.T) { - vt_test := NewValidateTokenTest() +func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { + vt_test := NewValidateSessionStateTest() defer vt_test.Close() vt_test.header = make(http.Header) vt_test.header.Set("Authorization", "Bearer foobar") @@ -91,28 +93,28 @@ func TestValidateTokenValidTokenWithHeaders(t *testing.T) { validateToken(vt_test.provider, "foobar", vt_test.header)) } -func TestValidateTokenEmptyToken(t *testing.T) { - vt_test := NewValidateTokenTest() +func TestValidateSessionStateEmptyToken(t *testing.T) { + vt_test := NewValidateSessionStateTest() defer vt_test.Close() assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) } -func TestValidateTokenEmptyValidateUrl(t *testing.T) { - vt_test := NewValidateTokenTest() +func TestValidateSessionStateEmptyValidateUrl(t *testing.T) { + vt_test := NewValidateSessionStateTest() defer vt_test.Close() vt_test.provider.Data().ValidateUrl = nil assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) } -func TestValidateTokenRequestNetworkFailure(t *testing.T) { - vt_test := NewValidateTokenTest() +func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { + vt_test := NewValidateSessionStateTest() // Close immediately to simulate a network failure vt_test.Close() assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) } -func TestValidateTokenExpiredToken(t *testing.T) { - vt_test := NewValidateTokenTest() +func TestValidateSessionStateExpiredToken(t *testing.T) { + vt_test := NewValidateSessionStateTest() defer vt_test.Close() vt_test.response_code = 401 assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) diff --git a/providers/linkedin.go b/providers/linkedin.go index 6249ec4..78ad3c9 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -1,7 +1,6 @@ package providers import ( - "bytes" "errors" "fmt" "log" @@ -49,16 +48,15 @@ func getLinkedInHeader(access_token string) http.Header { return header } -func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (string, error) { - if access_token == "" { +func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { + if s.AccessToken == "" { return "", errors.New("missing access token") } - params := url.Values{} - req, err := http.NewRequest("GET", p.ProfileUrl.String()+"?format=json", bytes.NewBufferString(params.Encode())) + req, err := http.NewRequest("GET", p.ProfileUrl.String()+"?format=json", nil) if err != nil { return "", err } - req.Header = getLinkedInHeader(access_token) + req.Header = getLinkedInHeader(s.AccessToken) json, err := api.Request(req) if err != nil { @@ -74,6 +72,6 @@ func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (st return email, nil } -func (p *LinkedInProvider) ValidateToken(access_token string) bool { - return validateToken(p, access_token, getLinkedInHeader(access_token)) +func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { + return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) } diff --git a/providers/linkedin_test.go b/providers/linkedin_test.go index 08b3e47..c75a4a8 100644 --- a/providers/linkedin_test.go +++ b/providers/linkedin_test.go @@ -97,8 +97,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { b_url, _ := url.Parse(b.URL) p := testLinkedInProvider(b_url.Host) - email, err := p.GetEmailAddress([]byte{}, - "imaginary_access_token") + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@linkedin.com", email) } @@ -113,7 +113,8 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. - email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token") + session := &SessionState{AccessToken: "unexpected_access_token"} + email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -125,7 +126,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { b_url, _ := url.Parse(b.URL) p := testLinkedInProvider(b_url.Host) - email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } diff --git a/providers/myusa.go b/providers/myusa.go index 7072639..c244ed0 100644 --- a/providers/myusa.go +++ b/providers/myusa.go @@ -42,9 +42,9 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider { return &MyUsaProvider{ProviderData: p} } -func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (string, error) { +func (p *MyUsaProvider) GetEmailAddress(s *SessionState) (string, error) { req, err := http.NewRequest("GET", - p.ProfileUrl.String()+"?access_token="+access_token, nil) + p.ProfileUrl.String()+"?access_token="+s.AccessToken, nil) if err != nil { log.Printf("failed building request %s", err) return "", err @@ -56,7 +56,3 @@ func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (strin } return json.Get("email").String() } - -func (p *MyUsaProvider) ValidateToken(access_token string) bool { - return validateToken(p, access_token, nil) -} diff --git a/providers/myusa_test.go b/providers/myusa_test.go index 32e8520..b4bdb30 100644 --- a/providers/myusa_test.go +++ b/providers/myusa_test.go @@ -1,11 +1,12 @@ package providers import ( - "github.com/bmizerany/assert" "net/http" "net/http/httptest" "net/url" "testing" + + "github.com/bmizerany/assert" ) func updateUrl(url *url.URL, hostname string) { @@ -102,7 +103,8 @@ func TestMyUsaProviderGetEmailAddress(t *testing.T) { b_url, _ := url.Parse(b.URL) p := testMyUsaProvider(b_url.Host) - email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) } @@ -119,7 +121,8 @@ func TestMyUsaProviderGetEmailAddressFailedRequest(t *testing.T) { // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. - email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token") + session := &SessionState{AccessToken: "unexpected_access_token"} + email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } @@ -131,7 +134,8 @@ func TestMyUsaProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { b_url, _ := url.Parse(b.URL) p := testMyUsaProvider(b_url.Host) - email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) } diff --git a/providers/provider_default.go b/providers/provider_default.go index edf7338..b18212f 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -9,9 +9,11 @@ import ( "net/http" "net/url" "strings" + + "github.com/bitly/oauth2_proxy/cookie" ) -func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token string, err error) { +func (p *ProviderData) Redeem(redirectUrl, code string) (s *SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -23,24 +25,28 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri params.Add("client_secret", p.ClientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") - req, err := http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) + var req *http.Request + req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) if err != nil { - return nil, "", err + return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp, err := http.DefaultClient.Do(req) + var resp *http.Response + resp, err = http.DefaultClient.Do(req) if err != nil { - return nil, "", err + return nil, err } + var body []byte body, err = ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { - return nil, "", err + return } if resp.StatusCode != 200 { - return body, "", fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body) + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body) + return } // blindly try json and x-www-form-urlencoded @@ -49,11 +55,23 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri } err = json.Unmarshal(body, &jsonResponse) if err == nil { - return body, jsonResponse.AccessToken, nil + s = &SessionState{ + AccessToken: jsonResponse.AccessToken, + } + return } - v, err := url.ParseQuery(string(body)) - return body, v.Get("access_token"), err + var v url.Values + v, err = url.ParseQuery(string(body)) + if err != nil { + return + } + if a := v.Get("access_token"); a != "" { + s = &SessionState{AccessToken: a} + } else { + err = fmt.Errorf("no access token found %s", body) + } + return } // GetLoginURL with typical oauth parameters @@ -72,3 +90,26 @@ func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string { a.RawQuery = params.Encode() return a.String() } + +// CookieForSession serializes a session state for storage in a cookie +func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { + return s.EncodeSessionState(c) +} + +// SessionFromCookie deserializes a session from a cookie value +func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { + return DecodeSessionState(v, c) +} + +func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { + return "", errors.New("not implemented") +} + +func (p *ProviderData) ValidateSessionState(s *SessionState) bool { + return validateToken(p, s.AccessToken, nil) +} + +// RefreshSessionIfNeeded +func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { + return false, nil +} diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go new file mode 100644 index 0000000..e60aa54 --- /dev/null +++ b/providers/provider_default_test.go @@ -0,0 +1,17 @@ +package providers + +import ( + "testing" + "time" + + "github.com/bmizerany/assert" +) + +func TestRefresh(t *testing.T) { + p := &ProviderData{} + refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ + ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), + }) + assert.Equal(t, false, refreshed) + assert.Equal(t, nil, err) +} diff --git a/providers/providers.go b/providers/providers.go index b7e84eb..3192011 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -1,11 +1,18 @@ package providers +import ( + "github.com/bitly/oauth2_proxy/cookie" +) + type Provider interface { Data() *ProviderData - GetEmailAddress(body []byte, access_token string) (string, error) - Redeem(string, string) ([]byte, string, error) - ValidateToken(access_token string) bool + GetEmailAddress(*SessionState) (string, error) + Redeem(string, string) (*SessionState, error) + ValidateSessionState(*SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string + RefreshSessionIfNeeded(*SessionState) (bool, error) + SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) + CookieForSession(*SessionState, *cookie.Cipher) (string, error) } func New(provider string, p *ProviderData) Provider { diff --git a/providers/session_state.go b/providers/session_state.go new file mode 100644 index 0000000..214b5a4 --- /dev/null +++ b/providers/session_state.go @@ -0,0 +1,115 @@ +package providers + +import ( + "fmt" + "strconv" + "strings" + "time" + + "github.com/bitly/oauth2_proxy/cookie" +) + +type SessionState struct { + AccessToken string + ExpiresOn time.Time + RefreshToken string + Email string + User string +} + +func (s *SessionState) IsExpired() bool { + if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) { + return true + } + return false +} + +func (s *SessionState) String() string { + o := fmt.Sprintf("Session{%s", s.userOrEmail()) + if s.AccessToken != "" { + o += " token:true" + } + if !s.ExpiresOn.IsZero() { + o += fmt.Sprintf(" expires:%s", s.ExpiresOn) + } + if s.RefreshToken != "" { + o += " refresh_token:true" + } + return o + "}" +} + +func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { + if c == nil || s.AccessToken == "" { + return s.userOrEmail(), nil + } + return s.EncryptedString(c) +} + +func (s *SessionState) userOrEmail() string { + u := s.User + if s.Email != "" { + u = s.Email + } + return u +} + +func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { + var err error + if c == nil { + panic("error. missing cipher") + } + a := s.AccessToken + if a != "" { + a, err = c.Encrypt(a) + if err != nil { + return "", err + } + } + r := s.RefreshToken + if r != "" { + r, err = c.Encrypt(r) + if err != nil { + return "", err + } + } + return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil +} + +func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { + chunks := strings.Split(v, "|") + if len(chunks) == 1 { + if strings.Contains(chunks[0], "@") { + u := strings.Split(v, "@")[0] + return &SessionState{Email: v, User: u}, nil + } + return &SessionState{User: v}, nil + } + + if len(chunks) != 4 { + err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) + return + } + + s = &SessionState{} + if c != nil && chunks[1] != "" { + s.AccessToken, err = c.Decrypt(chunks[1]) + if err != nil { + return nil, err + } + } + if c != nil && chunks[3] != "" { + s.RefreshToken, err = c.Decrypt(chunks[3]) + if err != nil { + return nil, err + } + } + if u := chunks[0]; strings.Contains(u, "@") { + s.Email = u + s.User = strings.Split(u, "@")[0] + } else { + s.User = u + } + ts, _ := strconv.Atoi(chunks[2]) + s.ExpiresOn = time.Unix(int64(ts), 0) + return +} diff --git a/providers/session_state_test.go b/providers/session_state_test.go new file mode 100644 index 0000000..ba8de5d --- /dev/null +++ b/providers/session_state_test.go @@ -0,0 +1,88 @@ +package providers + +import ( + "strings" + "testing" + "time" + + "github.com/bitly/oauth2_proxy/cookie" + "github.com/bmizerany/assert" +) + +const secret = "0123456789abcdefghijklmnopqrstuv" +const altSecret = "0000000000abcdefghijklmnopqrstuv" + +func TestSessionStateSerialization(t *testing.T) { + c, err := cookie.NewCipher(secret) + assert.Equal(t, nil, err) + c2, err := cookie.NewCipher(altSecret) + assert.Equal(t, nil, err) + s := &SessionState{ + Email: "user@domain.com", + AccessToken: "token1234", + ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), + RefreshToken: "refresh4321", + } + encoded, err := s.EncodeSessionState(c) + assert.Equal(t, nil, err) + assert.Equal(t, 3, strings.Count(encoded, "|")) + + ss, err := DecodeSessionState(encoded, c) + t.Logf("%#v", ss) + assert.Equal(t, nil, err) + assert.Equal(t, s.Email, ss.Email) + assert.Equal(t, s.AccessToken, ss.AccessToken) + assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) + assert.Equal(t, s.RefreshToken, ss.RefreshToken) + + // ensure a different cipher can't decode properly (ie: it gets gibberish) + ss, err = DecodeSessionState(encoded, c2) + t.Logf("%#v", ss) + assert.Equal(t, nil, err) + assert.Equal(t, s.Email, ss.Email) + assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) + assert.NotEqual(t, s.AccessToken, ss.AccessToken) + assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) +} + +func TestSessionStateSerializationNoCipher(t *testing.T) { + + s := &SessionState{ + Email: "user@domain.com", + AccessToken: "token1234", + ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), + RefreshToken: "refresh4321", + } + encoded, err := s.EncodeSessionState(nil) + assert.Equal(t, nil, err) + assert.Equal(t, s.Email, encoded) + + // only email should have been serialized + ss, err := DecodeSessionState(encoded, nil) + assert.Equal(t, nil, err) + assert.Equal(t, s.Email, ss.Email) + assert.Equal(t, "", ss.AccessToken) + assert.Equal(t, "", ss.RefreshToken) +} + +func TestSessionStateUserOrEmail(t *testing.T) { + + s := &SessionState{ + Email: "user@domain.com", + User: "just-user", + } + assert.Equal(t, "user@domain.com", s.userOrEmail()) + s.Email = "" + assert.Equal(t, "just-user", s.userOrEmail()) +} + +func TestExpired(t *testing.T) { + s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} + assert.Equal(t, true, s.IsExpired()) + + s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} + assert.Equal(t, false, s.IsExpired()) + + s = &SessionState{} + assert.Equal(t, false, s.IsExpired()) +}