Update sessions state

This commit is contained in:
Joel Speed 2018-01-27 10:53:17 +00:00
parent 2ac5cf6b56
commit 161028d61e
No known key found for this signature in database
GPG Key ID: 83695B8B3A376982
2 changed files with 26 additions and 8 deletions

View File

@ -30,6 +30,9 @@ func (s *SessionState) String() string {
if s.AccessToken != "" { if s.AccessToken != "" {
o += " token:true" o += " token:true"
} }
if s.IdToken != "" {
o += " id_token:true"
}
if !s.ExpiresOn.IsZero() { if !s.ExpiresOn.IsZero() {
o += fmt.Sprintf(" expires:%s", s.ExpiresOn) o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
} }
@ -61,13 +64,19 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
return "", err return "", err
} }
} }
i := s.IdToken
if i != "" {
if i, err = c.Encrypt(i); err != nil {
return "", err
}
}
r := s.RefreshToken r := s.RefreshToken
if r != "" { if r != "" {
if r, err = c.Encrypt(r); err != nil { if r, err = c.Encrypt(r); err != nil {
return "", err return "", err
} }
} }
return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil
} }
func decodeSessionStatePlain(v string) (s *SessionState, err error) { func decodeSessionStatePlain(v string) (s *SessionState, err error) {
@ -91,8 +100,8 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error)
} }
chunks := strings.Split(v, "|") chunks := strings.Split(v, "|")
if len(chunks) != 4 { if len(chunks) != 5 {
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks))
return return
} }
@ -107,11 +116,17 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error)
} }
} }
ts, _ := strconv.Atoi(chunks[2]) if chunks[2] != "" {
if sessionState.IdToken, err = c.Decrypt(chunks[2]); err != nil {
return nil, err
}
}
ts, _ := strconv.Atoi(chunks[3])
sessionState.ExpiresOn = time.Unix(int64(ts), 0) sessionState.ExpiresOn = time.Unix(int64(ts), 0)
if chunks[3] != "" { if chunks[4] != "" {
if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil { if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil {
return nil, err return nil, err
} }
} }

View File

@ -21,12 +21,13 @@ func TestSessionStateSerialization(t *testing.T) {
s := &SessionState{ s := &SessionState{
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
IdToken: "rawtoken1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321", RefreshToken: "refresh4321",
} }
encoded, err := s.EncodeSessionState(c) encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, 3, strings.Count(encoded, "|")) assert.Equal(t, 4, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c) ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
@ -34,6 +35,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, "user", ss.User) assert.Equal(t, "user", ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.IdToken, ss.IdToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken) assert.Equal(t, s.RefreshToken, ss.RefreshToken)
@ -45,6 +47,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.IdToken, ss.IdToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
} }
@ -62,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(c) encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, 3, strings.Count(encoded, "|")) assert.Equal(t, 4, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c) ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)