From 8b3a3853eba73e41584ac48c5493aef1d409f25b Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Mon, 6 May 2019 23:16:01 +0100 Subject: [PATCH] Implement LoadSession in Cookie SessionStore --- pkg/sessions/cookie/session_store.go | 108 +++++++++++++++++++++++++-- pkg/sessions/utils/utils.go | 41 ++++++++++ 2 files changed, 143 insertions(+), 6 deletions(-) create mode 100644 pkg/sessions/utils/utils.go diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 44799f9..2c034b2 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -1,11 +1,16 @@ package cookie import ( + "errors" "fmt" "net/http" + "strings" + "time" + "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/pkg/apis/options" "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "github.com/pusher/oauth2_proxy/pkg/sessions/utils" ) // Ensure CookieSessionStore implements the interface @@ -13,28 +18,119 @@ var _ sessions.SessionStore = &SessionStore{} // SessionStore is an implementation of the sessions.SessionStore // interface that stores sessions in client side cookies -type SessionStore struct{} +type SessionStore struct { + CookieCipher *cookie.Cipher + CookieExpire time.Duration + CookieName string + CookieSecret string +} // SaveSession takes a sessions.SessionState and stores the information from it // within Cookies set on the HTTP response writer -func (c *SessionStore) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { +func (s *SessionStore) SaveSession(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { return fmt.Errorf("method not implemented") } // LoadSession reads sessions.SessionState information from Cookies within the // HTTP request object -func (c *SessionStore) LoadSession(req *http.Request) (*sessions.SessionState, error) { - return nil, fmt.Errorf("method not implemented") +func (s *SessionStore) LoadSession(req *http.Request) (*sessions.SessionState, error) { + c, err := loadCookie(req, s.CookieName) + if err != nil { + // always http.ErrNoCookie + return nil, fmt.Errorf("Cookie %q not present", s.CookieName) + } + val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire) + if !ok { + return nil, errors.New("Cookie Signature not valid") + } + + session, err := utils.SessionFromCookie(val, s.CookieCipher) + if err != nil { + return nil, err + } + return session, nil } // ClearSession clears any saved session information by writing a cookie to // clear the session -func (c *SessionStore) ClearSession(rw http.ResponseWriter, req *http.Request) error { +func (s *SessionStore) ClearSession(rw http.ResponseWriter, req *http.Request) error { return fmt.Errorf("method not implemented") } // NewCookieSessionStore initialises a new instance of the SessionStore from // the configuration given func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { - return &SessionStore{}, fmt.Errorf("method not implemented") + var cipher *cookie.Cipher + if len(cookieOpts.CookieSecret) > 0 { + var err error + cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) + if err != nil { + return nil, fmt.Errorf("unable to create cipher: %v", err) + } + } + + return &SessionStore{ + CookieCipher: cipher, + CookieExpire: cookieOpts.CookieExpire, + CookieName: cookieOpts.CookieName, + CookieSecret: cookieOpts.CookieSecret, + }, nil +} + +// loadCookie retreieves the sessions state cookie from the http request. +// If a single cookie is present this will be returned, otherwise it attempts +// to reconstruct a cookie split up by splitCookie +func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { + c, err := req.Cookie(cookieName) + if err == nil { + return c, nil + } + cookies := []*http.Cookie{} + err = nil + count := 0 + for err == nil { + var c *http.Cookie + c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count)) + if err == nil { + cookies = append(cookies, c) + count++ + } + } + if len(cookies) == 0 { + return nil, fmt.Errorf("Could not find cookie %s", cookieName) + } + return joinCookies(cookies) +} + +// joinCookies takes a slice of cookies from the request and reconstructs the +// full session cookie +func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { + if len(cookies) == 0 { + return nil, fmt.Errorf("list of cookies must be > 0") + } + if len(cookies) == 1 { + return cookies[0], nil + } + c := copyCookie(cookies[0]) + for i := 1; i < len(cookies); i++ { + c.Value += cookies[i].Value + } + c.Name = strings.TrimRight(c.Name, "_0") + return c, nil +} + +func copyCookie(c *http.Cookie) *http.Cookie { + return &http.Cookie{ + Name: c.Name, + Value: c.Value, + Path: c.Path, + Domain: c.Domain, + Expires: c.Expires, + RawExpires: c.RawExpires, + MaxAge: c.MaxAge, + Secure: c.Secure, + HttpOnly: c.HttpOnly, + Raw: c.Raw, + Unparsed: c.Unparsed, + } } diff --git a/pkg/sessions/utils/utils.go b/pkg/sessions/utils/utils.go new file mode 100644 index 0000000..051e9cc --- /dev/null +++ b/pkg/sessions/utils/utils.go @@ -0,0 +1,41 @@ +package utils + +import ( + "encoding/base64" + + "github.com/pusher/oauth2_proxy/cookie" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" +) + +// CookieForSession serializes a session state for storage in a cookie +func CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) { + return s.EncodeSessionState(c) +} + +// SessionFromCookie deserializes a session from a cookie value +func SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) { + return sessions.DecodeSessionState(v, c) +} + +// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary +func SecretBytes(secret string) []byte { + b, err := base64.URLEncoding.DecodeString(addPadding(secret)) + if err == nil { + return []byte(addPadding(string(b))) + } + return []byte(secret) +} + +func addPadding(secret string) string { + padding := len(secret) % 4 + switch padding { + case 1: + return secret + "===" + case 2: + return secret + "==" + case 3: + return secret + "=" + default: + return secret + } +}