diff --git a/pkg/sessions/redis/redis_store.go b/pkg/sessions/redis/redis_store.go index c240071..ec55a03 100644 --- a/pkg/sessions/redis/redis_store.go +++ b/pkg/sessions/redis/redis_store.go @@ -17,7 +17,6 @@ import ( "github.com/pusher/oauth2_proxy/pkg/apis/options" "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/pkg/cookies" - "github.com/pusher/oauth2_proxy/pkg/sessions/utils" ) // TicketData is a structure representing the ticket used in server session storage @@ -29,46 +28,25 @@ type TicketData struct { // SessionStore is an implementation of the sessions.SessionStore // interface that stores sessions in redis type SessionStore struct { - CookieCipher *cookie.Cipher - CookieDomain string - CookieExpire time.Duration - CookieHTTPOnly bool - CookieName string - CookiePath string - CookieSecret string - CookieSecure bool - Client *redis.Client + CookieCipher *cookie.Cipher + CookieOptions *options.CookieOptions + Client *redis.Client } // NewRedisSessionStore initialises a new instance of the SessionStore from // the configuration given -func NewRedisSessionStore(opts options.RedisStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { - opt, err := redis.ParseURL(opts.RedisConnectionURL) +func NewRedisSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { + opt, err := redis.ParseURL(opts.RedisStoreOptions.RedisConnectionURL) if err != nil { return nil, fmt.Errorf("unable to parse redis url: %s", err) } - var cookieCipher *cookie.Cipher - if len(cookieOpts.CookieSecret) > 0 { - var err error - cookieCipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) - if err != nil { - return nil, fmt.Errorf("unable to create cookieCipher: %v", err) - } - } - client := redis.NewClient(opt) rs := &SessionStore{ - Client: client, - CookieCipher: cookieCipher, - CookieDomain: cookieOpts.CookieDomain, - CookieExpire: cookieOpts.CookieExpire, - CookieHTTPOnly: cookieOpts.CookieHTTPOnly, - CookieName: cookieOpts.CookieName, - CookiePath: cookieOpts.CookiePath, - CookieSecret: cookieOpts.CookieSecret, - CookieSecure: cookieOpts.CookieSecure, + Client: client, + CookieCipher: opts.Cipher, + CookieOptions: cookieOpts, } return rs, nil @@ -79,7 +57,7 @@ func NewRedisSessionStore(opts options.RedisStoreOptions, cookieOpts *options.Co func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { // Old sessions that we are refreshing would have a request cookie // New sessions don't, so we ignore the error. storeValue will check requestCookie - requestCookie, _ := req.Cookie(store.CookieName) + requestCookie, _ := req.Cookie(store.CookieOptions.CookieName) value, err := s.EncodeSessionState(store.CookieCipher) if err != nil { return err @@ -89,15 +67,12 @@ func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *se return err } - ticketCookie := cookies.MakeCookie( + ticketCookie := cookies.MakeCookieFromOptions( req, - store.CookieName, + store.CookieOptions.CookieName, ticketString, - store.CookiePath, - store.CookieDomain, - store.CookieHTTPOnly, - store.CookieSecure, - store.CookieExpire, + store.CookieOptions, + store.CookieOptions.CookieExpire, time.Now(), ) @@ -108,7 +83,7 @@ func (store *SessionStore) Save(rw http.ResponseWriter, req *http.Request, s *se // Load reads sessions.SessionState information from a ticket // cookie within the HTTP request object func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { - requestCookie, err := req.Cookie(store.CookieName) + requestCookie, err := req.Cookie(store.CookieOptions.CookieName) if err != nil { return nil, fmt.Errorf("error loading session: %s", err) } @@ -122,12 +97,12 @@ func (store *SessionStore) Load(req *http.Request) (*sessions.SessionState, erro // LoadSessionFromString loads the session based on the ticket value func (store *SessionStore) LoadSessionFromString(value string) (*sessions.SessionState, error) { - ticket, err := decodeTicket(store.CookieName, value) + ticket, err := decodeTicket(store.CookieOptions.CookieName, value) if err != nil { return nil, err } - result, err := store.Client.Get(ticket.asHandle(store.CookieName)).Result() + result, err := store.Client.Get(ticket.asHandle(store.CookieOptions.CookieName)).Result() if err != nil { return nil, err } @@ -151,17 +126,14 @@ func (store *SessionStore) LoadSessionFromString(value string) (*sessions.Sessio // Clear clears any saved session information for a given ticket cookie // from redis, and then clears the session func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { - requestCookie, _ := req.Cookie(store.CookieName) + requestCookie, _ := req.Cookie(store.CookieOptions.CookieName) // We go ahead and clear the cookie first, always. - clearCookie := cookies.MakeCookie( + clearCookie := cookies.MakeCookieFromOptions( req, - store.CookieName, + store.CookieOptions.CookieName, "", - store.CookiePath, - store.CookieDomain, - store.CookieHTTPOnly, - store.CookieSecure, + store.CookieOptions, time.Hour*-1, time.Now(), ) @@ -169,9 +141,9 @@ func (store *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro // We only return an error if we had an issue with redis // If there's an issue decoding the ticket, ignore it - ticket, _ := decodeTicket(store.CookieName, requestCookie.Value) + ticket, _ := decodeTicket(store.CookieOptions.CookieName, requestCookie.Value) if ticket != nil { - deleted, err := store.Client.Del(ticket.asHandle(store.CookieName)).Result() + deleted, err := store.Client.Del(ticket.asHandle(store.CookieOptions.CookieName)).Result() fmt.Println("delted %n", deleted) if err != nil { return fmt.Errorf("error clearing cookie from redis: %s", err) @@ -184,7 +156,7 @@ func (store *SessionStore) storeValue(value string, expiresOn time.Time, request var ticket *TicketData if requestCookie != nil { var err error - ticket, err = decodeTicket(store.CookieName, requestCookie.Value) + ticket, err = decodeTicket(store.CookieOptions.CookieName, requestCookie.Value) if err != nil { return "", err } @@ -206,13 +178,13 @@ func (store *SessionStore) storeValue(value string, expiresOn time.Time, request stream := cipher.NewCFBEncrypter(block, ticket.Secret) stream.XORKeyStream(ciphertext, []byte(value)) - handle := ticket.asHandle(store.CookieName) + handle := ticket.asHandle(store.CookieOptions.CookieName) expires := expiresOn.Sub(time.Now()) err = store.Client.Set(handle, ciphertext, expires).Err() if err != nil { return "", err } - return ticket.encodeTicket(store.CookieName), nil + return ticket.encodeTicket(store.CookieOptions.CookieName), nil } func newTicket() (*TicketData, error) { diff --git a/pkg/sessions/session_store.go b/pkg/sessions/session_store.go index e8e229a..17ef21c 100644 --- a/pkg/sessions/session_store.go +++ b/pkg/sessions/session_store.go @@ -15,7 +15,7 @@ func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOpt case options.CookieSessionStoreType: return cookie.NewCookieSessionStore(opts, cookieOpts) case options.RedisSessionStoreType: - return redis.NewRedisSessionStore(opts.RedisStoreOptions, cookieOpts) + return redis.NewRedisSessionStore(opts, cookieOpts) default: return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) }