Add CreatedAt to SessionState

This commit is contained in:
Joel Speed 2019-05-07 15:32:46 +01:00
parent fbee5eae16
commit 34cbe0497c
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
7 changed files with 57 additions and 8 deletions

View File

@ -14,6 +14,7 @@ import (
type SessionState struct {
AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"`
CreatedAt time.Time `json:"-"`
ExpiresOn time.Time `json:"-"`
RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"`
@ -23,6 +24,7 @@ type SessionState struct {
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
type SessionStateJSON struct {
*SessionState
CreatedAt *time.Time `json:",omitempty"`
ExpiresOn *time.Time `json:",omitempty"`
}
@ -34,6 +36,14 @@ func (s *SessionState) IsExpired() bool {
return false
}
// Age returns the age of a session
func (s *SessionState) Age() time.Duration {
if !s.CreatedAt.IsZero() {
return time.Now().Truncate(time.Second).Sub(s.CreatedAt)
}
return 0
}
// String constructs a summary of the session state
func (s *SessionState) String() string {
o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
@ -43,6 +53,9 @@ func (s *SessionState) String() string {
if s.IDToken != "" {
o += " id_token:true"
}
if !s.CreatedAt.IsZero() {
o += fmt.Sprintf(" created:%s", s.CreatedAt)
}
if !s.ExpiresOn.IsZero() {
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
}
@ -95,6 +108,9 @@ func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
}
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
ssj := &SessionStateJSON{SessionState: &ss}
if !ss.CreatedAt.IsZero() {
ssj.CreatedAt = &ss.CreatedAt
}
if !ss.ExpiresOn.IsZero() {
ssj.ExpiresOn = &ss.ExpiresOn
}
@ -165,8 +181,11 @@ func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
var ss *SessionState
err := json.Unmarshal([]byte(v), &ssj)
if err == nil && ssj.SessionState != nil {
// Extract SessionState and ExpiresOn value from SessionStateJSON
// Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
ss = ssj.SessionState
if ssj.CreatedAt != nil {
ss.CreatedAt = *ssj.CreatedAt
}
if ssj.ExpiresOn != nil {
ss.ExpiresOn = *ssj.ExpiresOn
}

View File

@ -22,6 +22,7 @@ func TestSessionStateSerialization(t *testing.T) {
Email: "user@domain.com",
AccessToken: "token1234",
IDToken: "rawtoken1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
@ -35,6 +36,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.IDToken, ss.IDToken)
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
@ -44,6 +46,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, nil, err)
assert.NotEqual(t, "user@domain.com", ss.User)
assert.NotEqual(t, s.Email, ss.Email)
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.IDToken, ss.IDToken)
@ -59,6 +62,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
User: "just-user",
Email: "user@domain.com",
AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
@ -71,6 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
@ -80,6 +85,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, nil, err)
assert.NotEqual(t, s.User, ss.User)
assert.NotEqual(t, s.Email, ss.Email)
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
@ -89,6 +95,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
s := &sessions.SessionState{
Email: "user@domain.com",
AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
@ -109,6 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
User: "just-user",
Email: "user@domain.com",
AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
@ -147,6 +155,7 @@ type testCase struct {
// Currently only tests without cipher here because we have no way to mock
// the random generator used in EncodeSessionState.
func TestEncodeSessionState(t *testing.T) {
c := time.Now()
e := time.Now().Add(time.Duration(1) * time.Hour)
testCases := []testCase{
@ -163,6 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
CreatedAt: c,
ExpiresOn: e,
RefreshToken: "refresh4321",
},
@ -185,6 +195,9 @@ func TestEncodeSessionState(t *testing.T) {
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
func TestDecodeSessionState(t *testing.T) {
created := time.Now()
createdJSON, _ := created.MarshalJSON()
createdString := string(createdJSON)
e := time.Now().Add(time.Duration(1) * time.Hour)
eJSON, _ := e.MarshalJSON()
eString := string(eJSON)
@ -219,7 +232,7 @@ func TestDecodeSessionState(t *testing.T) {
Email: "user@domain.com",
User: "just-user",
},
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
},
{
SessionState: sessions.SessionState{
@ -227,10 +240,11 @@ func TestDecodeSessionState(t *testing.T) {
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
CreatedAt: created,
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
Cipher: c,
},
{
@ -316,3 +330,14 @@ func TestDecodeSessionState(t *testing.T) {
}
}
}
func TestSessionStateAge(t *testing.T) {
ss := &sessions.SessionState{}
// Created at unset so should be 0
assert.Equal(t, time.Duration(0), ss.Age())
// Set CreatedAt to 1 hour ago
ss.CreatedAt = time.Now().Add(-1 * time.Hour)
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
}

View File

@ -78,7 +78,7 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
for _, c := range req.Cookies() {
if cookieNameRegex.MatchString(c.Name) {
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1)
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
http.SetCookie(rw, clearCookie)
cookies = append(cookies, clearCookie)
@ -101,14 +101,14 @@ func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expira
if value != "" {
value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now)
}
c := s.makeCookie(req, s.CookieName, value, expiration)
c := s.makeCookie(req, s.CookieName, value, expiration, now)
if len(c.Value) > 4096-len(s.CookieName) {
return splitCookie(c)
}
return []*http.Cookie{c}
}
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie {
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
return cookies.MakeCookie(
req,
name,
@ -118,7 +118,7 @@ func (s *SessionStore) makeCookie(req *http.Request, name string, value string,
s.CookieHTTPOnly,
s.CookieSecure,
expiration,
time.Now(),
now,
)
}

View File

@ -149,6 +149,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
RefreshToken: jsonResponse.RefreshToken,
Email: c.Email,

View File

@ -252,6 +252,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
Email: email,
}

View File

@ -87,6 +87,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
s.AccessToken = newSession.AccessToken
s.IDToken = newSession.IDToken
s.RefreshToken = newSession.RefreshToken
s.CreatedAt = newSession.CreatedAt
s.ExpiresOn = newSession.ExpiresOn
s.Email = newSession.Email
return
@ -126,6 +127,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
AccessToken: token.AccessToken,
IDToken: rawIDToken,
RefreshToken: token.RefreshToken,
CreatedAt: time.Now(),
ExpiresOn: token.Expiry,
Email: claims.Email,
User: claims.Subject,

View File

@ -8,6 +8,7 @@ import (
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
@ -72,7 +73,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat
return
}
if a := v.Get("access_token"); a != "" {
s = &sessions.SessionState{AccessToken: a}
s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()}
} else {
err = fmt.Errorf("no access token found %s", body)
}