Add CreatedAt to SessionState

This commit is contained in:
Joel Speed 2019-05-07 15:32:46 +01:00
parent fbee5eae16
commit 34cbe0497c
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
7 changed files with 57 additions and 8 deletions

View File

@ -14,6 +14,7 @@ import (
type SessionState struct { type SessionState struct {
AccessToken string `json:",omitempty"` AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"` IDToken string `json:",omitempty"`
CreatedAt time.Time `json:"-"`
ExpiresOn time.Time `json:"-"` ExpiresOn time.Time `json:"-"`
RefreshToken string `json:",omitempty"` RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"` Email string `json:",omitempty"`
@ -23,6 +24,7 @@ type SessionState struct {
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value // SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
type SessionStateJSON struct { type SessionStateJSON struct {
*SessionState *SessionState
CreatedAt *time.Time `json:",omitempty"`
ExpiresOn *time.Time `json:",omitempty"` ExpiresOn *time.Time `json:",omitempty"`
} }
@ -34,6 +36,14 @@ func (s *SessionState) IsExpired() bool {
return false return false
} }
// Age returns the age of a session
func (s *SessionState) Age() time.Duration {
if !s.CreatedAt.IsZero() {
return time.Now().Truncate(time.Second).Sub(s.CreatedAt)
}
return 0
}
// String constructs a summary of the session state // String constructs a summary of the session state
func (s *SessionState) String() string { func (s *SessionState) String() string {
o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User) o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
@ -43,6 +53,9 @@ func (s *SessionState) String() string {
if s.IDToken != "" { if s.IDToken != "" {
o += " id_token:true" o += " id_token:true"
} }
if !s.CreatedAt.IsZero() {
o += fmt.Sprintf(" created:%s", s.CreatedAt)
}
if !s.ExpiresOn.IsZero() { if !s.ExpiresOn.IsZero() {
o += fmt.Sprintf(" expires:%s", s.ExpiresOn) o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
} }
@ -95,6 +108,9 @@ func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
} }
// Embed SessionState and ExpiresOn pointer into SessionStateJSON // Embed SessionState and ExpiresOn pointer into SessionStateJSON
ssj := &SessionStateJSON{SessionState: &ss} ssj := &SessionStateJSON{SessionState: &ss}
if !ss.CreatedAt.IsZero() {
ssj.CreatedAt = &ss.CreatedAt
}
if !ss.ExpiresOn.IsZero() { if !ss.ExpiresOn.IsZero() {
ssj.ExpiresOn = &ss.ExpiresOn ssj.ExpiresOn = &ss.ExpiresOn
} }
@ -165,8 +181,11 @@ func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
var ss *SessionState var ss *SessionState
err := json.Unmarshal([]byte(v), &ssj) err := json.Unmarshal([]byte(v), &ssj)
if err == nil && ssj.SessionState != nil { if err == nil && ssj.SessionState != nil {
// Extract SessionState and ExpiresOn value from SessionStateJSON // Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
ss = ssj.SessionState ss = ssj.SessionState
if ssj.CreatedAt != nil {
ss.CreatedAt = *ssj.CreatedAt
}
if ssj.ExpiresOn != nil { if ssj.ExpiresOn != nil {
ss.ExpiresOn = *ssj.ExpiresOn ss.ExpiresOn = *ssj.ExpiresOn
} }

View File

@ -22,6 +22,7 @@ func TestSessionStateSerialization(t *testing.T) {
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
IDToken: "rawtoken1234", IDToken: "rawtoken1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
} }
@ -35,6 +36,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.IDToken, ss.IDToken) assert.Equal(t, s.IDToken, ss.IDToken)
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken) assert.Equal(t, s.RefreshToken, ss.RefreshToken)
@ -44,6 +46,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.NotEqual(t, "user@domain.com", ss.User) assert.NotEqual(t, "user@domain.com", ss.User)
assert.NotEqual(t, s.Email, ss.Email) assert.NotEqual(t, s.Email, ss.Email)
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.IDToken, ss.IDToken) assert.NotEqual(t, s.IDToken, ss.IDToken)
@ -59,6 +62,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
User: "just-user", User: "just-user",
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
} }
@ -71,6 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, s.User, ss.User) assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken) assert.Equal(t, s.RefreshToken, ss.RefreshToken)
@ -80,6 +85,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.NotEqual(t, s.User, ss.User) assert.NotEqual(t, s.User, ss.User)
assert.NotEqual(t, s.Email, ss.Email) assert.NotEqual(t, s.Email, ss.Email)
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
@ -89,6 +95,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
s := &sessions.SessionState{ s := &sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
} }
@ -109,6 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
User: "just-user", User: "just-user",
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
} }
@ -147,6 +155,7 @@ type testCase struct {
// Currently only tests without cipher here because we have no way to mock // Currently only tests without cipher here because we have no way to mock
// the random generator used in EncodeSessionState. // the random generator used in EncodeSessionState.
func TestEncodeSessionState(t *testing.T) { func TestEncodeSessionState(t *testing.T) {
c := time.Now()
e := time.Now().Add(time.Duration(1) * time.Hour) e := time.Now().Add(time.Duration(1) * time.Hour)
testCases := []testCase{ testCases := []testCase{
@ -163,6 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
User: "just-user", User: "just-user",
AccessToken: "token1234", AccessToken: "token1234",
IDToken: "rawtoken1234", IDToken: "rawtoken1234",
CreatedAt: c,
ExpiresOn: e, ExpiresOn: e,
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
}, },
@ -185,6 +195,9 @@ func TestEncodeSessionState(t *testing.T) {
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector // TestDecodeSessionState testssessions.DecodeSessionState with the test vector
func TestDecodeSessionState(t *testing.T) { func TestDecodeSessionState(t *testing.T) {
created := time.Now()
createdJSON, _ := created.MarshalJSON()
createdString := string(createdJSON)
e := time.Now().Add(time.Duration(1) * time.Hour) e := time.Now().Add(time.Duration(1) * time.Hour)
eJSON, _ := e.MarshalJSON() eJSON, _ := e.MarshalJSON()
eString := string(eJSON) eString := string(eJSON)
@ -219,7 +232,7 @@ func TestDecodeSessionState(t *testing.T) {
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
}, },
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
}, },
{ {
SessionState: sessions.SessionState{ SessionState: sessions.SessionState{
@ -227,10 +240,11 @@ func TestDecodeSessionState(t *testing.T) {
User: "just-user", User: "just-user",
AccessToken: "token1234", AccessToken: "token1234",
IDToken: "rawtoken1234", IDToken: "rawtoken1234",
CreatedAt: created,
ExpiresOn: e, ExpiresOn: e,
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
}, },
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
Cipher: c, Cipher: c,
}, },
{ {
@ -316,3 +330,14 @@ func TestDecodeSessionState(t *testing.T) {
} }
} }
} }
func TestSessionStateAge(t *testing.T) {
ss := &sessions.SessionState{}
// Created at unset so should be 0
assert.Equal(t, time.Duration(0), ss.Age())
// Set CreatedAt to 1 hour ago
ss.CreatedAt = time.Now().Add(-1 * time.Hour)
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
}

View File

@ -78,7 +78,7 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
for _, c := range req.Cookies() { for _, c := range req.Cookies() {
if cookieNameRegex.MatchString(c.Name) { if cookieNameRegex.MatchString(c.Name) {
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1) clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
http.SetCookie(rw, clearCookie) http.SetCookie(rw, clearCookie)
cookies = append(cookies, clearCookie) cookies = append(cookies, clearCookie)
@ -101,14 +101,14 @@ func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expira
if value != "" { if value != "" {
value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now) value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now)
} }
c := s.makeCookie(req, s.CookieName, value, expiration) c := s.makeCookie(req, s.CookieName, value, expiration, now)
if len(c.Value) > 4096-len(s.CookieName) { if len(c.Value) > 4096-len(s.CookieName) {
return splitCookie(c) return splitCookie(c)
} }
return []*http.Cookie{c} return []*http.Cookie{c}
} }
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie { func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
return cookies.MakeCookie( return cookies.MakeCookie(
req, req,
name, name,
@ -118,7 +118,7 @@ func (s *SessionStore) makeCookie(req *http.Request, name string, value string,
s.CookieHTTPOnly, s.CookieHTTPOnly,
s.CookieSecure, s.CookieSecure,
expiration, expiration,
time.Now(), now,
) )
} }

