diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index f6efefd..e085976 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -14,6 +14,7 @@ import ( type SessionState struct { AccessToken string `json:",omitempty"` IDToken string `json:",omitempty"` + CreatedAt time.Time `json:"-"` ExpiresOn time.Time `json:"-"` RefreshToken string `json:",omitempty"` Email string `json:",omitempty"` @@ -23,6 +24,7 @@ type SessionState struct { // SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value type SessionStateJSON struct { *SessionState + CreatedAt *time.Time `json:",omitempty"` ExpiresOn *time.Time `json:",omitempty"` } @@ -34,6 +36,14 @@ func (s *SessionState) IsExpired() bool { return false } +// Age returns the age of a session +func (s *SessionState) Age() time.Duration { + if !s.CreatedAt.IsZero() { + return time.Now().Truncate(time.Second).Sub(s.CreatedAt) + } + return 0 +} + // String constructs a summary of the session state func (s *SessionState) String() string { o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User) @@ -43,6 +53,9 @@ func (s *SessionState) String() string { if s.IDToken != "" { o += " id_token:true" } + if !s.CreatedAt.IsZero() { + o += fmt.Sprintf(" created:%s", s.CreatedAt) + } if !s.ExpiresOn.IsZero() { o += fmt.Sprintf(" expires:%s", s.ExpiresOn) } @@ -95,6 +108,9 @@ func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { } // Embed SessionState and ExpiresOn pointer into SessionStateJSON ssj := &SessionStateJSON{SessionState: &ss} + if !ss.CreatedAt.IsZero() { + ssj.CreatedAt = &ss.CreatedAt + } if !ss.ExpiresOn.IsZero() { ssj.ExpiresOn = &ss.ExpiresOn } @@ -165,8 +181,11 @@ func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) { var ss *SessionState err := json.Unmarshal([]byte(v), &ssj) if err == nil && ssj.SessionState != nil { - // Extract SessionState and ExpiresOn value from SessionStateJSON + // Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON ss = ssj.SessionState + if ssj.CreatedAt != nil { + ss.CreatedAt = *ssj.CreatedAt + } if ssj.ExpiresOn != nil { ss.ExpiresOn = *ssj.ExpiresOn } diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 83b21a4..a48344e 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -22,6 +22,7 @@ func TestSessionStateSerialization(t *testing.T) { Email: "user@domain.com", AccessToken: "token1234", IDToken: "rawtoken1234", + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } @@ -35,6 +36,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.IDToken, ss.IDToken) + assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.RefreshToken, ss.RefreshToken) @@ -44,6 +46,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, "user@domain.com", ss.User) assert.NotEqual(t, s.Email, ss.Email) + assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.IDToken, ss.IDToken) @@ -59,6 +62,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { User: "just-user", Email: "user@domain.com", AccessToken: "token1234", + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } @@ -71,6 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.AccessToken, ss.AccessToken) + assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.RefreshToken, ss.RefreshToken) @@ -80,6 +85,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, s.User, ss.User) assert.NotEqual(t, s.Email, ss.Email) + assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) @@ -89,6 +95,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { s := &sessions.SessionState{ Email: "user@domain.com", AccessToken: "token1234", + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } @@ -109,6 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { User: "just-user", Email: "user@domain.com", AccessToken: "token1234", + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } @@ -147,6 +155,7 @@ type testCase struct { // Currently only tests without cipher here because we have no way to mock // the random generator used in EncodeSessionState. func TestEncodeSessionState(t *testing.T) { + c := time.Now() e := time.Now().Add(time.Duration(1) * time.Hour) testCases := []testCase{ @@ -163,6 +172,7 @@ func TestEncodeSessionState(t *testing.T) { User: "just-user", AccessToken: "token1234", IDToken: "rawtoken1234", + CreatedAt: c, ExpiresOn: e, RefreshToken: "refresh4321", }, @@ -185,6 +195,9 @@ func TestEncodeSessionState(t *testing.T) { // TestDecodeSessionState testssessions.DecodeSessionState with the test vector func TestDecodeSessionState(t *testing.T) { + created := time.Now() + createdJSON, _ := created.MarshalJSON() + createdString := string(createdJSON) e := time.Now().Add(time.Duration(1) * time.Hour) eJSON, _ := e.MarshalJSON() eString := string(eJSON) @@ -219,7 +232,7 @@ func TestDecodeSessionState(t *testing.T) { Email: "user@domain.com", User: "just-user", }, - Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), + Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), }, { SessionState: sessions.SessionState{ @@ -227,10 +240,11 @@ func TestDecodeSessionState(t *testing.T) { User: "just-user", AccessToken: "token1234", IDToken: "rawtoken1234", + CreatedAt: created, ExpiresOn: e, RefreshToken: "refresh4321", }, - Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), + Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), Cipher: c, }, { @@ -316,3 +330,14 @@ func TestDecodeSessionState(t *testing.T) { } } } + +func TestSessionStateAge(t *testing.T) { + ss := &sessions.SessionState{} + + // Created at unset so should be 0 + assert.Equal(t, time.Duration(0), ss.Age()) + + // Set CreatedAt to 1 hour ago + ss.CreatedAt = time.Now().Add(-1 * time.Hour) + assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) +} diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 0fc4a64..4a3d323 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -78,7 +78,7 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { for _, c := range req.Cookies() { if cookieNameRegex.MatchString(c.Name) { - clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1) + clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) http.SetCookie(rw, clearCookie) cookies = append(cookies, clearCookie) @@ -101,14 +101,14 @@ func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expira if value != "" { value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now) } - c := s.makeCookie(req, s.CookieName, value, expiration) + c := s.makeCookie(req, s.CookieName, value, expiration, now) if len(c.Value) > 4096-len(s.CookieName) { return splitCookie(c) } return []*http.Cookie{c} } -func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie { +func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { return cookies.MakeCookie( req, name, @@ -118,7 +118,7 @@ func (s *SessionStore) makeCookie(req *http.Request, name string, value string, s.CookieHTTPOnly, s.CookieSecure, expiration, - time.Now(), + now, ) } diff --git a/providers/google.go b/providers/google.go index f79a131..6f29c2c 100644 --- a/providers/google.go +++ b/providers/google.go @@ -149,6 +149,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), RefreshToken: jsonResponse.RefreshToken, Email: c.Email, diff --git a/providers/logingov.go b/providers/logingov.go index 60f4260..95d7aa9 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -252,6 +252,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), Email: email, } diff --git a/providers/oidc.go b/providers/oidc.go index bacabdf..08ea082 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -87,6 +87,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) s.AccessToken = newSession.AccessToken s.IDToken = newSession.IDToken s.RefreshToken = newSession.RefreshToken + s.CreatedAt = newSession.CreatedAt s.ExpiresOn = newSession.ExpiresOn s.Email = newSession.Email return @@ -126,6 +127,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok AccessToken: token.AccessToken, IDToken: rawIDToken, RefreshToken: token.RefreshToken, + CreatedAt: time.Now(), ExpiresOn: token.Expiry, Email: claims.Email, User: claims.Subject, diff --git a/providers/provider_default.go b/providers/provider_default.go index cd78251..4716014 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "net/http" "net/url" + "time" "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/pkg/apis/sessions" @@ -72,7 +73,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat return } if a := v.Get("access_token"); a != "" { - s = &sessions.SessionState{AccessToken: a} + s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()} } else { err = fmt.Errorf("no access token found %s", body) }