From c61f3a1c657dabc0542ea7519e57e2951d963ac7 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Tue, 7 May 2019 16:13:55 +0100 Subject: [PATCH] Use SessionStore for session in proxy --- oauthproxy.go | 57 ++------------ oauthproxy_test.go | 108 +++++++++++++++++---------- pkg/sessions/cookie/session_store.go | 9 ++- pkg/sessions/session_store_test.go | 19 +++++ 4 files changed, 100 insertions(+), 93 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index f7633ac..f5ae591 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -456,27 +456,8 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va // ClearSessionCookie creates a cookie to unset the user's authentication cookie // stored in the user's session -func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { - var cookies []*http.Cookie - - // matches CookieName, CookieName_ - var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", p.CookieName)) - - for _, c := range req.Cookies() { - if cookieNameRegex.MatchString(c.Name) { - clearCookie := p.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) - - http.SetCookie(rw, clearCookie) - cookies = append(cookies, clearCookie) - } - } - - // ugly hack because default domain changed - if p.CookieDomain == "" && len(cookies) > 0 { - clr2 := *cookies[0] - clr2.Domain = req.Host - http.SetCookie(rw, &clr2) - } +func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error { + return p.sessionStore.Clear(rw, req) } // SetSessionCookie adds the user's session cookie to the response @@ -487,35 +468,13 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, } // LoadCookiedSession reads the user's authentication details from the request -func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, time.Duration, error) { - var age time.Duration - c, err := loadCookie(req, p.CookieName) - if err != nil { - // always http.ErrNoCookie - return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) - } - val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire) - if !ok { - return nil, age, errors.New("Cookie Signature not valid") - } - - session, err := p.provider.SessionFromCookie(val, p.CookieCipher) - if err != nil { - return nil, age, err - } - - age = time.Now().Truncate(time.Second).Sub(timestamp) - return session, age, nil +func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, error) { + return p.sessionStore.Load(req) } // SaveSession creates a new session cookie value and sets this on the response func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { - value, err := p.provider.CookieForSession(s, p.CookieCipher) - if err != nil { - return err - } - p.SetSessionCookie(rw, req, value) - return nil + return p.sessionStore.Save(rw, req, s) } // RobotsTxt disallows scraping pages from the OAuthProxy @@ -835,12 +794,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int var saveSession, clearSession, revalidated bool remoteAddr := getRemoteAddr(req) - session, sessionAge, err := p.LoadCookiedSession(req) + session, err := p.LoadCookiedSession(req) if err != nil { logger.Printf("Error loading cookied session: %s", err) } - if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { - logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", sessionAge, session, p.CookieRefresh) + if session != nil && session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { + logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh) saveSession = true } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 914e99f..32837eb 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -17,6 +17,7 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "github.com/pusher/oauth2_proxy/pkg/sessions/cookie" "github.com/pusher/oauth2_proxy/providers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -600,10 +601,15 @@ type ProcessCookieTestOpts struct { providerValidateCookieResponse bool } -func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { +type OptionsModifier func(*Options) + +func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest { var pcTest ProcessCookieTest pcTest.opts = NewOptions() + for _, modifier := range modifiers { + modifier(pcTest.opts) + } pcTest.opts.ClientID = "bazquux" pcTest.opts.ClientSecret = "xyzzyplugh" pcTest.opts.CookieSecret = "0123456789abcdefabcd" @@ -634,32 +640,38 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest { }) } +func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest { + return NewProcessCookieTest(ProcessCookieTestOpts{ + providerValidateCookieResponse: true, + }, modifiers...) +} + func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie { return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) } -func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error { - value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) +func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error { + err := p.proxy.SaveSession(p.rw, p.req, s) if err != nil { return err } - for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) { - p.req.AddCookie(c) + for _, cookie := range p.rw.Result().Cookies() { + p.req.AddCookie(cookie) } return nil } -func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) { +func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) { return p.proxy.LoadCookiedSession(p.req) } func TestLoadCookiedSession(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() - startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} - pcTest.SaveSession(startSession, time.Now()) + startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()} + pcTest.SaveSession(startSession) - session, _, err := pcTest.LoadCookiedSession() + session, err := pcTest.LoadCookiedSession() assert.Equal(t, nil, err) assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, "john.doe@example.com", session.User) @@ -669,7 +681,7 @@ func TestLoadCookiedSession(t *testing.T) { func TestProcessCookieNoCookieError(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() - session, _, err := pcTest.LoadCookiedSession() + session, err := pcTest.LoadCookiedSession() assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) if session != nil { t.Errorf("expected nil session. got %#v", session) @@ -677,29 +689,31 @@ func TestProcessCookieNoCookieError(t *testing.T) { } func TestProcessCookieRefreshNotSet(t *testing.T) { - pcTest := NewProcessCookieTestWithDefaults() - pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour + pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { + opts.CookieExpire = time.Duration(23) * time.Hour + }) reference := time.Now().Add(time.Duration(-2) * time.Hour) - startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - pcTest.SaveSession(startSession, reference) + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + pcTest.SaveSession(startSession) - session, age, err := pcTest.LoadCookiedSession() + session, err := pcTest.LoadCookiedSession() assert.Equal(t, nil, err) - if age < time.Duration(-2)*time.Hour { - t.Errorf("cookie too young %v", age) + if session.Age() < time.Duration(-2)*time.Hour { + t.Errorf("cookie too young %v", session.Age()) } assert.Equal(t, startSession.Email, session.Email) } func TestProcessCookieFailIfCookieExpired(t *testing.T) { - pcTest := NewProcessCookieTestWithDefaults() - pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour + pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { + opts.CookieExpire = time.Duration(24) * time.Hour + }) reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - pcTest.SaveSession(startSession, reference) + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + pcTest.SaveSession(startSession) - session, _, err := pcTest.LoadCookiedSession() + session, err := pcTest.LoadCookiedSession() assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expected nil session %#v", session) @@ -707,22 +721,23 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { } func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { - pcTest := NewProcessCookieTestWithDefaults() - pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour + pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) { + opts.CookieExpire = time.Duration(24) * time.Hour + }) reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - pcTest.SaveSession(startSession, reference) + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + pcTest.SaveSession(startSession) pcTest.proxy.CookieRefresh = time.Hour - session, _, err := pcTest.LoadCookiedSession() + session, err := pcTest.LoadCookiedSession() assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expected nil session %#v", session) } } -func NewAuthOnlyEndpointTest() *ProcessCookieTest { - pcTest := NewProcessCookieTestWithDefaults() +func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest { + pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...) pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) return pcTest @@ -731,8 +746,8 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest { func TestAuthOnlyEndpointAccepted(t *testing.T) { test := NewAuthOnlyEndpointTest() startSession := &sessions.SessionState{ - Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - test.SaveSession(startSession, time.Now()) + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} + test.SaveSession(startSession) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusAccepted, test.rw.Code) @@ -750,12 +765,13 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) { } func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { - test := NewAuthOnlyEndpointTest() - test.proxy.CookieExpire = time.Duration(24) * time.Hour + test := NewAuthOnlyEndpointTest(func(opts *Options) { + opts.CookieExpire = time.Duration(24) * time.Hour + }) reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &sessions.SessionState{ - Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - test.SaveSession(startSession, reference) + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference} + test.SaveSession(startSession) test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) @@ -766,8 +782,8 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { test := NewAuthOnlyEndpointTest() startSession := &sessions.SessionState{ - Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - test.SaveSession(startSession, time.Now()) + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} + test.SaveSession(startSession) test.validateUser = false test.proxy.ServeHTTP(test.rw, test.req) @@ -797,8 +813,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { pcTest.opts.ProxyPrefix+"/auth", nil) startSession := &sessions.SessionState{ - User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} - pcTest.SaveSession(startSession, time.Now()) + User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()} + pcTest.SaveSession(startSession) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) @@ -1068,7 +1084,12 @@ func TestAjaxForbiddendRequest(t *testing.T) { } func TestClearSplitCookie(t *testing.T) { - p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} + opts := NewOptions() + opts.CookieName = "oauth2" + opts.CookieDomain = "abc" + store, err := cookie.NewCookieSessionStore(opts.SessionOptions.CookieStoreOptions, &opts.CookieOptions) + assert.Equal(t, err, nil) + p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) @@ -1092,7 +1113,12 @@ func TestClearSplitCookie(t *testing.T) { } func TestClearSingleCookie(t *testing.T) { - p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} + opts := NewOptions() + opts.CookieName = "oauth2" + opts.CookieDomain = "abc" + store, err := cookie.NewCookieSessionStore(opts.SessionOptions.CookieStoreOptions, &opts.CookieOptions) + assert.Equal(t, err, nil) + p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store} var rw = httptest.NewRecorder() req := httptest.NewRequest("get", "/", nil) diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 4a3d323..8285888 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -40,11 +40,14 @@ type SessionStore struct { // Save takes a sessions.SessionState and stores the information from it // within Cookies set on the HTTP response writer func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { + if ss.CreatedAt.IsZero() { + ss.CreatedAt = time.Now() + } value, err := utils.CookieForSession(ss, s.CookieCipher) if err != nil { return err } - s.setSessionCookie(rw, req, value) + s.setSessionCookie(rw, req, value, ss.CreatedAt) return nil } @@ -89,8 +92,8 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { } // setSessionCookie adds the user's session cookie to the response -func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { - for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) { +func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string, created time.Time) { + for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, created) { http.SetCookie(rw, c) } } diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go index 4857912..7c63795 100644 --- a/pkg/sessions/session_store_test.go +++ b/pkg/sessions/session_store_test.go @@ -5,6 +5,8 @@ import ( "encoding/base64" "net/http" "net/http/httptest" + "strconv" + "strings" "testing" "time" @@ -72,6 +74,16 @@ var _ = Describe("NewSessionStore", func() { } }) + It("have a signature timestamp matching session.CreatedAt", func() { + for _, cookie := range cookies { + if cookie.Value != "" { + parts := strings.Split(cookie.Value, "|") + Expect(parts).To(HaveLen(3)) + Expect(parts[1]).To(Equal(strconv.Itoa(int(session.CreatedAt.Unix())))) + } + } + }) + }) } @@ -86,6 +98,10 @@ var _ = Describe("NewSessionStore", func() { Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) }) + It("Ensures the session CreatedAt is not zero", func() { + Expect(session.CreatedAt.IsZero()).To(BeFalse()) + }) + CheckCookieOptions() }) @@ -138,12 +154,15 @@ var _ = Describe("NewSessionStore", func() { // Can't compare time.Time using Equal() so remove ExpiresOn from sessions l := *loadedSession + l.CreatedAt = time.Time{} l.ExpiresOn = time.Time{} s := *session + s.CreatedAt = time.Time{} s.ExpiresOn = time.Time{} Expect(l).To(Equal(s)) // Compare time.Time separately + Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue()) Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) } })