View File

@ -149,6 +149,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt
s = &sessions.SessionState{ s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
RefreshToken: jsonResponse.RefreshToken, RefreshToken: jsonResponse.RefreshToken,
Email: c.Email, Email: c.Email,

View File

@ -252,6 +252,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session
s = &sessions.SessionState{ s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
Email: email, Email: email,
} }

View File

@ -87,6 +87,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
s.AccessToken = newSession.AccessToken s.AccessToken = newSession.AccessToken
s.IDToken = newSession.IDToken s.IDToken = newSession.IDToken
s.RefreshToken = newSession.RefreshToken s.RefreshToken = newSession.RefreshToken
s.CreatedAt = newSession.CreatedAt
s.ExpiresOn = newSession.ExpiresOn s.ExpiresOn = newSession.ExpiresOn
s.Email = newSession.Email s.Email = newSession.Email
return return
@ -126,6 +127,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
AccessToken: token.AccessToken, AccessToken: token.AccessToken,
IDToken: rawIDToken, IDToken: rawIDToken,
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,
CreatedAt: time.Now(),
ExpiresOn: token.Expiry, ExpiresOn: token.Expiry,
Email: claims.Email, Email: claims.Email,
User: claims.Subject, User: claims.Subject,

View File

@ -8,6 +8,7 @@ import (
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"time"
"github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
@ -72,7 +73,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat
return return
} }
if a := v.Get("access_token"); a != "" { if a := v.Get("access_token"); a != "" {
s = &sessions.SessionState{AccessToken: a} s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()}
} else { } else {
err = fmt.Errorf("no access token found %s", body) err = fmt.Errorf("no access token found %s", body)
} }