Add CreatedAt to SessionState
This commit is contained in:
parent
fbee5eae16
commit
34cbe0497c
@ -14,6 +14,7 @@ import (
|
||||
type SessionState struct {
|
||||
AccessToken string `json:",omitempty"`
|
||||
IDToken string `json:",omitempty"`
|
||||
CreatedAt time.Time `json:"-"`
|
||||
ExpiresOn time.Time `json:"-"`
|
||||
RefreshToken string `json:",omitempty"`
|
||||
Email string `json:",omitempty"`
|
||||
@ -23,6 +24,7 @@ type SessionState struct {
|
||||
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
|
||||
type SessionStateJSON struct {
|
||||
*SessionState
|
||||
CreatedAt *time.Time `json:",omitempty"`
|
||||
ExpiresOn *time.Time `json:",omitempty"`
|
||||
}
|
||||
|
||||
@ -34,6 +36,14 @@ func (s *SessionState) IsExpired() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Age returns the age of a session
|
||||
func (s *SessionState) Age() time.Duration {
|
||||
if !s.CreatedAt.IsZero() {
|
||||
return time.Now().Truncate(time.Second).Sub(s.CreatedAt)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// String constructs a summary of the session state
|
||||
func (s *SessionState) String() string {
|
||||
o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
|
||||
@ -43,6 +53,9 @@ func (s *SessionState) String() string {
|
||||
if s.IDToken != "" {
|
||||
o += " id_token:true"
|
||||
}
|
||||
if !s.CreatedAt.IsZero() {
|
||||
o += fmt.Sprintf(" created:%s", s.CreatedAt)
|
||||
}
|
||||
if !s.ExpiresOn.IsZero() {
|
||||
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
|
||||
}
|
||||
@ -95,6 +108,9 @@ func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
|
||||
}
|
||||
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
|
||||
ssj := &SessionStateJSON{SessionState: &ss}
|
||||
if !ss.CreatedAt.IsZero() {
|
||||
ssj.CreatedAt = &ss.CreatedAt
|
||||
}
|
||||
if !ss.ExpiresOn.IsZero() {
|
||||
ssj.ExpiresOn = &ss.ExpiresOn
|
||||
}
|
||||
@ -165,8 +181,11 @@ func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
|
||||
var ss *SessionState
|
||||
err := json.Unmarshal([]byte(v), &ssj)
|
||||
if err == nil && ssj.SessionState != nil {
|
||||
// Extract SessionState and ExpiresOn value from SessionStateJSON
|
||||
// Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
|
||||
ss = ssj.SessionState
|
||||
if ssj.CreatedAt != nil {
|
||||
ss.CreatedAt = *ssj.CreatedAt
|
||||
}
|
||||
if ssj.ExpiresOn != nil {
|
||||
ss.ExpiresOn = *ssj.ExpiresOn
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
@ -35,6 +36,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||
assert.Equal(t, s.IDToken, ss.IDToken)
|
||||
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
@ -44,6 +46,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, "user@domain.com", ss.User)
|
||||
assert.NotEqual(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||
assert.NotEqual(t, s.IDToken, ss.IDToken)
|
||||
@ -59,6 +62,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
@ -71,6 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
@ -80,6 +85,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, s.User, ss.User)
|
||||
assert.NotEqual(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
|
||||
@ -89,6 +95,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
s := &sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
@ -109,6 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
@ -147,6 +155,7 @@ type testCase struct {
|
||||
// Currently only tests without cipher here because we have no way to mock
|
||||
// the random generator used in EncodeSessionState.
|
||||
func TestEncodeSessionState(t *testing.T) {
|
||||
c := time.Now()
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
|
||||
testCases := []testCase{
|
||||
@ -163,6 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
CreatedAt: c,
|
||||
ExpiresOn: e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
@ -185,6 +195,9 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
|
||||
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
|
||||
func TestDecodeSessionState(t *testing.T) {
|
||||
created := time.Now()
|
||||
createdJSON, _ := created.MarshalJSON()
|
||||
createdString := string(createdJSON)
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
eJSON, _ := e.MarshalJSON()
|
||||
eString := string(eJSON)
|
||||
@ -219,7 +232,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||
},
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
@ -227,10 +240,11 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
CreatedAt: created,
|
||||
ExpiresOn: e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
||||
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
@ -316,3 +330,14 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStateAge(t *testing.T) {
|
||||
ss := &sessions.SessionState{}
|
||||
|
||||
// Created at unset so should be 0
|
||||
assert.Equal(t, time.Duration(0), ss.Age())
|
||||
|
||||
// Set CreatedAt to 1 hour ago
|
||||
ss.CreatedAt = time.Now().Add(-1 * time.Hour)
|
||||
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
|
||||
}
|
||||
|
@ -78,7 +78,7 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
||||
|
||||
for _, c := range req.Cookies() {
|
||||
if cookieNameRegex.MatchString(c.Name) {
|
||||
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1)
|
||||
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
|
||||
|
||||
http.SetCookie(rw, clearCookie)
|
||||
cookies = append(cookies, clearCookie)
|
||||
@ -101,14 +101,14 @@ func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expira
|
||||
if value != "" {
|
||||
value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now)
|
||||
}
|
||||
c := s.makeCookie(req, s.CookieName, value, expiration)
|
||||
c := s.makeCookie(req, s.CookieName, value, expiration, now)
|
||||
if len(c.Value) > 4096-len(s.CookieName) {
|
||||
return splitCookie(c)
|
||||
}
|
||||
return []*http.Cookie{c}
|
||||
}
|
||||
|
||||
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie {
|
||||
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return cookies.MakeCookie(
|
||||
req,
|
||||
name,
|
||||
@ -118,7 +118,7 @@ func (s *SessionStore) makeCookie(req *http.Request, name string, value string,
|
||||
s.CookieHTTPOnly,
|
||||
s.CookieSecure,
|
||||
expiration,
|
||||
time.Now(),
|
||||
now,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -149,6 +149,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||
RefreshToken: jsonResponse.RefreshToken,
|
||||
Email: c.Email,
|
||||
|
@ -252,6 +252,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||
Email: email,
|
||||
}
|
||||
|
@ -87,6 +87,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
|
||||
s.AccessToken = newSession.AccessToken
|
||||
s.IDToken = newSession.IDToken
|
||||
s.RefreshToken = newSession.RefreshToken
|
||||
s.CreatedAt = newSession.CreatedAt
|
||||
s.ExpiresOn = newSession.ExpiresOn
|
||||
s.Email = newSession.Email
|
||||
return
|
||||
@ -126,6 +127,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
||||
AccessToken: token.AccessToken,
|
||||
IDToken: rawIDToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: token.Expiry,
|
||||
Email: claims.Email,
|
||||
User: claims.Subject,
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
@ -72,7 +73,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat
|
||||
return
|
||||
}
|
||||
if a := v.Get("access_token"); a != "" {
|
||||
s = &sessions.SessionState{AccessToken: a}
|
||||
s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()}
|
||||
} else {
|
||||
err = fmt.Errorf("no access token found %s", body)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user