Use SessionStore for session in proxy

This commit is contained in:
Joel Speed 2019-05-07 16:13:55 +01:00
parent 34cbe0497c
commit c61f3a1c65
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
4 changed files with 100 additions and 93 deletions

View File

@ -456,27 +456,8 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va
// ClearSessionCookie creates a cookie to unset the user's authentication cookie // ClearSessionCookie creates a cookie to unset the user's authentication cookie
// stored in the user's session // stored in the user's session
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error {
var cookies []*http.Cookie return p.sessionStore.Clear(rw, req)
// matches CookieName, CookieName_<number>
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", p.CookieName))
for _, c := range req.Cookies() {
if cookieNameRegex.MatchString(c.Name) {
clearCookie := p.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
http.SetCookie(rw, clearCookie)
cookies = append(cookies, clearCookie)
}
}
// ugly hack because default domain changed
if p.CookieDomain == "" && len(cookies) > 0 {
clr2 := *cookies[0]
clr2.Domain = req.Host
http.SetCookie(rw, &clr2)
}
} }
// SetSessionCookie adds the user's session cookie to the response // SetSessionCookie adds the user's session cookie to the response
@ -487,35 +468,13 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request,
} }
// LoadCookiedSession reads the user's authentication details from the request // LoadCookiedSession reads the user's authentication details from the request
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, time.Duration, error) { func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, error) {
var age time.Duration return p.sessionStore.Load(req)
c, err := loadCookie(req, p.CookieName)
if err != nil {
// always http.ErrNoCookie
return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
}
val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire)
if !ok {
return nil, age, errors.New("Cookie Signature not valid")
}
session, err := p.provider.SessionFromCookie(val, p.CookieCipher)
if err != nil {
return nil, age, err
}
age = time.Now().Truncate(time.Second).Sub(timestamp)
return session, age, nil
} }
// SaveSession creates a new session cookie value and sets this on the response // SaveSession creates a new session cookie value and sets this on the response
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error {
value, err := p.provider.CookieForSession(s, p.CookieCipher) return p.sessionStore.Save(rw, req, s)
if err != nil {
return err
}
p.SetSessionCookie(rw, req, value)
return nil
} }
// RobotsTxt disallows scraping pages from the OAuthProxy // RobotsTxt disallows scraping pages from the OAuthProxy
@ -835,12 +794,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
var saveSession, clearSession, revalidated bool var saveSession, clearSession, revalidated bool
remoteAddr := getRemoteAddr(req) remoteAddr := getRemoteAddr(req)
session, sessionAge, err := p.LoadCookiedSession(req) session, err := p.LoadCookiedSession(req)
if err != nil { if err != nil {
logger.Printf("Error loading cookied session: %s", err) logger.Printf("Error loading cookied session: %s", err)
} }
if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) { if session != nil && session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", sessionAge, session, p.CookieRefresh) logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh)
saveSession = true saveSession = true
} }

View File

@ -17,6 +17,7 @@ import (
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
"github.com/pusher/oauth2_proxy/providers" "github.com/pusher/oauth2_proxy/providers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -600,10 +601,15 @@ type ProcessCookieTestOpts struct {
providerValidateCookieResponse bool providerValidateCookieResponse bool
} }
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { type OptionsModifier func(*Options)
func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest {
var pcTest ProcessCookieTest var pcTest ProcessCookieTest
pcTest.opts = NewOptions() pcTest.opts = NewOptions()
for _, modifier := range modifiers {
modifier(pcTest.opts)
}
pcTest.opts.ClientID = "bazquux" pcTest.opts.ClientID = "bazquux"
pcTest.opts.ClientSecret = "xyzzyplugh" pcTest.opts.ClientSecret = "xyzzyplugh"
pcTest.opts.CookieSecret = "0123456789abcdefabcd" pcTest.opts.CookieSecret = "0123456789abcdefabcd"
@ -634,32 +640,38 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
}) })
} }
func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest {
return NewProcessCookieTest(ProcessCookieTestOpts{
providerValidateCookieResponse: true,
}, modifiers...)
}
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie { func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie {
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
} }
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error { func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error {
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) err := p.proxy.SaveSession(p.rw, p.req, s)
if err != nil { if err != nil {
return err return err
} }
for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) { for _, cookie := range p.rw.Result().Cookies() {
p.req.AddCookie(c) p.req.AddCookie(cookie)
} }
return nil return nil
} }
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) { func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) {
return p.proxy.LoadCookiedSession(p.req) return p.proxy.LoadCookiedSession(p.req)
} }
func TestLoadCookiedSession(t *testing.T) { func TestLoadCookiedSession(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()}
pcTest.SaveSession(startSession, time.Now()) pcTest.SaveSession(startSession)
session, _, err := pcTest.LoadCookiedSession() session, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, startSession.Email, session.Email)
assert.Equal(t, "john.doe@example.com", session.User) assert.Equal(t, "john.doe@example.com", session.User)
@ -669,7 +681,7 @@ func TestLoadCookiedSession(t *testing.T) {
func TestProcessCookieNoCookieError(t *testing.T) { func TestProcessCookieNoCookieError(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
session, _, err := pcTest.LoadCookiedSession() session, err := pcTest.LoadCookiedSession()
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
if session != nil { if session != nil {
t.Errorf("expected nil session. got %#v", session) t.Errorf("expected nil session. got %#v", session)
@ -677,29 +689,31 @@ func TestProcessCookieNoCookieError(t *testing.T) {
} }
func TestProcessCookieRefreshNotSet(t *testing.T) { func TestProcessCookieRefreshNotSet(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour opts.CookieExpire = time.Duration(23) * time.Hour
})
reference := time.Now().Add(time.Duration(-2) * time.Hour) reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
pcTest.SaveSession(startSession, reference) pcTest.SaveSession(startSession)
session, age, err := pcTest.LoadCookiedSession() session, err := pcTest.LoadCookiedSession()
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
if age < time.Duration(-2)*time.Hour { if session.Age() < time.Duration(-2)*time.Hour {
t.Errorf("cookie too young %v", age) t.Errorf("cookie too young %v", session.Age())
} }
assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, startSession.Email, session.Email)
} }
func TestProcessCookieFailIfCookieExpired(t *testing.T) { func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour opts.CookieExpire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
pcTest.SaveSession(startSession, reference) pcTest.SaveSession(startSession)
session, _, err := pcTest.LoadCookiedSession() session, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil { if session != nil {
t.Errorf("expected nil session %#v", session) t.Errorf("expected nil session %#v", session)
@ -707,22 +721,23 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
} }
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour opts.CookieExpire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
pcTest.SaveSession(startSession, reference) pcTest.SaveSession(startSession)
pcTest.proxy.CookieRefresh = time.Hour pcTest.proxy.CookieRefresh = time.Hour
session, _, err := pcTest.LoadCookiedSession() session, err := pcTest.LoadCookiedSession()
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil { if session != nil {
t.Errorf("expected nil session %#v", session) t.Errorf("expected nil session %#v", session)
} }
} }
func NewAuthOnlyEndpointTest() *ProcessCookieTest { func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...)
pcTest.req, _ = http.NewRequest("GET", pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil) pcTest.opts.ProxyPrefix+"/auth", nil)
return pcTest return pcTest
@ -731,8 +746,8 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest {
func TestAuthOnlyEndpointAccepted(t *testing.T) { func TestAuthOnlyEndpointAccepted(t *testing.T) {
test := NewAuthOnlyEndpointTest() test := NewAuthOnlyEndpointTest()
startSession := &sessions.SessionState{ startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
test.SaveSession(startSession, time.Now()) test.SaveSession(startSession)
test.proxy.ServeHTTP(test.rw, test.req) test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusAccepted, test.rw.Code) assert.Equal(t, http.StatusAccepted, test.rw.Code)
@ -750,12 +765,13 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
} }
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
test := NewAuthOnlyEndpointTest() test := NewAuthOnlyEndpointTest(func(opts *Options) {
test.proxy.CookieExpire = time.Duration(24) * time.Hour opts.CookieExpire = time.Duration(24) * time.Hour
})
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &sessions.SessionState{ startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
test.SaveSession(startSession, reference) test.SaveSession(startSession)
test.proxy.ServeHTTP(test.rw, test.req) test.proxy.ServeHTTP(test.rw, test.req)
assert.Equal(t, http.StatusUnauthorized, test.rw.Code) assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
@ -766,8 +782,8 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test := NewAuthOnlyEndpointTest() test := NewAuthOnlyEndpointTest()
startSession := &sessions.SessionState{ startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
test.SaveSession(startSession, time.Now()) test.SaveSession(startSession)
test.validateUser = false test.validateUser = false
test.proxy.ServeHTTP(test.rw, test.req) test.proxy.ServeHTTP(test.rw, test.req)
@ -797,8 +813,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
pcTest.opts.ProxyPrefix+"/auth", nil) pcTest.opts.ProxyPrefix+"/auth", nil)
startSession := &sessions.SessionState{ startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
pcTest.SaveSession(startSession, time.Now()) pcTest.SaveSession(startSession)
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
@ -1068,7 +1084,12 @@ func TestAjaxForbiddendRequest(t *testing.T) {
} }
func TestClearSplitCookie(t *testing.T) { func TestClearSplitCookie(t *testing.T) {
p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} opts := NewOptions()
opts.CookieName = "oauth2"
opts.CookieDomain = "abc"
store, err := cookie.NewCookieSessionStore(opts.SessionOptions.CookieStoreOptions, &opts.CookieOptions)
assert.Equal(t, err, nil)
p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store}
var rw = httptest.NewRecorder() var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil) req := httptest.NewRequest("get", "/", nil)
@ -1092,7 +1113,12 @@ func TestClearSplitCookie(t *testing.T) {
} }
func TestClearSingleCookie(t *testing.T) { func TestClearSingleCookie(t *testing.T) {
p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} opts := NewOptions()
opts.CookieName = "oauth2"
opts.CookieDomain = "abc"
store, err := cookie.NewCookieSessionStore(opts.SessionOptions.CookieStoreOptions, &opts.CookieOptions)
assert.Equal(t, err, nil)
p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store}
var rw = httptest.NewRecorder() var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil) req := httptest.NewRequest("get", "/", nil)

