From 6aa35a9ecfe99dc9c78839da60b98b2bd250549a Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 27 Jan 2018 10:53:17 +0000 Subject: [PATCH] Update sessions state --- providers/session_state.go | 27 +++++++++++++++++++++------ providers/session_state_test.go | 7 +++++-- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/providers/session_state.go b/providers/session_state.go index 8195029..2862cdd 100644 --- a/providers/session_state.go +++ b/providers/session_state.go @@ -33,6 +33,9 @@ func (s *SessionState) String() string { if s.AccessToken != "" { o += " token:true" } + if s.IDToken != "" { + o += " id_token:true" + } if !s.ExpiresOn.IsZero() { o += fmt.Sprintf(" expires:%s", s.ExpiresOn) } @@ -66,13 +69,19 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { return "", err } } + i := s.IDToken + if i != "" { + if i, err = c.Encrypt(i); err != nil { + return "", err + } + } r := s.RefreshToken if r != "" { if r, err = c.Encrypt(r); err != nil { return "", err } } - return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil + return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil } func decodeSessionStatePlain(v string) (s *SessionState, err error) { @@ -97,8 +106,8 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) } chunks := strings.Split(v, "|") - if len(chunks) != 4 { - err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) + if len(chunks) != 5 { + err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks)) return } @@ -113,11 +122,17 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) } } - ts, _ := strconv.Atoi(chunks[2]) + if chunks[2] != "" { + if sessionState.IDToken, err = c.Decrypt(chunks[2]); err != nil { + return nil, err + } + } + + ts, _ := strconv.Atoi(chunks[3]) sessionState.ExpiresOn = time.Unix(int64(ts), 0) - if chunks[3] != "" { - if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil { + if chunks[4] != "" { + if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil { return nil, err } } diff --git a/providers/session_state_test.go b/providers/session_state_test.go index f34f292..504228f 100644 --- a/providers/session_state_test.go +++ b/providers/session_state_test.go @@ -21,12 +21,13 @@ func TestSessionStateSerialization(t *testing.T) { s := &SessionState{ Email: "user@domain.com", AccessToken: "token1234", + IDToken: "rawtoken1234", ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - assert.Equal(t, 3, strings.Count(encoded, "|")) + assert.Equal(t, 4, strings.Count(encoded, "|")) ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) @@ -34,6 +35,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, "user", ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.AccessToken, ss.AccessToken) + assert.Equal(t, s.IDToken, ss.IDToken) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.RefreshToken, ss.RefreshToken) @@ -45,6 +47,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.NotEqual(t, s.AccessToken, ss.AccessToken) + assert.NotEqual(t, s.IDToken, ss.IDToken) assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) } @@ -62,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { } encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - assert.Equal(t, 3, strings.Count(encoded, "|")) + assert.Equal(t, 4, strings.Count(encoded, "|")) ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss)