Use encoding/json for SessionState serialization (#63)

* Use encoding/json for SessionState serialization

In order to make it easier to extend in future.

* Store only email and user in cookie when cipher is unavailable

This improves safety and robustness, and also preserves the existing
behaviour.

* Add TestEncodeSessionState/TestDecodeSessionState

Use the test vectors with JSON encoding just introduced.

* Support session state encoding in older versions

* Add test cases for legacy session state strings

* Add check for wrong expiration time in session state strings

* Avoid exposing time.Time zero value when encoding session state string

* Update CHANGELOG.md
This commit is contained in:
YAEGASHI Takeshi 2019-03-20 22:59:24 +09:00 committed by Joel Speed
parent a656435d00
commit 2070fae47c
3 changed files with 314 additions and 91 deletions

View File

@ -2,6 +2,10 @@
## Changes since v3.1.0 ## Changes since v3.1.0
- [#63](https://github.com/pusher/oauth2_proxy/pull/63) Use encoding/json for SessionState serialization (@yaegashi)
- Use JSON to encode session state to be stored in browser cookies
- Implement legacy decode function to support existing cookies generated by older versions
- Add detailed table driven tests in session_state_test.go
- [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added login.gov provider (@timothy-spencer) - [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added login.gov provider (@timothy-spencer)
- [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added environment variables for all config options (@timothy-spencer) - [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added environment variables for all config options (@timothy-spencer)
- [#70](https://github.com/pusher/oauth2_proxy/pull/70) Fix handling of splitted cookies (@einfachchr) - [#70](https://github.com/pusher/oauth2_proxy/pull/70) Fix handling of splitted cookies (@einfachchr)

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@ -11,12 +12,18 @@ import (
// SessionState is used to store information about the currently authenticated user session // SessionState is used to store information about the currently authenticated user session
type SessionState struct { type SessionState struct {
AccessToken string AccessToken string `json:",omitempty"`
IDToken string IDToken string `json:",omitempty"`
ExpiresOn time.Time ExpiresOn time.Time `json:"-"`
RefreshToken string RefreshToken string `json:",omitempty"`
Email string Email string `json:",omitempty"`
User string User string `json:",omitempty"`
}
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
type SessionStateJSON struct {
*SessionState
ExpiresOn *time.Time `json:",omitempty"`
} }
// IsExpired checks whether the session has expired // IsExpired checks whether the session has expired
@ -29,7 +36,7 @@ func (s *SessionState) IsExpired() bool {
// String constructs a summary of the session state // String constructs a summary of the session state
func (s *SessionState) String() string { func (s *SessionState) String() string {
o := fmt.Sprintf("Session{%s", s.accountInfo()) o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
if s.AccessToken != "" { if s.AccessToken != "" {
o += " token:true" o += " token:true"
} }
@ -47,95 +54,145 @@ func (s *SessionState) String() string {
// EncodeSessionState returns string representation of the current session // EncodeSessionState returns string representation of the current session
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
if c == nil || s.AccessToken == "" { var ss SessionState
return s.accountInfo(), nil
}
return s.EncryptedString(c)
}
func (s *SessionState) accountInfo() string {
return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
}
// EncryptedString encrypts the session state into a cookie string
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
var err error
if c == nil { if c == nil {
panic("error. missing cipher") // Store only Email and User when cipher is unavailable
} ss.Email = s.Email
a := s.AccessToken ss.User = s.User
if a != "" { } else {
if a, err = c.Encrypt(a); err != nil { ss = *s
return "", err var err error
if ss.AccessToken != "" {
ss.AccessToken, err = c.Encrypt(ss.AccessToken)
if err != nil {
return "", err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Encrypt(ss.IDToken)
if err != nil {
return "", err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Encrypt(ss.RefreshToken)
if err != nil {
return "", err
}
} }
} }
i := s.IDToken // Embed SessionState and ExpiresOn pointer into SessionStateJSON
if i != "" { ssj := &SessionStateJSON{SessionState: &ss}
if i, err = c.Encrypt(i); err != nil { if !ss.ExpiresOn.IsZero() {
return "", err ssj.ExpiresOn = &ss.ExpiresOn
}
} }
r := s.RefreshToken b, err := json.Marshal(ssj)
if r != "" { return string(b), err
if r, err = c.Encrypt(r); err != nil {
return "", err
}
}
return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil
} }
func decodeSessionStatePlain(v string) (s *SessionState, err error) { // legacyDecodeSessionStatePlain decodes older plain session state string
func legacyDecodeSessionStatePlain(v string) (*SessionState, error) {
chunks := strings.Split(v, " ") chunks := strings.Split(v, " ")
if len(chunks) != 2 { if len(chunks) != 2 {
return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks)) return nil, fmt.Errorf("invalid session state (legacy: expected 2 chunks for user/email got %d)", len(chunks))
} }
email := strings.TrimPrefix(chunks[0], "email:")
user := strings.TrimPrefix(chunks[1], "user:") user := strings.TrimPrefix(chunks[1], "user:")
if user == "" { email := strings.TrimPrefix(chunks[0], "email:")
user = strings.Split(email, "@")[0]
}
return &SessionState{User: user, Email: email}, nil return &SessionState{User: user, Email: email}, nil
} }
// DecodeSessionState decodes the session cookie string into a SessionState // legacyDecodeSessionState attempts to decode the session state string
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { // generated by v3.1.0 or older
if c == nil { func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
return decodeSessionStatePlain(v)
}
chunks := strings.Split(v, "|") chunks := strings.Split(v, "|")
if len(chunks) != 5 {
err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks)) if c == nil {
return if len(chunks) != 1 {
return nil, fmt.Errorf("invalid session state (legacy: expected 1 chunk for plain got %d)", len(chunks))
}
return legacyDecodeSessionStatePlain(chunks[0])
} }
sessionState, err := decodeSessionStatePlain(chunks[0]) if len(chunks) != 4 && len(chunks) != 5 {
return nil, fmt.Errorf("invalid session state (legacy: expected 4 or 5 chunks for full got %d)", len(chunks))
}
i := 0
ss, err := legacyDecodeSessionStatePlain(chunks[i])
if err != nil { if err != nil {
return nil, err return nil, err
} }
if chunks[1] != "" { i++
if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil { ss.AccessToken = chunks[i]
return nil, err
} if len(chunks) == 5 {
// SessionState with IDToken in v3.1.0
i++
ss.IDToken = chunks[i]
} }
if chunks[2] != "" { i++
if sessionState.IDToken, err = c.Decrypt(chunks[2]); err != nil { ts, err := strconv.Atoi(chunks[i])
return nil, err if err != nil {
} return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err)
} }
ss.ExpiresOn = time.Unix(int64(ts), 0)
ts, _ := strconv.Atoi(chunks[3]) i++
sessionState.ExpiresOn = time.Unix(int64(ts), 0) ss.RefreshToken = chunks[i]
if chunks[4] != "" { return ss, nil
if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil { }
return nil, err
} // DecodeSessionState decodes the session cookie string into a SessionState
} func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
var ssj SessionStateJSON
return sessionState, nil var ss *SessionState
err := json.Unmarshal([]byte(v), &ssj)
if err == nil && ssj.SessionState != nil {
// Extract SessionState and ExpiresOn value from SessionStateJSON
ss = ssj.SessionState
if ssj.ExpiresOn != nil {
ss.ExpiresOn = *ssj.ExpiresOn
}
} else {
// Try to decode a legacy string when json.Unmarshal failed
ss, err = legacyDecodeSessionState(v, c)
if err != nil {
return nil, err
}
}
if c == nil {
// Load only Email and User when cipher is unavailable
ss = &SessionState{
Email: ss.Email,
User: ss.User,
}
} else {
if ss.AccessToken != "" {
ss.AccessToken, err = c.Decrypt(ss.AccessToken)
if err != nil {
return nil, err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Decrypt(ss.IDToken)
if err != nil {
return nil, err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Decrypt(ss.RefreshToken)
if err != nil {
return nil, err
}
}
}
if ss.User == "" {
ss.User = strings.Split(ss.Email, "@")[0]
}
return ss, nil
} }

View File

@ -2,7 +2,6 @@ package providers
import ( import (
"fmt" "fmt"
"strings"
"testing" "testing"
"time" "time"
@ -27,7 +26,6 @@ func TestSessionStateSerialization(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(c) encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, 4, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c) ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
@ -65,7 +63,6 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(c) encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, 4, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c) ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
@ -96,8 +93,6 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(nil) encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
expected := fmt.Sprintf("email:%s user:", s.Email)
assert.Equal(t, expected, encoded)
// only email should have been serialized // only email should have been serialized
ss, err := DecodeSessionState(encoded, nil) ss, err := DecodeSessionState(encoded, nil)
@ -118,8 +113,6 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(nil) encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User)
assert.Equal(t, expected, encoded)
// only email should have been serialized // only email should have been serialized
ss, err := DecodeSessionState(encoded, nil) ss, err := DecodeSessionState(encoded, nil)
@ -130,19 +123,6 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
assert.Equal(t, "", ss.RefreshToken) assert.Equal(t, "", ss.RefreshToken)
} }
func TestSessionStateAccountInfo(t *testing.T) {
s := &SessionState{
Email: "user@domain.com",
User: "just-user",
}
expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User)
assert.Equal(t, expected, s.accountInfo())
s.Email = ""
expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User)
assert.Equal(t, expected, s.accountInfo())
}
func TestExpired(t *testing.T) { func TestExpired(t *testing.T) {
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
assert.Equal(t, true, s.IsExpired()) assert.Equal(t, true, s.IsExpired())
@ -153,3 +133,185 @@ func TestExpired(t *testing.T) {
s = &SessionState{} s = &SessionState{}
assert.Equal(t, false, s.IsExpired()) assert.Equal(t, false, s.IsExpired())
} }
type testCase struct {
SessionState
Encoded string
Cipher *cookie.Cipher
Error bool
}
// TestEncodeSessionState tests EncodeSessionState with the test vector
//
// Currently only tests without cipher here because we have no way to mock
// the random generator used in EncodeSessionState.
func TestEncodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour)
testCases := []testCase{
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
}
for i, tc := range testCases {
encoded, err := tc.EncodeSessionState(tc.Cipher)
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
if tc.Error {
assert.Error(t, err)
assert.Empty(t, encoded)
continue
}
assert.NoError(t, err)
assert.JSONEq(t, tc.Encoded, encoded)
}
}
// TestDecodeSessionState tests DecodeSessionState with the test vector
func TestDecodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour)
eJSON, _ := e.MarshalJSON()
eString := string(eJSON)
eUnix := e.Unix()
c, err := cookie.NewCipher([]byte(secret))
assert.NoError(t, err)
testCases := []testCase{
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "user",
},
Encoded: `{"Email":"user@domain.com"}`,
},
{
SessionState: SessionState{
User: "just-user",
},
Encoded: `{"User":"just-user"}`,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
Cipher: c,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
Cipher: c,
},
{
Encoded: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`,
Cipher: c,
Error: true,
},
{
Encoded: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`,
Cipher: c,
Error: true,
},
{
SessionState: SessionState{
User: "just-user",
Email: "user@domain.com",
},
Encoded: "email:user@domain.com user:just-user",
},
{
Encoded: "email:user@domain.com user:just-user||||",
Error: true,
},
{
Encoded: "email:user@domain.com user:just-user",
Cipher: c,
Error: true,
},
{
Encoded: "email:user@domain.com user:just-user|||99999999999999999999|",
Cipher: c,
Error: true,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix),
Cipher: c,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix),
Cipher: c,
},
}
for i, tc := range testCases {
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
if tc.Error {
assert.Error(t, err)
assert.Nil(t, ss)
continue
}
assert.NoError(t, err)
if assert.NotNil(t, ss) {
assert.Equal(t, tc.User, ss.User)
assert.Equal(t, tc.Email, ss.Email)
assert.Equal(t, tc.AccessToken, ss.AccessToken)
assert.Equal(t, tc.RefreshToken, ss.RefreshToken)
assert.Equal(t, tc.IDToken, ss.IDToken)
assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
}
}
}