View File

@ -40,11 +40,14 @@ type SessionStore struct {
// Save takes a sessions.SessionState and stores the information from it // Save takes a sessions.SessionState and stores the information from it
// within Cookies set on the HTTP response writer // within Cookies set on the HTTP response writer
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
if ss.CreatedAt.IsZero() {
ss.CreatedAt = time.Now()
}
value, err := utils.CookieForSession(ss, s.CookieCipher) value, err := utils.CookieForSession(ss, s.CookieCipher)
if err != nil { if err != nil {
return err return err
} }
s.setSessionCookie(rw, req, value) s.setSessionCookie(rw, req, value, ss.CreatedAt)
return nil return nil
} }
@ -89,8 +92,8 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
} }
// setSessionCookie adds the user's session cookie to the response // setSessionCookie adds the user's session cookie to the response
func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string, created time.Time) {
for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) { for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, created) {
http.SetCookie(rw, c) http.SetCookie(rw, c)
} }
} }

View File

@ -5,6 +5,8 @@ import (
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv"
"strings"
"testing" "testing"
"time" "time"
@ -72,6 +74,16 @@ var _ = Describe("NewSessionStore", func() {
} }
}) })
It("have a signature timestamp matching session.CreatedAt", func() {
for _, cookie := range cookies {
if cookie.Value != "" {
parts := strings.Split(cookie.Value, "|")
Expect(parts).To(HaveLen(3))
Expect(parts[1]).To(Equal(strconv.Itoa(int(session.CreatedAt.Unix()))))
}
}
})
}) })
} }
@ -86,6 +98,10 @@ var _ = Describe("NewSessionStore", func() {
Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty())
}) })
It("Ensures the session CreatedAt is not zero", func() {
Expect(session.CreatedAt.IsZero()).To(BeFalse())
})
CheckCookieOptions() CheckCookieOptions()
}) })
@ -138,12 +154,15 @@ var _ = Describe("NewSessionStore", func() {
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions // Can't compare time.Time using Equal() so remove ExpiresOn from sessions
l := *loadedSession l := *loadedSession
l.CreatedAt = time.Time{}
l.ExpiresOn = time.Time{} l.ExpiresOn = time.Time{}
s := *session s := *session
s.CreatedAt = time.Time{}
s.ExpiresOn = time.Time{} s.ExpiresOn = time.Time{}
Expect(l).To(Equal(s)) Expect(l).To(Equal(s))
// Compare time.Time separately // Compare time.Time separately
Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue())
Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue())
} }
}) })