Use encoding/json for SessionState serialization (#63)
* Use encoding/json for SessionState serialization In order to make it easier to extend in future. * Store only email and user in cookie when cipher is unavailable This improves safety and robustness, and also preserves the existing behaviour. * Add TestEncodeSessionState/TestDecodeSessionState Use the test vectors with JSON encoding just introduced. * Support session state encoding in older versions * Add test cases for legacy session state strings * Add check for wrong expiration time in session state strings * Avoid exposing time.Time zero value when encoding session state string * Update CHANGELOG.md
This commit is contained in:
parent
a656435d00
commit
2070fae47c
@ -2,6 +2,10 @@
|
||||
|
||||
## Changes since v3.1.0
|
||||
|
||||
- [#63](https://github.com/pusher/oauth2_proxy/pull/63) Use encoding/json for SessionState serialization (@yaegashi)
|
||||
- Use JSON to encode session state to be stored in browser cookies
|
||||
- Implement legacy decode function to support existing cookies generated by older versions
|
||||
- Add detailed table driven tests in session_state_test.go
|
||||
- [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added login.gov provider (@timothy-spencer)
|
||||
- [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added environment variables for all config options (@timothy-spencer)
|
||||
- [#70](https://github.com/pusher/oauth2_proxy/pull/70) Fix handling of splitted cookies (@einfachchr)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -11,12 +12,18 @@ import (
|
||||
|
||||
// SessionState is used to store information about the currently authenticated user session
|
||||
type SessionState struct {
|
||||
AccessToken string
|
||||
IDToken string
|
||||
ExpiresOn time.Time
|
||||
RefreshToken string
|
||||
Email string
|
||||
User string
|
||||
AccessToken string `json:",omitempty"`
|
||||
IDToken string `json:",omitempty"`
|
||||
ExpiresOn time.Time `json:"-"`
|
||||
RefreshToken string `json:",omitempty"`
|
||||
Email string `json:",omitempty"`
|
||||
User string `json:",omitempty"`
|
||||
}
|
||||
|
||||
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
|
||||
type SessionStateJSON struct {
|
||||
*SessionState
|
||||
ExpiresOn *time.Time `json:",omitempty"`
|
||||
}
|
||||
|
||||
// IsExpired checks whether the session has expired
|
||||
@ -29,7 +36,7 @@ func (s *SessionState) IsExpired() bool {
|
||||
|
||||
// String constructs a summary of the session state
|
||||
func (s *SessionState) String() string {
|
||||
o := fmt.Sprintf("Session{%s", s.accountInfo())
|
||||
o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
|
||||
if s.AccessToken != "" {
|
||||
o += " token:true"
|
||||
}
|
||||
@ -47,95 +54,145 @@ func (s *SessionState) String() string {
|
||||
|
||||
// EncodeSessionState returns string representation of the current session
|
||||
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
|
||||
if c == nil || s.AccessToken == "" {
|
||||
return s.accountInfo(), nil
|
||||
}
|
||||
return s.EncryptedString(c)
|
||||
}
|
||||
|
||||
func (s *SessionState) accountInfo() string {
|
||||
return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
|
||||
}
|
||||
|
||||
// EncryptedString encrypts the session state into a cookie string
|
||||
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
|
||||
var err error
|
||||
var ss SessionState
|
||||
if c == nil {
|
||||
panic("error. missing cipher")
|
||||
}
|
||||
a := s.AccessToken
|
||||
if a != "" {
|
||||
if a, err = c.Encrypt(a); err != nil {
|
||||
// Store only Email and User when cipher is unavailable
|
||||
ss.Email = s.Email
|
||||
ss.User = s.User
|
||||
} else {
|
||||
ss = *s
|
||||
var err error
|
||||
if ss.AccessToken != "" {
|
||||
ss.AccessToken, err = c.Encrypt(ss.AccessToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
i := s.IDToken
|
||||
if i != "" {
|
||||
if i, err = c.Encrypt(i); err != nil {
|
||||
if ss.IDToken != "" {
|
||||
ss.IDToken, err = c.Encrypt(ss.IDToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
r := s.RefreshToken
|
||||
if r != "" {
|
||||
if r, err = c.Encrypt(r); err != nil {
|
||||
if ss.RefreshToken != "" {
|
||||
ss.RefreshToken, err = c.Encrypt(ss.RefreshToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil
|
||||
}
|
||||
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
|
||||
ssj := &SessionStateJSON{SessionState: &ss}
|
||||
if !ss.ExpiresOn.IsZero() {
|
||||
ssj.ExpiresOn = &ss.ExpiresOn
|
||||
}
|
||||
b, err := json.Marshal(ssj)
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
func decodeSessionStatePlain(v string) (s *SessionState, err error) {
|
||||
// legacyDecodeSessionStatePlain decodes older plain session state string
|
||||
func legacyDecodeSessionStatePlain(v string) (*SessionState, error) {
|
||||
chunks := strings.Split(v, " ")
|
||||
if len(chunks) != 2 {
|
||||
return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks))
|
||||
return nil, fmt.Errorf("invalid session state (legacy: expected 2 chunks for user/email got %d)", len(chunks))
|
||||
}
|
||||
|
||||
email := strings.TrimPrefix(chunks[0], "email:")
|
||||
user := strings.TrimPrefix(chunks[1], "user:")
|
||||
if user == "" {
|
||||
user = strings.Split(email, "@")[0]
|
||||
}
|
||||
email := strings.TrimPrefix(chunks[0], "email:")
|
||||
|
||||
return &SessionState{User: user, Email: email}, nil
|
||||
}
|
||||
|
||||
// DecodeSessionState decodes the session cookie string into a SessionState
|
||||
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
|
||||
if c == nil {
|
||||
return decodeSessionStatePlain(v)
|
||||
}
|
||||
|
||||
// legacyDecodeSessionState attempts to decode the session state string
|
||||
// generated by v3.1.0 or older
|
||||
func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
|
||||
chunks := strings.Split(v, "|")
|
||||
if len(chunks) != 5 {
|
||||
err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks))
|
||||
return
|
||||
|
||||
if c == nil {
|
||||
if len(chunks) != 1 {
|
||||
return nil, fmt.Errorf("invalid session state (legacy: expected 1 chunk for plain got %d)", len(chunks))
|
||||
}
|
||||
return legacyDecodeSessionStatePlain(chunks[0])
|
||||
}
|
||||
|
||||
sessionState, err := decodeSessionStatePlain(chunks[0])
|
||||
if len(chunks) != 4 && len(chunks) != 5 {
|
||||
return nil, fmt.Errorf("invalid session state (legacy: expected 4 or 5 chunks for full got %d)", len(chunks))
|
||||
}
|
||||
|
||||
i := 0
|
||||
ss, err := legacyDecodeSessionStatePlain(chunks[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if chunks[1] != "" {
|
||||
if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i++
|
||||
ss.AccessToken = chunks[i]
|
||||
|
||||
if len(chunks) == 5 {
|
||||
// SessionState with IDToken in v3.1.0
|
||||
i++
|
||||
ss.IDToken = chunks[i]
|
||||
}
|
||||
|
||||
if chunks[2] != "" {
|
||||
if sessionState.IDToken, err = c.Decrypt(chunks[2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i++
|
||||
ts, err := strconv.Atoi(chunks[i])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err)
|
||||
}
|
||||
ss.ExpiresOn = time.Unix(int64(ts), 0)
|
||||
|
||||
ts, _ := strconv.Atoi(chunks[3])
|
||||
sessionState.ExpiresOn = time.Unix(int64(ts), 0)
|
||||
i++
|
||||
ss.RefreshToken = chunks[i]
|
||||
|
||||
if chunks[4] != "" {
|
||||
if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return sessionState, nil
|
||||
return ss, nil
|
||||
}
|
||||
|
||||
// DecodeSessionState decodes the session cookie string into a SessionState
|
||||
func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
|
||||
var ssj SessionStateJSON
|
||||
var ss *SessionState
|
||||
err := json.Unmarshal([]byte(v), &ssj)
|
||||
if err == nil && ssj.SessionState != nil {
|
||||
// Extract SessionState and ExpiresOn value from SessionStateJSON
|
||||
ss = ssj.SessionState
|
||||
if ssj.ExpiresOn != nil {
|
||||
ss.ExpiresOn = *ssj.ExpiresOn
|
||||
}
|
||||
} else {
|
||||
// Try to decode a legacy string when json.Unmarshal failed
|
||||
ss, err = legacyDecodeSessionState(v, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if c == nil {
|
||||
// Load only Email and User when cipher is unavailable
|
||||
ss = &SessionState{
|
||||
Email: ss.Email,
|
||||
User: ss.User,
|
||||
}
|
||||
} else {
|
||||
if ss.AccessToken != "" {
|
||||
ss.AccessToken, err = c.Decrypt(ss.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if ss.IDToken != "" {
|
||||
ss.IDToken, err = c.Decrypt(ss.IDToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if ss.RefreshToken != "" {
|
||||
ss.RefreshToken, err = c.Decrypt(ss.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if ss.User == "" {
|
||||
ss.User = strings.Split(ss.Email, "@")[0]
|
||||
}
|
||||
return ss, nil
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -27,7 +26,6 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 4, strings.Count(encoded, "|"))
|
||||
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
@ -65,7 +63,6 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 4, strings.Count(encoded, "|"))
|
||||
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
@ -96,8 +93,6 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(nil)
|
||||
assert.Equal(t, nil, err)
|
||||
expected := fmt.Sprintf("email:%s user:", s.Email)
|
||||
assert.Equal(t, expected, encoded)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
@ -118,8 +113,6 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(nil)
|
||||
assert.Equal(t, nil, err)
|
||||
expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User)
|
||||
assert.Equal(t, expected, encoded)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
@ -130,19 +123,6 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
assert.Equal(t, "", ss.RefreshToken)
|
||||
}
|
||||
|
||||
func TestSessionStateAccountInfo(t *testing.T) {
|
||||
s := &SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
}
|
||||
expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User)
|
||||
assert.Equal(t, expected, s.accountInfo())
|
||||
|
||||
s.Email = ""
|
||||
expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User)
|
||||
assert.Equal(t, expected, s.accountInfo())
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
|
||||
assert.Equal(t, true, s.IsExpired())
|
||||
@ -153,3 +133,185 @@ func TestExpired(t *testing.T) {
|
||||
s = &SessionState{}
|
||||
assert.Equal(t, false, s.IsExpired())
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
SessionState
|
||||
Encoded string
|
||||
Cipher *cookie.Cipher
|
||||
Error bool
|
||||
}
|
||||
|
||||
// TestEncodeSessionState tests EncodeSessionState with the test vector
|
||||
//
|
||||
// Currently only tests without cipher here because we have no way to mock
|
||||
// the random generator used in EncodeSessionState.
|
||||
func TestEncodeSessionState(t *testing.T) {
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
ExpiresOn: e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, encoded)
|
||||
continue
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.JSONEq(t, tc.Encoded, encoded)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeSessionState tests DecodeSessionState with the test vector
|
||||
func TestDecodeSessionState(t *testing.T) {
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
eJSON, _ := e.MarshalJSON()
|
||||
eString := string(eJSON)
|
||||
eUnix := e.Unix()
|
||||
|
||||
c, err := cookie.NewCipher([]byte(secret))
|
||||
assert.NoError(t, err)
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
ExpiresOn: e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`,
|
||||
Cipher: c,
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`,
|
||||
Cipher: c,
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
},
|
||||
Encoded: "email:user@domain.com user:just-user",
|
||||
},
|
||||
{
|
||||
Encoded: "email:user@domain.com user:just-user||||",
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
Encoded: "email:user@domain.com user:just-user",
|
||||
Cipher: c,
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
Encoded: "email:user@domain.com user:just-user|||99999999999999999999|",
|
||||
Cipher: c,
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
ExpiresOn: e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix),
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
ExpiresOn: e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix),
|
||||
Cipher: c,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ss)
|
||||
continue
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
if assert.NotNil(t, ss) {
|
||||
assert.Equal(t, tc.User, ss.User)
|
||||
assert.Equal(t, tc.Email, ss.Email)
|
||||
assert.Equal(t, tc.AccessToken, ss.AccessToken)
|
||||
assert.Equal(t, tc.RefreshToken, ss.RefreshToken)
|
||||
assert.Equal(t, tc.IDToken, ss.IDToken)
|
||||
assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user