oauth2_proxy/providers/session_state.go

199 lines
4.8 KiB
Go
Raw Normal View History

package providers
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
2018-11-27 11:45:05 +00:00
"github.com/pusher/oauth2_proxy/cookie"
)
// SessionState is used to store information about the currently authenticated user session
type SessionState struct {
AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"`
ExpiresOn time.Time `json:"-"`
RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"`
User string `json:",omitempty"`
}
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
type SessionStateJSON struct {
*SessionState
ExpiresOn *time.Time `json:",omitempty"`
}
// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool {
if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
return true
}
return false
}
// String constructs a summary of the session state
func (s *SessionState) String() string {
o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
if s.AccessToken != "" {
o += " token:true"
}
2018-01-27 10:53:17 +00:00
if s.IDToken != "" {
o += " id_token:true"
}
if !s.ExpiresOn.IsZero() {
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
}
if s.RefreshToken != "" {
o += " refresh_token:true"
}
return o + "}"
}
// EncodeSessionState returns string representation of the current session
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
var ss SessionState
if c == nil {
// Store only Email and User when cipher is unavailable
ss.Email = s.Email
ss.User = s.User
} else {
ss = *s
var err error
if ss.AccessToken != "" {
ss.AccessToken, err = c.Encrypt(ss.AccessToken)
if err != nil {
return "", err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Encrypt(ss.IDToken)
if err != nil {
return "", err
}
2018-01-27 10:53:17 +00:00
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Encrypt(ss.RefreshToken)
if err != nil {
return "", err
}
}
}
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
ssj := &SessionStateJSON{SessionState: &ss}
if !ss.ExpiresOn.IsZero() {
ssj.ExpiresOn = &ss.ExpiresOn
}
b, err := json.Marshal(ssj)
return string(b), err
}
// legacyDecodeSessionStatePlain decodes older plain session state string
func legacyDecodeSessionStatePlain(v string) (*SessionState, error) {
chunks := strings.Split(v, " ")
if len(chunks) != 2 {
return nil, fmt.Errorf("invalid session state (legacy: expected 2 chunks for user/email got %d)", len(chunks))
}
user := strings.TrimPrefix(chunks[1], "user:")
email := strings.TrimPrefix(chunks[0], "email:")
return &SessionState{User: user, Email: email}, nil
}
// legacyDecodeSessionState attempts to decode the session state string
// generated by v3.1.0 or older
func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
chunks := strings.Split(v, "|")
if c == nil {
if len(chunks) != 1 {
return nil, fmt.Errorf("invalid session state (legacy: expected 1 chunk for plain got %d)", len(chunks))
}
return legacyDecodeSessionStatePlain(chunks[0])
}
if len(chunks) != 4 && len(chunks) != 5 {
return nil, fmt.Errorf("invalid session state (legacy: expected 4 or 5 chunks for full got %d)", len(chunks))
}
i := 0
ss, err := legacyDecodeSessionStatePlain(chunks[i])
if err != nil {
return nil, err
}
i++
ss.AccessToken = chunks[i]
if len(chunks) == 5 {
// SessionState with IDToken in v3.1.0
i++
ss.IDToken = chunks[i]
}
i++
ts, err := strconv.Atoi(chunks[i])
if err != nil {
return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err)
2018-01-27 10:53:17 +00:00
}
ss.ExpiresOn = time.Unix(int64(ts), 0)
i++
ss.RefreshToken = chunks[i]
2018-01-27 10:53:17 +00:00
return ss, nil
}
// DecodeSessionState decodes the session cookie string into a SessionState
func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
var ssj SessionStateJSON
var ss *SessionState
err := json.Unmarshal([]byte(v), &ssj)
if err == nil && ssj.SessionState != nil {
// Extract SessionState and ExpiresOn value from SessionStateJSON
ss = ssj.SessionState
if ssj.ExpiresOn != nil {
ss.ExpiresOn = *ssj.ExpiresOn
}
} else {
// Try to decode a legacy string when json.Unmarshal failed
ss, err = legacyDecodeSessionState(v, c)
if err != nil {
return nil, err
}
}
if c == nil {
// Load only Email and User when cipher is unavailable
ss = &SessionState{
Email: ss.Email,
User: ss.User,
}
} else {
if ss.AccessToken != "" {
ss.AccessToken, err = c.Decrypt(ss.AccessToken)
if err != nil {
return nil, err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Decrypt(ss.IDToken)
if err != nil {
return nil, err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Decrypt(ss.RefreshToken)
if err != nil {
return nil, err
}
}
}
if ss.User == "" {
ss.User = strings.Split(ss.Email, "@")[0]
}
return ss, nil
}