diff --git a/CHANGELOG.md b/CHANGELOG.md index eb798f5..6b34f7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ ## Changes since v3.2.0 +- [#148](https://github.com/pusher/outh2_proxy/pull/148) Implement SessionStore interface within proxy (@JoelSpeed) - [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed) - Allows for multiple different session storage implementations including client and server side - Adds tests suite for interface to ensure consistency across implementations diff --git a/oauthproxy.go b/oauthproxy.go index 02d9ac1..389b2a9 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -16,7 +16,7 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/logger" - "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/providers" "github.com/yhat/wsutil" ) @@ -29,10 +29,6 @@ const ( httpScheme = "http" httpsScheme = "https" - // Cookies are limited to 4kb including the length of the cookie name, - // the cookie name can be up to 256 bytes - maxCookieLength = 3840 - applicationJSON = "application/json" ) @@ -75,6 +71,7 @@ type OAuthProxy struct { redirectURL *url.URL // the url to receive requests at whitelistDomains []string provider providers.Provider + sessionStore sessionsapi.SessionStore ProxyPrefix string SignInMessage string HtpasswdFile *HtpasswdFile @@ -88,7 +85,6 @@ type OAuthProxy struct { PassAccessToken bool SetAuthorization bool PassAuthorization bool - CookieCipher *cookie.Cipher skipAuthRegex []string skipAuthPreflight bool compiledRegex []*regexp.Regexp @@ -218,15 +214,6 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s path:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, opts.CookiePath, refresh) - var cipher *cookie.Cipher - if opts.PassAccessToken || opts.SetAuthorization || opts.PassAuthorization || (opts.CookieRefresh != time.Duration(0)) { - var err error - cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret)) - if err != nil { - logger.Fatal("cookie-secret error: ", err) - } - } - return &OAuthProxy{ CookieName: opts.CookieName, CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"), @@ -249,6 +236,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { ProxyPrefix: opts.ProxyPrefix, provider: opts.provider, + sessionStore: opts.sessionStore, serveMux: serveMux, redirectURL: redirectURL, whitelistDomains: opts.WhitelistDomains, @@ -263,7 +251,6 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { SetAuthorization: opts.SetAuthorization, PassAuthorization: opts.PassAuthorization, SkipProviderButton: opts.SkipProviderButton, - CookieCipher: cipher, templates: loadTemplates(opts.CustomTemplatesDir), Footer: opts.Footer, } @@ -293,7 +280,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool { return p.HtpasswdFile != nil && p.DisplayHtpasswdForm } -func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) { +func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, err error) { if code == "" { return nil, errors.New("missing code") } @@ -316,104 +303,6 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, er return } -// MakeSessionCookie creates an http.Cookie containing the authenticated user's -// authentication details -func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { - if value != "" { - value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) - } - c := p.makeCookie(req, p.CookieName, value, expiration, now) - if len(c.Value) > 4096-len(p.CookieName) { - return splitCookie(c) - } - return []*http.Cookie{c} -} - -func copyCookie(c *http.Cookie) *http.Cookie { - return &http.Cookie{ - Name: c.Name, - Value: c.Value, - Path: c.Path, - Domain: c.Domain, - Expires: c.Expires, - RawExpires: c.RawExpires, - MaxAge: c.MaxAge, - Secure: c.Secure, - HttpOnly: c.HttpOnly, - Raw: c.Raw, - Unparsed: c.Unparsed, - } -} - -// splitCookie reads the full cookie generated to store the session and splits -// it into a slice of cookies which fit within the 4kb cookie limit indexing -// the cookies from 0 -func splitCookie(c *http.Cookie) []*http.Cookie { - if len(c.Value) < maxCookieLength { - return []*http.Cookie{c} - } - cookies := []*http.Cookie{} - valueBytes := []byte(c.Value) - count := 0 - for len(valueBytes) > 0 { - new := copyCookie(c) - new.Name = fmt.Sprintf("%s_%d", c.Name, count) - count++ - if len(valueBytes) < maxCookieLength { - new.Value = string(valueBytes) - valueBytes = []byte{} - } else { - newValue := valueBytes[:maxCookieLength] - valueBytes = valueBytes[maxCookieLength:] - new.Value = string(newValue) - } - cookies = append(cookies, new) - } - return cookies -} - -// joinCookies takes a slice of cookies from the request and reconstructs the -// full session cookie -func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { - if len(cookies) == 0 { - return nil, fmt.Errorf("list of cookies must be > 0") - } - if len(cookies) == 1 { - return cookies[0], nil - } - c := copyCookie(cookies[0]) - for i := 1; i < len(cookies); i++ { - c.Value += cookies[i].Value - } - c.Name = strings.TrimRight(c.Name, "_0") - return c, nil -} - -// loadCookie retreieves the sessions state cookie from the http request. -// If a single cookie is present this will be returned, otherwise it attempts -// to reconstruct a cookie split up by splitCookie -func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { - c, err := req.Cookie(cookieName) - if err == nil { - return c, nil - } - cookies := []*http.Cookie{} - err = nil - count := 0 - for err == nil { - var c *http.Cookie - c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count)) - if err == nil { - cookies = append(cookies, c) - count++ - } - } - if len(cookies) == 0 { - return nil, fmt.Errorf("Could not find cookie %s", cookieName) - } - return joinCookies(cookies) -} - // MakeCSRFCookie creates a cookie for CSRF func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) @@ -454,66 +343,18 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va // ClearSessionCookie creates a cookie to unset the user's authentication cookie // stored in the user's session -func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { - var cookies []*http.Cookie - - // matches CookieName, CookieName_ - var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", p.CookieName)) - - for _, c := range req.Cookies() { - if cookieNameRegex.MatchString(c.Name) { - clearCookie := p.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) - - http.SetCookie(rw, clearCookie) - cookies = append(cookies, clearCookie) - } - } - - // ugly hack because default domain changed - if p.CookieDomain == "" && len(cookies) > 0 { - clr2 := *cookies[0] - clr2.Domain = req.Host - http.SetCookie(rw, &clr2) - } -} - -// SetSessionCookie adds the user's session cookie to the response -func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { - for _, c := range p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()) { - http.SetCookie(rw, c) - } +func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { + return p.sessionStore.Clear(rw, req) } // LoadCookiedSession reads the user's authentication details from the request -func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) { - var age time.Duration - c, err := loadCookie(req, p.CookieName) - if err != nil { - // always http.ErrNoCookie - return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) - } - val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire) - if !ok { - return nil, age, errors.New("Cookie Signature not valid") - } - - session, err := p.provider.SessionFromCookie(val, p.CookieCipher) - if err != nil { - return nil, age, err - } - - age = time.Now().Truncate(time.Second).Sub(timestamp) - return session, age, nil +func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, error) { + return p.sessionStore.Load(req) } // SaveSession creates a new session cookie value and sets this on the response -func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { - value, err := p.provider.CookieForSession(s, p.CookieCipher) - if err != nil { - return err - } - p.SetSessionCookie(rw, req, value) - return nil +func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { + return p.sessionStore.Save(rw, req, s) } // RobotsTxt disallows scraping pages from the OAuthProxy @@ -694,7 +535,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { user, ok := p.ManualSignIn(rw, req) if ok { - session := &sessions.SessionState{User: user} + session := &sessionsapi.SessionState{User: user} p.SaveSession(rw, req, session) http.Redirect(rw, req, redirect, 302) } else { @@ -833,12 +674,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int var saveSession, clearSession, revalidated bool remoteAddr := getRemoteAddr(req) - session, sessionAge, err := p.LoadCookiedSession(req) + session, err := p.LoadCookiedSession(req) if err != nil { logger.Printf("Error loading cookied session: %s", err) } - if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { - logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", sessionAge, session, p.CookieRefresh) + if session != nil && session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { + logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh) saveSession = true } @@ -945,7 +786,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int // CheckBasicAuth checks the requests Authorization header for basic auth // credentials and authenticates these against the proxies HtpasswdFile -func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) { +func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { if p.HtpasswdFile == nil { return nil, nil } @@ -967,7 +808,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, } if p.HtpasswdFile.Validate(pair[0], pair[1]) { logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") - return &sessions.SessionState{User: pair[0]}, nil + return &sessionsapi.SessionState{User: pair[0]}, nil } logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") return nil, nil diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 914e99f..1d09bbb 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -17,6 +17,7 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "github.com/pusher/oauth2_proxy/pkg/sessions/cookie" "github.com/pusher/oauth2_proxy/providers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -600,10 +601,15 @@ type ProcessCookieTestOpts struct { providerValidateCookieResponse bool } -func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { +type OptionsModifier func(*Options) + +func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest { var pcTest ProcessCookieTest pcTest.opts = NewOptions() + for _, modifier := range modifiers { + modifier(pcTest.opts) + } pcTest.opts.ClientID = "bazquux" pcTest.opts.ClientSecret = "xyzzyplugh" pcTest.opts.CookieSecret = "0123456789abcdefabcd" @@ -634,32 +640,34 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest { }) } -func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie { - return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) +func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest { + return NewProcessCookieTest(ProcessCookieTestOpts{ + providerValidateCookieResponse: true, + }, modifiers...) } -func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error { - value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) +func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error { + err := p.proxy.SaveSession(p.rw, p.req, s) if err != nil { return err } - for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) { - p.req.AddCookie(c) + for _, cookie := range p.rw.Result().Cookies() { + p.req.AddCookie(cookie) } return nil } -func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) { +func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) { return p.proxy.LoadCookiedSession(p.req) } func TestLoadCookiedSession(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() - startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} - pcTest.SaveSession(startSession, time.Now()) + startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()} + pcTest.SaveSession(startSession) - session, _, err := pcTest.LoadCookiedSession() + session, err := pcTest.LoadCookiedSession() assert.Equal(t, nil, err) assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, "john.doe@example.com", session.User) @@ -669,7 +677,7 @@ func TestLoadCookiedSession(t *testing.T) { func TestProcessCookieNoCookieError(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() - session, _, err := pcTest.LoadCookiedSession() + session, err := pcTest.LoadCookiedSession() assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) if session != nil { t.Errorf("expected nil session. got %#v", session) @@ -677,29 +685,31 @@ func TestProcessCookieNoCookieError(t *testing.T) { } func TestProcessCookieRefreshNotSet(t *testing.T) { - pcTest := NewProcessCookieTestWithDefaults() - pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour + pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { + opts.CookieExpire = time.Duration(23) * time.Hour + }) reference := time.Now().Add(time.Duration(-2) * time.Hour) - startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - pcTest.SaveSession(startSession, reference) + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + pcTest.SaveSession(startSession) - session, age, err := pcTest.LoadCookiedSession() + session, err := pcTest.LoadCookiedSession() assert.Equal(t, nil, err) - if age < time.Duration(-2)*time.Hour { - t.Errorf("cookie too young %v", age) + if session.Age() < time.Duration(-2)*time.Hour { + t.Errorf("cookie too young %v", session.Age()) } assert.Equal(t, startSession.Email, session.Email) } func TestProcessCookieFailIfCookieExpired(t *testing.T) { - pcTest := NewProcessCookieTestWithDefaults() - pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour + pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { + opts.CookieExpire = time.Duration(24) * time.Hour + }) reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - pcTest.SaveSession(startSession, reference) + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + pcTest.SaveSession(startSession) - session, _, err := pcTest.LoadCookiedSession() + session, err := pcTest.LoadCookiedSession() assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expected nil session %#v", session) @@ -707,22 +717,23 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { } func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { - pcTest := NewProcessCookieTestWithDefaults() - pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour + pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { + opts.CookieExpire = time.Duration(24) * time.Hour + }) reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - pcTest.SaveSession(startSession, reference) + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + pcTest.SaveSession(startSession) pcTest.proxy.CookieRefresh = time.Hour - session, _, err := pcTest.LoadCookiedSession() + session, err := pcTest.LoadCookiedSession() assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expected nil session %#v", session) } } -func NewAuthOnlyEndpointTest() *ProcessCookieTest { - pcTest := NewProcessCookieTestWithDefaults() +func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest { + pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...) pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) return pcTest @@ -731,8 +742,8 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest { func TestAuthOnlyEndpointAccepted(t *testing.T) { test := NewAuthOnlyEndpointTest() startSession := &sessions.SessionState{ - Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - test.SaveSession(startSession, time.Now()) + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} + test.SaveSession(startSession) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusAccepted, test.rw.Code) @@ -750,12 +761,13 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { - test := NewAuthOnlyEndpointTest() - test.proxy.CookieExpire = time.Duration(24) * time.Hour + test := NewAuthOnlyEndpointTest(func(opts *Options) { + opts.CookieExpire = time.Duration(24) * time.Hour + }) reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{ - Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - test.SaveSession(startSession, reference) + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + test.SaveSession(startSession) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) @@ -766,8 +778,8 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { test := NewAuthOnlyEndpointTest() startSession := &sessions.SessionState{ - Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - test.SaveSession(startSession, time.Now()) + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} + test.SaveSession(startSession) test.validateUser = false test.proxy.ServeHTTP(test.rw, test.req) @@ -797,8 +809,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { pcTest.opts.ProxyPrefix+"/auth", nil) startSession := &sessions.SessionState{ - User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} - pcTest.SaveSession(startSession, time.Now()) + User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()} + pcTest.SaveSession(startSession) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) @@ -930,11 +942,11 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { state := &sessions.SessionState{ Email: "mbland@acm.org", AccessToken: "my_access_token"} - value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) + err = proxy.SaveSession(st.rw, req, state) if err != nil { panic(err) } - for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) { + for _, c := range st.rw.Result().Cookies() { req.AddCookie(c) } // This is used by the upstream to validate the signature. @@ -1068,7 +1080,12 @@ func TestAjaxForbiddendRequest(t *testing.T) { } func TestClearSplitCookie(t *testing.T) { - p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} + opts := NewOptions() + opts.CookieName = "oauth2" + opts.CookieDomain = "abc" + store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions) + assert.Equal(t, err, nil) + p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) @@ -1092,7 +1109,12 @@ func TestClearSplitCookie(t *testing.T) { } func TestClearSingleCookie(t *testing.T) { - p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} + opts := NewOptions() + opts.CookieName = "oauth2" + opts.CookieDomain = "abc" + store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions) + assert.Equal(t, err, nil) + p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) diff --git a/options.go b/options.go index 3639134..0460bce 100644 --- a/options.go +++ b/options.go @@ -17,8 +17,11 @@ import ( oidc "github.com/coreos/go-oidc" "github.com/dgrijalva/jwt-go" "github.com/mbland/hmacauth" + "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/pkg/apis/options" + sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "github.com/pusher/oauth2_proxy/pkg/sessions" "github.com/pusher/oauth2_proxy/providers" "gopkg.in/natefinch/lumberjack.v2" ) @@ -111,6 +114,7 @@ type Options struct { proxyURLs []*url.URL CompiledRegex []*regexp.Regexp provider providers.Provider + sessionStore sessionsapi.SessionStore signatureData *SignatureData oidcVerifier *oidc.IDTokenVerifier } @@ -136,6 +140,9 @@ func NewOptions() *Options { CookieExpire: time.Duration(168) * time.Hour, CookieRefresh: time.Duration(0), }, + SessionOptions: options.SessionOptions{ + Type: "cookie", + }, SetXAuthRequest: false, SkipAuthPreflight: false, PassBasicAuth: true, @@ -261,7 +268,8 @@ func (o *Options) Validate() error { } msgs = parseProviderInfo(o, msgs) - if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { + var cipher *cookie.Cipher + if o.PassAccessToken || o.SetAuthorization || o.PassAuthorization || (o.CookieRefresh != time.Duration(0)) { validCookieSecretSize := false for _, i := range []int{16, 24, 32} { if len(secretBytes(o.CookieSecret)) == i { @@ -283,9 +291,23 @@ func (o *Options) Validate() error { "pass_access_token == true or "+ "cookie_refresh != 0, but is %d bytes.%s", len(secretBytes(o.CookieSecret)), suffix)) + } else { + var err error + cipher, err = cookie.NewCipher(secretBytes(o.CookieSecret)) + if err != nil { + msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err)) + } } } + o.SessionOptions.Cipher = cipher + sessionStore, err := sessions.NewSessionStore(&o.SessionOptions, &o.CookieOptions) + if err != nil { + msgs = append(msgs, fmt.Sprintf("error initialising session storage: %v", err)) + } else { + o.sessionStore = sessionStore + } + if o.CookieRefresh >= o.CookieExpire { msgs = append(msgs, fmt.Sprintf( "cookie_refresh (%s) must be less than "+ diff --git a/pkg/apis/options/sessions.go b/pkg/apis/options/sessions.go index 56fd27a..7d33f14 100644 --- a/pkg/apis/options/sessions.go +++ b/pkg/apis/options/sessions.go @@ -1,8 +1,13 @@ package options +import ( + "github.com/pusher/oauth2_proxy/cookie" +) + // SessionOptions contains configuration options for the SessionStore providers. type SessionOptions struct { - Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` + Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` + Cipher *cookie.Cipher CookieStoreOptions } diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index f6efefd..e085976 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -14,6 +14,7 @@ import ( type SessionState struct { AccessToken string `json:",omitempty"` IDToken string `json:",omitempty"` + CreatedAt time.Time `json:"-"` ExpiresOn time.Time `json:"-"` RefreshToken string `json:",omitempty"` Email string `json:",omitempty"` @@ -23,6 +24,7 @@ type SessionState struct { // SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value type SessionStateJSON struct { *SessionState + CreatedAt *time.Time `json:",omitempty"` ExpiresOn *time.Time `json:",omitempty"` } @@ -34,6 +36,14 @@ func (s *SessionState) IsExpired() bool { return false } +// Age returns the age of a session +func (s *SessionState) Age() time.Duration { + if !s.CreatedAt.IsZero() { + return time.Now().Truncate(time.Second).Sub(s.CreatedAt) + } + return 0 +} + // String constructs a summary of the session state func (s *SessionState) String() string { o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User) @@ -43,6 +53,9 @@ func (s *SessionState) String() string { if s.IDToken != "" { o += " id_token:true" } + if !s.CreatedAt.IsZero() { + o += fmt.Sprintf(" created:%s", s.CreatedAt) + } if !s.ExpiresOn.IsZero() { o += fmt.Sprintf(" expires:%s", s.ExpiresOn) } @@ -95,6 +108,9 @@ func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { } // Embed SessionState and ExpiresOn pointer into SessionStateJSON ssj := &SessionStateJSON{SessionState: &ss} + if !ss.CreatedAt.IsZero() { + ssj.CreatedAt = &ss.CreatedAt + } if !ss.ExpiresOn.IsZero() { ssj.ExpiresOn = &ss.ExpiresOn } @@ -165,8 +181,11 @@ func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) { var ss *SessionState err := json.Unmarshal([]byte(v), &ssj) if err == nil && ssj.SessionState != nil { - // Extract SessionState and ExpiresOn value from SessionStateJSON + // Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON ss = ssj.SessionState + if ssj.CreatedAt != nil { + ss.CreatedAt = *ssj.CreatedAt + } if ssj.ExpiresOn != nil { ss.ExpiresOn = *ssj.ExpiresOn } diff --git a/pkg/apis/sessions/session_state_test.go b/pkg/apis/sessions/session_state_test.go index 83b21a4..a48344e 100644 --- a/pkg/apis/sessions/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -22,6 +22,7 @@ func TestSessionStateSerialization(t *testing.T) { Email: "user@domain.com", AccessToken: "token1234", IDToken: "rawtoken1234", + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } @@ -35,6 +36,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.IDToken, ss.IDToken) + assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.RefreshToken, ss.RefreshToken) @@ -44,6 +46,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, "user@domain.com", ss.User) assert.NotEqual(t, s.Email, ss.Email) + assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.IDToken, ss.IDToken) @@ -59,6 +62,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { User: "just-user", Email: "user@domain.com", AccessToken: "token1234", + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } @@ -71,6 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.AccessToken, ss.AccessToken) + assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.RefreshToken, ss.RefreshToken) @@ -80,6 +85,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, nil, err) assert.NotEqual(t, s.User, ss.User) assert.NotEqual(t, s.Email, ss.Email) + assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) @@ -89,6 +95,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { s := &sessions.SessionState{ Email: "user@domain.com", AccessToken: "token1234", + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } @@ -109,6 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { User: "just-user", Email: "user@domain.com", AccessToken: "token1234", + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } @@ -147,6 +155,7 @@ type testCase struct { // Currently only tests without cipher here because we have no way to mock // the random generator used in EncodeSessionState. func TestEncodeSessionState(t *testing.T) { + c := time.Now() e := time.Now().Add(time.Duration(1) * time.Hour) testCases := []testCase{ @@ -163,6 +172,7 @@ func TestEncodeSessionState(t *testing.T) { User: "just-user", AccessToken: "token1234", IDToken: "rawtoken1234", + CreatedAt: c, ExpiresOn: e, RefreshToken: "refresh4321", }, @@ -185,6 +195,9 @@ func TestEncodeSessionState(t *testing.T) { // TestDecodeSessionState testssessions.DecodeSessionState with the test vector func TestDecodeSessionState(t *testing.T) { + created := time.Now() + createdJSON, _ := created.MarshalJSON() + createdString := string(createdJSON) e := time.Now().Add(time.Duration(1) * time.Hour) eJSON, _ := e.MarshalJSON() eString := string(eJSON) @@ -219,7 +232,7 @@ func TestDecodeSessionState(t *testing.T) { Email: "user@domain.com", User: "just-user", }, - Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), + Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), }, { SessionState: sessions.SessionState{ @@ -227,10 +240,11 @@ func TestDecodeSessionState(t *testing.T) { User: "just-user", AccessToken: "token1234", IDToken: "rawtoken1234", + CreatedAt: created, ExpiresOn: e, RefreshToken: "refresh4321", }, - Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), + Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString), Cipher: c, }, { @@ -316,3 +330,14 @@ func TestDecodeSessionState(t *testing.T) { } } } + +func TestSessionStateAge(t *testing.T) { + ss := &sessions.SessionState{} + + // Created at unset so should be 0 + assert.Equal(t, time.Duration(0), ss.Age()) + + // Set CreatedAt to 1 hour ago + ss.CreatedAt = time.Now().Add(-1 * time.Hour) + assert.Equal(t, time.Hour, ss.Age().Round(time.Minute)) +} diff --git a/pkg/cookies/cookies.go b/pkg/cookies/cookies.go index 936f08e..08e6a9b 100644 --- a/pkg/cookies/cookies.go +++ b/pkg/cookies/cookies.go @@ -7,6 +7,7 @@ import ( "time" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/options" ) // MakeCookie constructs a cookie from the given parameters, @@ -32,3 +33,9 @@ func MakeCookie(req *http.Request, name string, value string, path string, domai Expires: now.Add(expiration), } } + +// MakeCookieFromOptions constructs a cookie based on the givemn *options.CookieOptions, +// value and creation time +func MakeCookieFromOptions(req *http.Request, name string, value string, opts *options.CookieOptions, expiration time.Duration, now time.Time) *http.Cookie { + return MakeCookie(req, name, value, opts.CookiePath, opts.CookieDomain, opts.CookieHTTPOnly, opts.CookieSecure, expiration, now) +} diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 14c0b71..c40dd23 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -27,36 +27,33 @@ var _ sessions.SessionStore = &SessionStore{} // SessionStore is an implementation of the sessions.SessionStore // interface that stores sessions in client side cookies type SessionStore struct { - CookieCipher *cookie.Cipher - CookieDomain string - CookieExpire time.Duration - CookieHTTPOnly bool - CookieName string - CookiePath string - CookieSecret string - CookieSecure bool + CookieOptions *options.CookieOptions + CookieCipher *cookie.Cipher } // Save takes a sessions.SessionState and stores the information from it // within Cookies set on the HTTP response writer func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { + if ss.CreatedAt.IsZero() { + ss.CreatedAt = time.Now() + } value, err := utils.CookieForSession(ss, s.CookieCipher) if err != nil { return err } - s.setSessionCookie(rw, req, value) + s.setSessionCookie(rw, req, value, ss.CreatedAt) return nil } // Load reads sessions.SessionState information from Cookies within the // HTTP request object func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { - c, err := loadCookie(req, s.CookieName) + c, err := loadCookie(req, s.CookieOptions.CookieName) if err != nil { // always http.ErrNoCookie - return nil, fmt.Errorf("Cookie %q not present", s.CookieName) + return nil, fmt.Errorf("Cookie %q not present", s.CookieOptions.CookieName) } - val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire) + val, _, ok := cookie.Validate(c, s.CookieOptions.CookieSecret, s.CookieOptions.CookieExpire) if !ok { return nil, errors.New("Cookie Signature not valid") } @@ -74,11 +71,11 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { var cookies []*http.Cookie // matches CookieName, CookieName_ - var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName)) + var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieOptions.CookieName)) for _, c := range req.Cookies() { if cookieNameRegex.MatchString(c.Name) { - clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1) + clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) http.SetCookie(rw, clearCookie) cookies = append(cookies, clearCookie) @@ -89,60 +86,42 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { } // setSessionCookie adds the user's session cookie to the response -func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { - for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) { +func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string, created time.Time) { + for _, c := range s.makeSessionCookie(req, val, created) { http.SetCookie(rw, c) } } // makeSessionCookie creates an http.Cookie containing the authenticated user's // authentication details -func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { +func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie { if value != "" { - value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now) + value = cookie.SignedValue(s.CookieOptions.CookieSecret, s.CookieOptions.CookieName, value, now) } - c := s.makeCookie(req, s.CookieName, value, expiration) - if len(c.Value) > 4096-len(s.CookieName) { + c := s.makeCookie(req, s.CookieOptions.CookieName, value, s.CookieOptions.CookieExpire, now) + if len(c.Value) > 4096-len(s.CookieOptions.CookieName) { return splitCookie(c) } return []*http.Cookie{c} } -func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie { - return cookies.MakeCookie( +func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { + return cookies.MakeCookieFromOptions( req, name, value, - s.CookiePath, - s.CookieDomain, - s.CookieHTTPOnly, - s.CookieSecure, + s.CookieOptions, expiration, - time.Now(), + now, ) } // NewCookieSessionStore initialises a new instance of the SessionStore from // the configuration given -func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { - var cipher *cookie.Cipher - if len(cookieOpts.CookieSecret) > 0 { - var err error - cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) - if err != nil { - return nil, fmt.Errorf("unable to create cipher: %v", err) - } - } - +func NewCookieSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { return &SessionStore{ - CookieCipher: cipher, - CookieDomain: cookieOpts.CookieDomain, - CookieExpire: cookieOpts.CookieExpire, - CookieHTTPOnly: cookieOpts.CookieHTTPOnly, - CookieName: cookieOpts.CookieName, - CookiePath: cookieOpts.CookiePath, - CookieSecret: cookieOpts.CookieSecret, - CookieSecure: cookieOpts.CookieSecure, + CookieCipher: opts.Cipher, + CookieOptions: cookieOpts, }, nil } diff --git a/pkg/sessions/session_store.go b/pkg/sessions/session_store.go index cc074c7..ec84dd7 100644 --- a/pkg/sessions/session_store.go +++ b/pkg/sessions/session_store.go @@ -12,7 +12,7 @@ import ( func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { switch opts.Type { case options.CookieSessionStoreType: - return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts) + return cookie.NewCookieSessionStore(opts, cookieOpts) default: return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) } diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go index 0ceea66..b407d58 100644 --- a/pkg/sessions/session_store_test.go +++ b/pkg/sessions/session_store_test.go @@ -5,16 +5,20 @@ import ( "encoding/base64" "net/http" "net/http/httptest" + "strconv" + "strings" "testing" "time" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/pkg/apis/options" sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/pkg/cookies" "github.com/pusher/oauth2_proxy/pkg/sessions" - "github.com/pusher/oauth2_proxy/pkg/sessions/cookie" + sessionscookie "github.com/pusher/oauth2_proxy/pkg/sessions/cookie" + "github.com/pusher/oauth2_proxy/pkg/sessions/utils" ) func TestSessionStore(t *testing.T) { @@ -72,6 +76,16 @@ var _ = Describe("NewSessionStore", func() { } }) + It("have a signature timestamp matching session.CreatedAt", func() { + for _, cookie := range cookies { + if cookie.Value != "" { + parts := strings.Split(cookie.Value, "|") + Expect(parts).To(HaveLen(3)) + Expect(parts[1]).To(Equal(strconv.Itoa(int(session.CreatedAt.Unix())))) + } + } + }) + }) } @@ -86,6 +100,10 @@ var _ = Describe("NewSessionStore", func() { Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) }) + It("Ensures the session CreatedAt is not zero", func() { + Expect(session.CreatedAt.IsZero()).To(BeFalse()) + }) + CheckCookieOptions() }) @@ -138,12 +156,15 @@ var _ = Describe("NewSessionStore", func() { // Can't compare time.Time using Equal() so remove ExpiresOn from sessions l := *loadedSession + l.CreatedAt = time.Time{} l.ExpiresOn = time.Time{} s := *session + s.CreatedAt = time.Time{} s.ExpiresOn = time.Time{} Expect(l).To(Equal(s)) // Compare time.Time separately + Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue()) Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) } }) @@ -181,12 +202,16 @@ var _ = Describe("NewSessionStore", func() { SessionStoreInterfaceTests() }) - Context("with a cookie-secret set", func() { + Context("with a cipher", func() { BeforeEach(func() { secret := make([]byte, 32) _, err := rand.Read(secret) Expect(err).ToNot(HaveOccurred()) cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret) + cipher, err := cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) + Expect(err).ToNot(HaveOccurred()) + Expect(cipher).ToNot(BeNil()) + opts.Cipher = cipher ss, err = sessions.NewSessionStore(opts, cookieOpts) Expect(err).ToNot(HaveOccurred()) @@ -231,7 +256,7 @@ var _ = Describe("NewSessionStore", func() { It("creates a cookie.SessionStore", func() { ss, err := sessions.NewSessionStore(opts, cookieOpts) Expect(err).NotTo(HaveOccurred()) - Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{})) + Expect(ss).To(BeAssignableToTypeOf(&sessionscookie.SessionStore{})) }) Context("the cookie.SessionStore", func() { diff --git a/providers/google.go b/providers/google.go index f79a131..6f29c2c 100644 --- a/providers/google.go +++ b/providers/google.go @@ -149,6 +149,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), RefreshToken: jsonResponse.RefreshToken, Email: c.Email, diff --git a/providers/logingov.go b/providers/logingov.go index 60f4260..95d7aa9 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -252,6 +252,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, + CreatedAt: time.Now(), ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), Email: email, } diff --git a/providers/oidc.go b/providers/oidc.go index bacabdf..08ea082 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -87,6 +87,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) s.AccessToken = newSession.AccessToken s.IDToken = newSession.IDToken s.RefreshToken = newSession.RefreshToken + s.CreatedAt = newSession.CreatedAt s.ExpiresOn = newSession.ExpiresOn s.Email = newSession.Email return @@ -126,6 +127,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok AccessToken: token.AccessToken, IDToken: rawIDToken, RefreshToken: token.RefreshToken, + CreatedAt: time.Now(), ExpiresOn: token.Expiry, Email: claims.Email, User: claims.Subject, diff --git a/providers/provider_default.go b/providers/provider_default.go index cd78251..4716014 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -8,6 +8,7 @@ import ( "io/ioutil" "net/http" "net/url" + "time" "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/pkg/apis/sessions" @@ -72,7 +73,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat return } if a := v.Get("access_token"); a != "" { - s = &sessions.SessionState{AccessToken: a} + s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()} } else { err = fmt.Errorf("no access token found %s", body) }