From 2070fae47c6fb00c56ae4610e05913a96507e915 Mon Sep 17 00:00:00 2001 From: YAEGASHI Takeshi Date: Wed, 20 Mar 2019 22:59:24 +0900 Subject: [PATCH] Use encoding/json for SessionState serialization (#63) * Use encoding/json for SessionState serialization In order to make it easier to extend in future. * Store only email and user in cookie when cipher is unavailable This improves safety and robustness, and also preserves the existing behaviour. * Add TestEncodeSessionState/TestDecodeSessionState Use the test vectors with JSON encoding just introduced. * Support session state encoding in older versions * Add test cases for legacy session state strings * Add check for wrong expiration time in session state strings * Avoid exposing time.Time zero value when encoding session state string * Update CHANGELOG.md --- CHANGELOG.md | 4 + providers/session_state.go | 199 ++++++++++++++++++++----------- providers/session_state_test.go | 202 ++++++++++++++++++++++++++++---- 3 files changed, 314 insertions(+), 91 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a16aa3..9823907 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Changes since v3.1.0 +- [#63](https://github.com/pusher/oauth2_proxy/pull/63) Use encoding/json for SessionState serialization (@yaegashi) + - Use JSON to encode session state to be stored in browser cookies + - Implement legacy decode function to support existing cookies generated by older versions + - Add detailed table driven tests in session_state_test.go - [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added login.gov provider (@timothy-spencer) - [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added environment variables for all config options (@timothy-spencer) - [#70](https://github.com/pusher/oauth2_proxy/pull/70) Fix handling of splitted cookies (@einfachchr) diff --git a/providers/session_state.go b/providers/session_state.go index 2862cdd..4741b4a 100644 --- a/providers/session_state.go +++ b/providers/session_state.go @@ -1,6 +1,7 @@ package providers import ( + "encoding/json" "fmt" "strconv" "strings" @@ -11,12 +12,18 @@ import ( // SessionState is used to store information about the currently authenticated user session type SessionState struct { - AccessToken string - IDToken string - ExpiresOn time.Time - RefreshToken string - Email string - User string + AccessToken string `json:",omitempty"` + IDToken string `json:",omitempty"` + ExpiresOn time.Time `json:"-"` + RefreshToken string `json:",omitempty"` + Email string `json:",omitempty"` + User string `json:",omitempty"` +} + +// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value +type SessionStateJSON struct { + *SessionState + ExpiresOn *time.Time `json:",omitempty"` } // IsExpired checks whether the session has expired @@ -29,7 +36,7 @@ func (s *SessionState) IsExpired() bool { // String constructs a summary of the session state func (s *SessionState) String() string { - o := fmt.Sprintf("Session{%s", s.accountInfo()) + o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User) if s.AccessToken != "" { o += " token:true" } @@ -47,95 +54,145 @@ func (s *SessionState) String() string { // EncodeSessionState returns string representation of the current session func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { - if c == nil || s.AccessToken == "" { - return s.accountInfo(), nil - } - return s.EncryptedString(c) -} - -func (s *SessionState) accountInfo() string { - return fmt.Sprintf("email:%s user:%s", s.Email, s.User) -} - -// EncryptedString encrypts the session state into a cookie string -func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { - var err error + var ss SessionState if c == nil { - panic("error. missing cipher") - } - a := s.AccessToken - if a != "" { - if a, err = c.Encrypt(a); err != nil { - return "", err + // Store only Email and User when cipher is unavailable + ss.Email = s.Email + ss.User = s.User + } else { + ss = *s + var err error + if ss.AccessToken != "" { + ss.AccessToken, err = c.Encrypt(ss.AccessToken) + if err != nil { + return "", err + } + } + if ss.IDToken != "" { + ss.IDToken, err = c.Encrypt(ss.IDToken) + if err != nil { + return "", err + } + } + if ss.RefreshToken != "" { + ss.RefreshToken, err = c.Encrypt(ss.RefreshToken) + if err != nil { + return "", err + } } } - i := s.IDToken - if i != "" { - if i, err = c.Encrypt(i); err != nil { - return "", err - } + // Embed SessionState and ExpiresOn pointer into SessionStateJSON + ssj := &SessionStateJSON{SessionState: &ss} + if !ss.ExpiresOn.IsZero() { + ssj.ExpiresOn = &ss.ExpiresOn } - r := s.RefreshToken - if r != "" { - if r, err = c.Encrypt(r); err != nil { - return "", err - } - } - return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil + b, err := json.Marshal(ssj) + return string(b), err } -func decodeSessionStatePlain(v string) (s *SessionState, err error) { +// legacyDecodeSessionStatePlain decodes older plain session state string +func legacyDecodeSessionStatePlain(v string) (*SessionState, error) { chunks := strings.Split(v, " ") if len(chunks) != 2 { - return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks)) + return nil, fmt.Errorf("invalid session state (legacy: expected 2 chunks for user/email got %d)", len(chunks)) } - email := strings.TrimPrefix(chunks[0], "email:") user := strings.TrimPrefix(chunks[1], "user:") - if user == "" { - user = strings.Split(email, "@")[0] - } + email := strings.TrimPrefix(chunks[0], "email:") return &SessionState{User: user, Email: email}, nil } -// DecodeSessionState decodes the session cookie string into a SessionState -func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { - if c == nil { - return decodeSessionStatePlain(v) - } - +// legacyDecodeSessionState attempts to decode the session state string +// generated by v3.1.0 or older +func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) { chunks := strings.Split(v, "|") - if len(chunks) != 5 { - err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks)) - return + + if c == nil { + if len(chunks) != 1 { + return nil, fmt.Errorf("invalid session state (legacy: expected 1 chunk for plain got %d)", len(chunks)) + } + return legacyDecodeSessionStatePlain(chunks[0]) } - sessionState, err := decodeSessionStatePlain(chunks[0]) + if len(chunks) != 4 && len(chunks) != 5 { + return nil, fmt.Errorf("invalid session state (legacy: expected 4 or 5 chunks for full got %d)", len(chunks)) + } + + i := 0 + ss, err := legacyDecodeSessionStatePlain(chunks[i]) if err != nil { return nil, err } - if chunks[1] != "" { - if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil { - return nil, err - } + i++ + ss.AccessToken = chunks[i] + + if len(chunks) == 5 { + // SessionState with IDToken in v3.1.0 + i++ + ss.IDToken = chunks[i] } - if chunks[2] != "" { - if sessionState.IDToken, err = c.Decrypt(chunks[2]); err != nil { - return nil, err - } + i++ + ts, err := strconv.Atoi(chunks[i]) + if err != nil { + return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err) } + ss.ExpiresOn = time.Unix(int64(ts), 0) - ts, _ := strconv.Atoi(chunks[3]) - sessionState.ExpiresOn = time.Unix(int64(ts), 0) + i++ + ss.RefreshToken = chunks[i] - if chunks[4] != "" { - if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil { - return nil, err - } - } - - return sessionState, nil + return ss, nil +} + +// DecodeSessionState decodes the session cookie string into a SessionState +func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) { + var ssj SessionStateJSON + var ss *SessionState + err := json.Unmarshal([]byte(v), &ssj) + if err == nil && ssj.SessionState != nil { + // Extract SessionState and ExpiresOn value from SessionStateJSON + ss = ssj.SessionState + if ssj.ExpiresOn != nil { + ss.ExpiresOn = *ssj.ExpiresOn + } + } else { + // Try to decode a legacy string when json.Unmarshal failed + ss, err = legacyDecodeSessionState(v, c) + if err != nil { + return nil, err + } + } + if c == nil { + // Load only Email and User when cipher is unavailable + ss = &SessionState{ + Email: ss.Email, + User: ss.User, + } + } else { + if ss.AccessToken != "" { + ss.AccessToken, err = c.Decrypt(ss.AccessToken) + if err != nil { + return nil, err + } + } + if ss.IDToken != "" { + ss.IDToken, err = c.Decrypt(ss.IDToken) + if err != nil { + return nil, err + } + } + if ss.RefreshToken != "" { + ss.RefreshToken, err = c.Decrypt(ss.RefreshToken) + if err != nil { + return nil, err + } + } + } + if ss.User == "" { + ss.User = strings.Split(ss.Email, "@")[0] + } + return ss, nil } diff --git a/providers/session_state_test.go b/providers/session_state_test.go index 504228f..9557eea 100644 --- a/providers/session_state_test.go +++ b/providers/session_state_test.go @@ -2,7 +2,6 @@ package providers import ( "fmt" - "strings" "testing" "time" @@ -27,7 +26,6 @@ func TestSessionStateSerialization(t *testing.T) { } encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - assert.Equal(t, 4, strings.Count(encoded, "|")) ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) @@ -65,7 +63,6 @@ func TestSessionStateSerializationWithUser(t *testing.T) { } encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - assert.Equal(t, 4, strings.Count(encoded, "|")) ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) @@ -96,8 +93,6 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { } encoded, err := s.EncodeSessionState(nil) assert.Equal(t, nil, err) - expected := fmt.Sprintf("email:%s user:", s.Email) - assert.Equal(t, expected, encoded) // only email should have been serialized ss, err := DecodeSessionState(encoded, nil) @@ -118,8 +113,6 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { } encoded, err := s.EncodeSessionState(nil) assert.Equal(t, nil, err) - expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User) - assert.Equal(t, expected, encoded) // only email should have been serialized ss, err := DecodeSessionState(encoded, nil) @@ -130,19 +123,6 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { assert.Equal(t, "", ss.RefreshToken) } -func TestSessionStateAccountInfo(t *testing.T) { - s := &SessionState{ - Email: "user@domain.com", - User: "just-user", - } - expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User) - assert.Equal(t, expected, s.accountInfo()) - - s.Email = "" - expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User) - assert.Equal(t, expected, s.accountInfo()) -} - func TestExpired(t *testing.T) { s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} assert.Equal(t, true, s.IsExpired()) @@ -153,3 +133,185 @@ func TestExpired(t *testing.T) { s = &SessionState{} assert.Equal(t, false, s.IsExpired()) } + +type testCase struct { + SessionState + Encoded string + Cipher *cookie.Cipher + Error bool +} + +// TestEncodeSessionState tests EncodeSessionState with the test vector +// +// Currently only tests without cipher here because we have no way to mock +// the random generator used in EncodeSessionState. +func TestEncodeSessionState(t *testing.T) { + e := time.Now().Add(time.Duration(1) * time.Hour) + + testCases := []testCase{ + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + }, + Encoded: `{"Email":"user@domain.com","User":"just-user"}`, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + AccessToken: "token1234", + IDToken: "rawtoken1234", + ExpiresOn: e, + RefreshToken: "refresh4321", + }, + Encoded: `{"Email":"user@domain.com","User":"just-user"}`, + }, + } + + for i, tc := range testCases { + encoded, err := tc.EncodeSessionState(tc.Cipher) + t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) + if tc.Error { + assert.Error(t, err) + assert.Empty(t, encoded) + continue + } + assert.NoError(t, err) + assert.JSONEq(t, tc.Encoded, encoded) + } +} + +// TestDecodeSessionState tests DecodeSessionState with the test vector +func TestDecodeSessionState(t *testing.T) { + e := time.Now().Add(time.Duration(1) * time.Hour) + eJSON, _ := e.MarshalJSON() + eString := string(eJSON) + eUnix := e.Unix() + + c, err := cookie.NewCipher([]byte(secret)) + assert.NoError(t, err) + + testCases := []testCase{ + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + }, + Encoded: `{"Email":"user@domain.com","User":"just-user"}`, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "user", + }, + Encoded: `{"Email":"user@domain.com"}`, + }, + { + SessionState: SessionState{ + User: "just-user", + }, + Encoded: `{"User":"just-user"}`, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + }, + Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + AccessToken: "token1234", + IDToken: "rawtoken1234", + ExpiresOn: e, + RefreshToken: "refresh4321", + }, + Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), + Cipher: c, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + }, + Encoded: `{"Email":"user@domain.com","User":"just-user"}`, + Cipher: c, + }, + { + Encoded: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`, + Cipher: c, + Error: true, + }, + { + Encoded: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`, + Cipher: c, + Error: true, + }, + { + SessionState: SessionState{ + User: "just-user", + Email: "user@domain.com", + }, + Encoded: "email:user@domain.com user:just-user", + }, + { + Encoded: "email:user@domain.com user:just-user||||", + Error: true, + }, + { + Encoded: "email:user@domain.com user:just-user", + Cipher: c, + Error: true, + }, + { + Encoded: "email:user@domain.com user:just-user|||99999999999999999999|", + Cipher: c, + Error: true, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + AccessToken: "token1234", + ExpiresOn: e, + RefreshToken: "refresh4321", + }, + Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix), + Cipher: c, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + AccessToken: "token1234", + IDToken: "rawtoken1234", + ExpiresOn: e, + RefreshToken: "refresh4321", + }, + Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix), + Cipher: c, + }, + } + + for i, tc := range testCases { + ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) + t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) + if tc.Error { + assert.Error(t, err) + assert.Nil(t, ss) + continue + } + assert.NoError(t, err) + if assert.NotNil(t, ss) { + assert.Equal(t, tc.User, ss.User) + assert.Equal(t, tc.Email, ss.Email) + assert.Equal(t, tc.AccessToken, ss.AccessToken) + assert.Equal(t, tc.RefreshToken, ss.RefreshToken) + assert.Equal(t, tc.IDToken, ss.IDToken) + assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) + } + } +}