244 lines
5.9 KiB
Go
244 lines
5.9 KiB
Go
package sessions
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pusher/oauth2_proxy/cookie"
|
|
)
|
|
|
|
// SessionState is used to store information about the currently authenticated user session
|
|
type SessionState struct {
|
|
AccessToken string `json:",omitempty"`
|
|
IDToken string `json:",omitempty"`
|
|
CreatedAt time.Time `json:"-"`
|
|
ExpiresOn time.Time `json:"-"`
|
|
RefreshToken string `json:",omitempty"`
|
|
Email string `json:",omitempty"`
|
|
User string `json:",omitempty"`
|
|
}
|
|
|
|
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
|
|
type SessionStateJSON struct {
|
|
*SessionState
|
|
CreatedAt *time.Time `json:",omitempty"`
|
|
ExpiresOn *time.Time `json:",omitempty"`
|
|
}
|
|
|
|
// IsExpired checks whether the session has expired
|
|
func (s *SessionState) IsExpired() bool {
|
|
if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Age returns the age of a session
|
|
func (s *SessionState) Age() time.Duration {
|
|
if !s.CreatedAt.IsZero() {
|
|
return time.Now().Truncate(time.Second).Sub(s.CreatedAt)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// String constructs a summary of the session state
|
|
func (s *SessionState) String() string {
|
|
o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
|
|
if s.AccessToken != "" {
|
|
o += " token:true"
|
|
}
|
|
if s.IDToken != "" {
|
|
o += " id_token:true"
|
|
}
|
|
if !s.CreatedAt.IsZero() {
|
|
o += fmt.Sprintf(" created:%s", s.CreatedAt)
|
|
}
|
|
if !s.ExpiresOn.IsZero() {
|
|
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
|
|
}
|
|
if s.RefreshToken != "" {
|
|
o += " refresh_token:true"
|
|
}
|
|
return o + "}"
|
|
}
|
|
|
|
// EncodeSessionState returns string representation of the current session
|
|
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
|
|
var ss SessionState
|
|
if c == nil {
|
|
// Store only Email and User when cipher is unavailable
|
|
ss.Email = s.Email
|
|
ss.User = s.User
|
|
} else {
|
|
ss = *s
|
|
var err error
|
|
if ss.Email != "" {
|
|
ss.Email, err = c.Encrypt(ss.Email)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
if ss.User != "" {
|
|
ss.User, err = c.Encrypt(ss.User)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
if ss.AccessToken != "" {
|
|
ss.AccessToken, err = c.Encrypt(ss.AccessToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
if ss.IDToken != "" {
|
|
ss.IDToken, err = c.Encrypt(ss.IDToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
if ss.RefreshToken != "" {
|
|
ss.RefreshToken, err = c.Encrypt(ss.RefreshToken)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
}
|
|
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
|
|
ssj := &SessionStateJSON{SessionState: &ss}
|
|
if !ss.CreatedAt.IsZero() {
|
|
ssj.CreatedAt = &ss.CreatedAt
|
|
}
|
|
if !ss.ExpiresOn.IsZero() {
|
|
ssj.ExpiresOn = &ss.ExpiresOn
|
|
}
|
|
b, err := json.Marshal(ssj)
|
|
return string(b), err
|
|
}
|
|
|
|
// legacyDecodeSessionStatePlain decodes older plain session state string
|
|
func legacyDecodeSessionStatePlain(v string) (*SessionState, error) {
|
|
chunks := strings.Split(v, " ")
|
|
if len(chunks) != 2 {
|
|
return nil, fmt.Errorf("invalid session state (legacy: expected 2 chunks for user/email got %d)", len(chunks))
|
|
}
|
|
|
|
user := strings.TrimPrefix(chunks[1], "user:")
|
|
email := strings.TrimPrefix(chunks[0], "email:")
|
|
|
|
return &SessionState{User: user, Email: email}, nil
|
|
}
|
|
|
|
// legacyDecodeSessionState attempts to decode the session state string
|
|
// generated by v3.1.0 or older
|
|
func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
|
|
chunks := strings.Split(v, "|")
|
|
|
|
if c == nil {
|
|
if len(chunks) != 1 {
|
|
return nil, fmt.Errorf("invalid session state (legacy: expected 1 chunk for plain got %d)", len(chunks))
|
|
}
|
|
return legacyDecodeSessionStatePlain(chunks[0])
|
|
}
|
|
|
|
if len(chunks) != 4 && len(chunks) != 5 {
|
|
return nil, fmt.Errorf("invalid session state (legacy: expected 4 or 5 chunks for full got %d)", len(chunks))
|
|
}
|
|
|
|
i := 0
|
|
ss, err := legacyDecodeSessionStatePlain(chunks[i])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
i++
|
|
ss.AccessToken = chunks[i]
|
|
|
|
if len(chunks) == 5 {
|
|
// SessionState with IDToken in v3.1.0
|
|
i++
|
|
ss.IDToken = chunks[i]
|
|
}
|
|
|
|
i++
|
|
ts, err := strconv.Atoi(chunks[i])
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err)
|
|
}
|
|
ss.ExpiresOn = time.Unix(int64(ts), 0)
|
|
|
|
i++
|
|
ss.RefreshToken = chunks[i]
|
|
|
|
return ss, nil
|
|
}
|
|
|
|
// DecodeSessionState decodes the session cookie string into a SessionState
|
|
func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
|
|
var ssj SessionStateJSON
|
|
var ss *SessionState
|
|
err := json.Unmarshal([]byte(v), &ssj)
|
|
if err == nil && ssj.SessionState != nil {
|
|
// Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
|
|
ss = ssj.SessionState
|
|
if ssj.CreatedAt != nil {
|
|
ss.CreatedAt = *ssj.CreatedAt
|
|
}
|
|
if ssj.ExpiresOn != nil {
|
|
ss.ExpiresOn = *ssj.ExpiresOn
|
|
}
|
|
} else {
|
|
// Try to decode a legacy string when json.Unmarshal failed
|
|
ss, err = legacyDecodeSessionState(v, c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if c == nil {
|
|
// Load only Email and User when cipher is unavailable
|
|
ss = &SessionState{
|
|
Email: ss.Email,
|
|
User: ss.User,
|
|
}
|
|
} else {
|
|
// Backward compatibility with using unecrypted Email
|
|
if ss.Email != "" {
|
|
decryptedEmail, errEmail := c.Decrypt(ss.Email)
|
|
if errEmail == nil {
|
|
ss.Email = decryptedEmail
|
|
}
|
|
}
|
|
// Backward compatibility with using unecrypted User
|
|
if ss.User != "" {
|
|
decryptedUser, errUser := c.Decrypt(ss.User)
|
|
if errUser == nil {
|
|
ss.User = decryptedUser
|
|
}
|
|
}
|
|
if ss.AccessToken != "" {
|
|
ss.AccessToken, err = c.Decrypt(ss.AccessToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if ss.IDToken != "" {
|
|
ss.IDToken, err = c.Decrypt(ss.IDToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
if ss.RefreshToken != "" {
|
|
ss.RefreshToken, err = c.Decrypt(ss.RefreshToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
if ss.User == "" {
|
|
ss.User = ss.Email
|
|
}
|
|
return ss, nil
|
|
}
|