Merge pull request #148 from pusher/proxy-session-store
Proxy session store
This commit is contained in:
commit
10e240c8bf
@ -10,6 +10,7 @@
|
||||
|
||||
## Changes since v3.2.0
|
||||
|
||||
- [#148](https://github.com/pusher/outh2_proxy/pull/148) Implement SessionStore interface within proxy (@JoelSpeed)
|
||||
- [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed)
|
||||
- Allows for multiple different session storage implementations including client and server side
|
||||
- Adds tests suite for interface to ensure consistency across implementations
|
||||
|
191
oauthproxy.go
191
oauthproxy.go
@ -16,7 +16,7 @@ import (
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"github.com/yhat/wsutil"
|
||||
)
|
||||
@ -29,10 +29,6 @@ const (
|
||||
httpScheme = "http"
|
||||
httpsScheme = "https"
|
||||
|
||||
// Cookies are limited to 4kb including the length of the cookie name,
|
||||
// the cookie name can be up to 256 bytes
|
||||
maxCookieLength = 3840
|
||||
|
||||
applicationJSON = "application/json"
|
||||
)
|
||||
|
||||
@ -75,6 +71,7 @@ type OAuthProxy struct {
|
||||
redirectURL *url.URL // the url to receive requests at
|
||||
whitelistDomains []string
|
||||
provider providers.Provider
|
||||
sessionStore sessionsapi.SessionStore
|
||||
ProxyPrefix string
|
||||
SignInMessage string
|
||||
HtpasswdFile *HtpasswdFile
|
||||
@ -88,7 +85,6 @@ type OAuthProxy struct {
|
||||
PassAccessToken bool
|
||||
SetAuthorization bool
|
||||
PassAuthorization bool
|
||||
CookieCipher *cookie.Cipher
|
||||
skipAuthRegex []string
|
||||
skipAuthPreflight bool
|
||||
compiledRegex []*regexp.Regexp
|
||||
@ -218,15 +214,6 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
||||
|
||||
logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s path:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, opts.CookiePath, refresh)
|
||||
|
||||
var cipher *cookie.Cipher
|
||||
if opts.PassAccessToken || opts.SetAuthorization || opts.PassAuthorization || (opts.CookieRefresh != time.Duration(0)) {
|
||||
var err error
|
||||
cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret))
|
||||
if err != nil {
|
||||
logger.Fatal("cookie-secret error: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &OAuthProxy{
|
||||
CookieName: opts.CookieName,
|
||||
CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"),
|
||||
@ -249,6 +236,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
||||
|
||||
ProxyPrefix: opts.ProxyPrefix,
|
||||
provider: opts.provider,
|
||||
sessionStore: opts.sessionStore,
|
||||
serveMux: serveMux,
|
||||
redirectURL: redirectURL,
|
||||
whitelistDomains: opts.WhitelistDomains,
|
||||
@ -263,7 +251,6 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
||||
SetAuthorization: opts.SetAuthorization,
|
||||
PassAuthorization: opts.PassAuthorization,
|
||||
SkipProviderButton: opts.SkipProviderButton,
|
||||
CookieCipher: cipher,
|
||||
templates: loadTemplates(opts.CustomTemplatesDir),
|
||||
Footer: opts.Footer,
|
||||
}
|
||||
@ -293,7 +280,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
|
||||
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) {
|
||||
func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, err error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
}
|
||||
@ -316,104 +303,6 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, er
|
||||
return
|
||||
}
|
||||
|
||||
// MakeSessionCookie creates an http.Cookie containing the authenticated user's
|
||||
// authentication details
|
||||
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie {
|
||||
if value != "" {
|
||||
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
|
||||
}
|
||||
c := p.makeCookie(req, p.CookieName, value, expiration, now)
|
||||
if len(c.Value) > 4096-len(p.CookieName) {
|
||||
return splitCookie(c)
|
||||
}
|
||||
return []*http.Cookie{c}
|
||||
}
|
||||
|
||||
func copyCookie(c *http.Cookie) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: c.Name,
|
||||
Value: c.Value,
|
||||
Path: c.Path,
|
||||
Domain: c.Domain,
|
||||
Expires: c.Expires,
|
||||
RawExpires: c.RawExpires,
|
||||
MaxAge: c.MaxAge,
|
||||
Secure: c.Secure,
|
||||
HttpOnly: c.HttpOnly,
|
||||
Raw: c.Raw,
|
||||
Unparsed: c.Unparsed,
|
||||
}
|
||||
}
|
||||
|
||||
// splitCookie reads the full cookie generated to store the session and splits
|
||||
// it into a slice of cookies which fit within the 4kb cookie limit indexing
|
||||
// the cookies from 0
|
||||
func splitCookie(c *http.Cookie) []*http.Cookie {
|
||||
if len(c.Value) < maxCookieLength {
|
||||
return []*http.Cookie{c}
|
||||
}
|
||||
cookies := []*http.Cookie{}
|
||||
valueBytes := []byte(c.Value)
|
||||
count := 0
|
||||
for len(valueBytes) > 0 {
|
||||
new := copyCookie(c)
|
||||
new.Name = fmt.Sprintf("%s_%d", c.Name, count)
|
||||
count++
|
||||
if len(valueBytes) < maxCookieLength {
|
||||
new.Value = string(valueBytes)
|
||||
valueBytes = []byte{}
|
||||
} else {
|
||||
newValue := valueBytes[:maxCookieLength]
|
||||
valueBytes = valueBytes[maxCookieLength:]
|
||||
new.Value = string(newValue)
|
||||
}
|
||||
cookies = append(cookies, new)
|
||||
}
|
||||
return cookies
|
||||
}
|
||||
|
||||
// joinCookies takes a slice of cookies from the request and reconstructs the
|
||||
// full session cookie
|
||||
func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
|
||||
if len(cookies) == 0 {
|
||||
return nil, fmt.Errorf("list of cookies must be > 0")
|
||||
}
|
||||
if len(cookies) == 1 {
|
||||
return cookies[0], nil
|
||||
}
|
||||
c := copyCookie(cookies[0])
|
||||
for i := 1; i < len(cookies); i++ {
|
||||
c.Value += cookies[i].Value
|
||||
}
|
||||
c.Name = strings.TrimRight(c.Name, "_0")
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// loadCookie retreieves the sessions state cookie from the http request.
|
||||
// If a single cookie is present this will be returned, otherwise it attempts
|
||||
// to reconstruct a cookie split up by splitCookie
|
||||
func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
|
||||
c, err := req.Cookie(cookieName)
|
||||
if err == nil {
|
||||
return c, nil
|
||||
}
|
||||
cookies := []*http.Cookie{}
|
||||
err = nil
|
||||
count := 0
|
||||
for err == nil {
|
||||
var c *http.Cookie
|
||||
c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count))
|
||||
if err == nil {
|
||||
cookies = append(cookies, c)
|
||||
count++
|
||||
}
|
||||
}
|
||||
if len(cookies) == 0 {
|
||||
return nil, fmt.Errorf("Could not find cookie %s", cookieName)
|
||||
}
|
||||
return joinCookies(cookies)
|
||||
}
|
||||
|
||||
// MakeCSRFCookie creates a cookie for CSRF
|
||||
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
|
||||
@ -454,66 +343,18 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va
|
||||
|
||||
// ClearSessionCookie creates a cookie to unset the user's authentication cookie
|
||||
// stored in the user's session
|
||||
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
|
||||
var cookies []*http.Cookie
|
||||
|
||||
// matches CookieName, CookieName_<number>
|
||||
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", p.CookieName))
|
||||
|
||||
for _, c := range req.Cookies() {
|
||||
if cookieNameRegex.MatchString(c.Name) {
|
||||
clearCookie := p.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
|
||||
|
||||
http.SetCookie(rw, clearCookie)
|
||||
cookies = append(cookies, clearCookie)
|
||||
}
|
||||
}
|
||||
|
||||
// ugly hack because default domain changed
|
||||
if p.CookieDomain == "" && len(cookies) > 0 {
|
||||
clr2 := *cookies[0]
|
||||
clr2.Domain = req.Host
|
||||
http.SetCookie(rw, &clr2)
|
||||
}
|
||||
}
|
||||
|
||||
// SetSessionCookie adds the user's session cookie to the response
|
||||
func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
for _, c := range p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()) {
|
||||
http.SetCookie(rw, c)
|
||||
}
|
||||
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error {
|
||||
return p.sessionStore.Clear(rw, req)
|
||||
}
|
||||
|
||||
// LoadCookiedSession reads the user's authentication details from the request
|
||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) {
|
||||
var age time.Duration
|
||||
c, err := loadCookie(req, p.CookieName)
|
||||
if err != nil {
|
||||
// always http.ErrNoCookie
|
||||
return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
|
||||
}
|
||||
val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire)
|
||||
if !ok {
|
||||
return nil, age, errors.New("Cookie Signature not valid")
|
||||
}
|
||||
|
||||
session, err := p.provider.SessionFromCookie(val, p.CookieCipher)
|
||||
if err != nil {
|
||||
return nil, age, err
|
||||
}
|
||||
|
||||
age = time.Now().Truncate(time.Second).Sub(timestamp)
|
||||
return session, age, nil
|
||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||
return p.sessionStore.Load(req)
|
||||
}
|
||||
|
||||
// SaveSession creates a new session cookie value and sets this on the response
|
||||
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
||||
value, err := p.provider.CookieForSession(s, p.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.SetSessionCookie(rw, req, value)
|
||||
return nil
|
||||
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error {
|
||||
return p.sessionStore.Save(rw, req, s)
|
||||
}
|
||||
|
||||
// RobotsTxt disallows scraping pages from the OAuthProxy
|
||||
@ -694,7 +535,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
user, ok := p.ManualSignIn(rw, req)
|
||||
if ok {
|
||||
session := &sessions.SessionState{User: user}
|
||||
session := &sessionsapi.SessionState{User: user}
|
||||
p.SaveSession(rw, req, session)
|
||||
http.Redirect(rw, req, redirect, 302)
|
||||
} else {
|
||||
@ -833,12 +674,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
||||
var saveSession, clearSession, revalidated bool
|
||||
remoteAddr := getRemoteAddr(req)
|
||||
|
||||
session, sessionAge, err := p.LoadCookiedSession(req)
|
||||
session, err := p.LoadCookiedSession(req)
|
||||
if err != nil {
|
||||
logger.Printf("Error loading cookied session: %s", err)
|
||||
}
|
||||
if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
|
||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", sessionAge, session, p.CookieRefresh)
|
||||
if session != nil && session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
|
||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh)
|
||||
saveSession = true
|
||||
}
|
||||
|
||||
@ -945,7 +786,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
||||
|
||||
// CheckBasicAuth checks the requests Authorization header for basic auth
|
||||
// credentials and authenticates these against the proxies HtpasswdFile
|
||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) {
|
||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||
if p.HtpasswdFile == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@ -967,7 +808,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState,
|
||||
}
|
||||
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
|
||||
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
||||
return &sessions.SessionState{User: pair[0]}, nil
|
||||
return &sessionsapi.SessionState{User: pair[0]}, nil
|
||||
}
|
||||
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
||||
return nil, nil
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -600,10 +601,15 @@ type ProcessCookieTestOpts struct {
|
||||
providerValidateCookieResponse bool
|
||||
}
|
||||
|
||||
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
|
||||
type OptionsModifier func(*Options)
|
||||
|
||||
func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest {
|
||||
var pcTest ProcessCookieTest
|
||||
|
||||
pcTest.opts = NewOptions()
|
||||
for _, modifier := range modifiers {
|
||||
modifier(pcTest.opts)
|
||||
}
|
||||
pcTest.opts.ClientID = "bazquux"
|
||||
pcTest.opts.ClientSecret = "xyzzyplugh"
|
||||
pcTest.opts.CookieSecret = "0123456789abcdefabcd"
|
||||
@ -634,32 +640,34 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
|
||||
})
|
||||
}
|
||||
|
||||
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie {
|
||||
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
||||
func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest {
|
||||
return NewProcessCookieTest(ProcessCookieTestOpts{
|
||||
providerValidateCookieResponse: true,
|
||||
}, modifiers...)
|
||||
}
|
||||
|
||||
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error {
|
||||
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
|
||||
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error {
|
||||
err := p.proxy.SaveSession(p.rw, p.req, s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) {
|
||||
p.req.AddCookie(c)
|
||||
for _, cookie := range p.rw.Result().Cookies() {
|
||||
p.req.AddCookie(cookie)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) {
|
||||
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) {
|
||||
return p.proxy.LoadCookiedSession(p.req)
|
||||
}
|
||||
|
||||
func TestLoadCookiedSession(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
|
||||
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, time.Now())
|
||||
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()}
|
||||
pcTest.SaveSession(startSession)
|
||||
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
session, err := pcTest.LoadCookiedSession()
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, startSession.Email, session.Email)
|
||||
assert.Equal(t, "john.doe@example.com", session.User)
|
||||
@ -669,7 +677,7 @@ func TestLoadCookiedSession(t *testing.T) {
|
||||
func TestProcessCookieNoCookieError(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
session, err := pcTest.LoadCookiedSession()
|
||||
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
|
||||
if session != nil {
|
||||
t.Errorf("expected nil session. got %#v", session)
|
||||
@ -677,29 +685,31 @@ func TestProcessCookieNoCookieError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProcessCookieRefreshNotSet(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||
pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
|
||||
opts.CookieExpire = time.Duration(23) * time.Hour
|
||||
})
|
||||
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
|
||||
pcTest.SaveSession(startSession)
|
||||
|
||||
session, age, err := pcTest.LoadCookiedSession()
|
||||
session, err := pcTest.LoadCookiedSession()
|
||||
assert.Equal(t, nil, err)
|
||||
if age < time.Duration(-2)*time.Hour {
|
||||
t.Errorf("cookie too young %v", age)
|
||||
if session.Age() < time.Duration(-2)*time.Hour {
|
||||
t.Errorf("cookie too young %v", session.Age())
|
||||
}
|
||||
assert.Equal(t, startSession.Email, session.Email)
|
||||
}
|
||||
|
||||
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
|
||||
opts.CookieExpire = time.Duration(24) * time.Hour
|
||||
})
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
|
||||
pcTest.SaveSession(startSession)
|
||||
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
session, err := pcTest.LoadCookiedSession()
|
||||
assert.NotEqual(t, nil, err)
|
||||
if session != nil {
|
||||
t.Errorf("expected nil session %#v", session)
|
||||
@ -707,22 +717,23 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
|
||||
opts.CookieExpire = time.Duration(24) * time.Hour
|
||||
})
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
|
||||
pcTest.SaveSession(startSession)
|
||||
|
||||
pcTest.proxy.CookieRefresh = time.Hour
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
session, err := pcTest.LoadCookiedSession()
|
||||
assert.NotEqual(t, nil, err)
|
||||
if session != nil {
|
||||
t.Errorf("expected nil session %#v", session)
|
||||
}
|
||||
}
|
||||
|
||||
func NewAuthOnlyEndpointTest() *ProcessCookieTest {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest {
|
||||
pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...)
|
||||
pcTest.req, _ = http.NewRequest("GET",
|
||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||
return pcTest
|
||||
@ -731,8 +742,8 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest {
|
||||
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
||||
test := NewAuthOnlyEndpointTest()
|
||||
startSession := &sessions.SessionState{
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
test.SaveSession(startSession, time.Now())
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
|
||||
test.SaveSession(startSession)
|
||||
|
||||
test.proxy.ServeHTTP(test.rw, test.req)
|
||||
assert.Equal(t, http.StatusAccepted, test.rw.Code)
|
||||
@ -750,12 +761,13 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
||||
test := NewAuthOnlyEndpointTest()
|
||||
test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
test := NewAuthOnlyEndpointTest(func(opts *Options) {
|
||||
opts.CookieExpire = time.Duration(24) * time.Hour
|
||||
})
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &sessions.SessionState{
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
test.SaveSession(startSession, reference)
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
|
||||
test.SaveSession(startSession)
|
||||
|
||||
test.proxy.ServeHTTP(test.rw, test.req)
|
||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||
@ -766,8 +778,8 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
||||
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||
test := NewAuthOnlyEndpointTest()
|
||||
startSession := &sessions.SessionState{
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
test.SaveSession(startSession, time.Now())
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
|
||||
test.SaveSession(startSession)
|
||||
test.validateUser = false
|
||||
|
||||
test.proxy.ServeHTTP(test.rw, test.req)
|
||||
@ -797,8 +809,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||
|
||||
startSession := &sessions.SessionState{
|
||||
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
|
||||
pcTest.SaveSession(startSession, time.Now())
|
||||
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
|
||||
pcTest.SaveSession(startSession)
|
||||
|
||||
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
||||
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
||||
@ -930,11 +942,11 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
|
||||
|
||||
state := &sessions.SessionState{
|
||||
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
||||
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
|
||||
err = proxy.SaveSession(st.rw, req, state)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) {
|
||||
for _, c := range st.rw.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
// This is used by the upstream to validate the signature.
|
||||
@ -1068,7 +1080,12 @@ func TestAjaxForbiddendRequest(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClearSplitCookie(t *testing.T) {
|
||||
p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"}
|
||||
opts := NewOptions()
|
||||
opts.CookieName = "oauth2"
|
||||
opts.CookieDomain = "abc"
|
||||
store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions)
|
||||
assert.Equal(t, err, nil)
|
||||
p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store}
|
||||
var rw = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("get", "/", nil)
|
||||
|
||||
@ -1092,7 +1109,12 @@ func TestClearSplitCookie(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClearSingleCookie(t *testing.T) {
|
||||
p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"}
|
||||
opts := NewOptions()
|
||||
opts.CookieName = "oauth2"
|
||||
opts.CookieDomain = "abc"
|
||||
store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions)
|
||||
assert.Equal(t, err, nil)
|
||||
p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store}
|
||||
var rw = httptest.NewRecorder()
|
||||
req := httptest.NewRequest("get", "/", nil)
|
||||
|
||||
|
24
options.go
24
options.go
@ -17,8 +17,11 @@ import (
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||
sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/pkg/sessions"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
@ -111,6 +114,7 @@ type Options struct {
|
||||
proxyURLs []*url.URL
|
||||
CompiledRegex []*regexp.Regexp
|
||||
provider providers.Provider
|
||||
sessionStore sessionsapi.SessionStore
|
||||
signatureData *SignatureData
|
||||
oidcVerifier *oidc.IDTokenVerifier
|
||||
}
|
||||
@ -136,6 +140,9 @@ func NewOptions() *Options {
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieRefresh: time.Duration(0),
|
||||
},
|
||||
SessionOptions: options.SessionOptions{
|
||||
Type: "cookie",
|
||||
},
|
||||
SetXAuthRequest: false,
|
||||
SkipAuthPreflight: false,
|
||||
PassBasicAuth: true,
|
||||
@ -261,7 +268,8 @@ func (o *Options) Validate() error {
|
||||
}
|
||||
msgs = parseProviderInfo(o, msgs)
|
||||
|
||||
if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) {
|
||||
var cipher *cookie.Cipher
|
||||
if o.PassAccessToken || o.SetAuthorization || o.PassAuthorization || (o.CookieRefresh != time.Duration(0)) {
|
||||
validCookieSecretSize := false
|
||||
for _, i := range []int{16, 24, 32} {
|
||||
if len(secretBytes(o.CookieSecret)) == i {
|
||||
@ -283,9 +291,23 @@ func (o *Options) Validate() error {
|
||||
"pass_access_token == true or "+
|
||||
"cookie_refresh != 0, but is %d bytes.%s",
|
||||
len(secretBytes(o.CookieSecret)), suffix))
|
||||
} else {
|
||||
var err error
|
||||
cipher, err = cookie.NewCipher(secretBytes(o.CookieSecret))
|
||||
if err != nil {
|
||||
msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o.SessionOptions.Cipher = cipher
|
||||
sessionStore, err := sessions.NewSessionStore(&o.SessionOptions, &o.CookieOptions)
|
||||
if err != nil {
|
||||
msgs = append(msgs, fmt.Sprintf("error initialising session storage: %v", err))
|
||||
} else {
|
||||
o.sessionStore = sessionStore
|
||||
}
|
||||
|
||||
if o.CookieRefresh >= o.CookieExpire {
|
||||
msgs = append(msgs, fmt.Sprintf(
|
||||
"cookie_refresh (%s) must be less than "+
|
||||
|
@ -1,8 +1,13 @@
|
||||
package options
|
||||
|
||||
import (
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
)
|
||||
|
||||
// SessionOptions contains configuration options for the SessionStore providers.
|
||||
type SessionOptions struct {
|
||||
Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"`
|
||||
Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"`
|
||||
Cipher *cookie.Cipher
|
||||
CookieStoreOptions
|
||||
}
|
||||
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
type SessionState struct {
|
||||
AccessToken string `json:",omitempty"`
|
||||
IDToken string `json:",omitempty"`
|
||||
CreatedAt time.Time `json:"-"`
|
||||
ExpiresOn time.Time `json:"-"`
|
||||
RefreshToken string `json:",omitempty"`
|
||||
Email string `json:",omitempty"`
|
||||
@ -23,6 +24,7 @@ type SessionState struct {
|
||||
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
|
||||
type SessionStateJSON struct {
|
||||
*SessionState
|
||||
CreatedAt *time.Time `json:",omitempty"`
|
||||
ExpiresOn *time.Time `json:",omitempty"`
|
||||
}
|
||||
|
||||
@ -34,6 +36,14 @@ func (s *SessionState) IsExpired() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Age returns the age of a session
|
||||
func (s *SessionState) Age() time.Duration {
|
||||
if !s.CreatedAt.IsZero() {
|
||||
return time.Now().Truncate(time.Second).Sub(s.CreatedAt)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// String constructs a summary of the session state
|
||||
func (s *SessionState) String() string {
|
||||
o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
|
||||
@ -43,6 +53,9 @@ func (s *SessionState) String() string {
|
||||
if s.IDToken != "" {
|
||||
o += " id_token:true"
|
||||
}
|
||||
if !s.CreatedAt.IsZero() {
|
||||
o += fmt.Sprintf(" created:%s", s.CreatedAt)
|
||||
}
|
||||
if !s.ExpiresOn.IsZero() {
|
||||
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
|
||||
}
|
||||
@ -95,6 +108,9 @@ func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
|
||||
}
|
||||
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
|
||||
ssj := &SessionStateJSON{SessionState: &ss}
|
||||
if !ss.CreatedAt.IsZero() {
|
||||
ssj.CreatedAt = &ss.CreatedAt
|
||||
}
|
||||
if !ss.ExpiresOn.IsZero() {
|
||||
ssj.ExpiresOn = &ss.ExpiresOn
|
||||
}
|
||||
@ -165,8 +181,11 @@ func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
|
||||
var ss *SessionState
|
||||
err := json.Unmarshal([]byte(v), &ssj)
|
||||
if err == nil && ssj.SessionState != nil {
|
||||
// Extract SessionState and ExpiresOn value from SessionStateJSON
|
||||
// Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
|
||||
ss = ssj.SessionState
|
||||
if ssj.CreatedAt != nil {
|
||||
ss.CreatedAt = *ssj.CreatedAt
|
||||
}
|
||||
if ssj.ExpiresOn != nil {
|
||||
ss.ExpiresOn = *ssj.ExpiresOn
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
@ -35,6 +36,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||
assert.Equal(t, s.IDToken, ss.IDToken)
|
||||
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
@ -44,6 +46,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, "user@domain.com", ss.User)
|
||||
assert.NotEqual(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||
assert.NotEqual(t, s.IDToken, ss.IDToken)
|
||||
@ -59,6 +62,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
@ -71,6 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
@ -80,6 +85,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, s.User, ss.User)
|
||||
assert.NotEqual(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
|
||||
@ -89,6 +95,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
s := &sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
@ -109,6 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
@ -147,6 +155,7 @@ type testCase struct {
|
||||
// Currently only tests without cipher here because we have no way to mock
|
||||
// the random generator used in EncodeSessionState.
|
||||
func TestEncodeSessionState(t *testing.T) {
|
||||
c := time.Now()
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
|
||||
testCases := []testCase{
|
||||
@ -163,6 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
CreatedAt: c,
|
||||
ExpiresOn: e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
@ -185,6 +195,9 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
|
||||
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
|
||||
func TestDecodeSessionState(t *testing.T) {
|
||||
created := time.Now()
|
||||
createdJSON, _ := created.MarshalJSON()
|
||||
createdString := string(createdJSON)
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
eJSON, _ := e.MarshalJSON()
|
||||
eString := string(eJSON)
|
||||
@ -219,7 +232,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||
},
|
||||
{
|
||||
SessionState: sessions.SessionState{
|
||||
@ -227,10 +240,11 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
CreatedAt: created,
|
||||
ExpiresOn: e,
|
||||
RefreshToken: "refresh4321",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
||||
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
@ -316,3 +330,14 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStateAge(t *testing.T) {
|
||||
ss := &sessions.SessionState{}
|
||||
|
||||
// Created at unset so should be 0
|
||||
assert.Equal(t, time.Duration(0), ss.Age())
|
||||
|
||||
// Set CreatedAt to 1 hour ago
|
||||
ss.CreatedAt = time.Now().Add(-1 * time.Hour)
|
||||
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||
)
|
||||
|
||||
// MakeCookie constructs a cookie from the given parameters,
|
||||
@ -32,3 +33,9 @@ func MakeCookie(req *http.Request, name string, value string, path string, domai
|
||||
Expires: now.Add(expiration),
|
||||
}
|
||||
}
|
||||
|
||||
// MakeCookieFromOptions constructs a cookie based on the givemn *options.CookieOptions,
|
||||
// value and creation time
|
||||
func MakeCookieFromOptions(req *http.Request, name string, value string, opts *options.CookieOptions, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return MakeCookie(req, name, value, opts.CookiePath, opts.CookieDomain, opts.CookieHTTPOnly, opts.CookieSecure, expiration, now)
|
||||
}
|
||||
|
@ -27,36 +27,33 @@ var _ sessions.SessionStore = &SessionStore{}
|
||||
// SessionStore is an implementation of the sessions.SessionStore
|
||||
// interface that stores sessions in client side cookies
|
||||
type SessionStore struct {
|
||||
CookieCipher *cookie.Cipher
|
||||
CookieDomain string
|
||||
CookieExpire time.Duration
|
||||
CookieHTTPOnly bool
|
||||
CookieName string
|
||||
CookiePath string
|
||||
CookieSecret string
|
||||
CookieSecure bool
|
||||
CookieOptions *options.CookieOptions
|
||||
CookieCipher *cookie.Cipher
|
||||
}
|
||||
|
||||
// Save takes a sessions.SessionState and stores the information from it
|
||||
// within Cookies set on the HTTP response writer
|
||||
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
|
||||
if ss.CreatedAt.IsZero() {
|
||||
ss.CreatedAt = time.Now()
|
||||
}
|
||||
value, err := utils.CookieForSession(ss, s.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.setSessionCookie(rw, req, value)
|
||||
s.setSessionCookie(rw, req, value, ss.CreatedAt)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load reads sessions.SessionState information from Cookies within the
|
||||
// HTTP request object
|
||||
func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
|
||||
c, err := loadCookie(req, s.CookieName)
|
||||
c, err := loadCookie(req, s.CookieOptions.CookieName)
|
||||
if err != nil {
|
||||
// always http.ErrNoCookie
|
||||
return nil, fmt.Errorf("Cookie %q not present", s.CookieName)
|
||||
return nil, fmt.Errorf("Cookie %q not present", s.CookieOptions.CookieName)
|
||||
}
|
||||
val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire)
|
||||
val, _, ok := cookie.Validate(c, s.CookieOptions.CookieSecret, s.CookieOptions.CookieExpire)
|
||||
if !ok {
|
||||
return nil, errors.New("Cookie Signature not valid")
|
||||
}
|
||||
@ -74,11 +71,11 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
||||
var cookies []*http.Cookie
|
||||
|
||||
// matches CookieName, CookieName_<number>
|
||||
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName))
|
||||
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieOptions.CookieName))
|
||||
|
||||
for _, c := range req.Cookies() {
|
||||
if cookieNameRegex.MatchString(c.Name) {
|
||||
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1)
|
||||
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
|
||||
|
||||
http.SetCookie(rw, clearCookie)
|
||||
cookies = append(cookies, clearCookie)
|
||||
@ -89,60 +86,42 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
||||
}
|
||||
|
||||
// setSessionCookie adds the user's session cookie to the response
|
||||
func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) {
|
||||
func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string, created time.Time) {
|
||||
for _, c := range s.makeSessionCookie(req, val, created) {
|
||||
http.SetCookie(rw, c)
|
||||
}
|
||||
}
|
||||
|
||||
// makeSessionCookie creates an http.Cookie containing the authenticated user's
|
||||
// authentication details
|
||||
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie {
|
||||
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie {
|
||||
if value != "" {
|
||||
value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now)
|
||||
value = cookie.SignedValue(s.CookieOptions.CookieSecret, s.CookieOptions.CookieName, value, now)
|
||||
}
|
||||
c := s.makeCookie(req, s.CookieName, value, expiration)
|
||||
if len(c.Value) > 4096-len(s.CookieName) {
|
||||
c := s.makeCookie(req, s.CookieOptions.CookieName, value, s.CookieOptions.CookieExpire, now)
|
||||
if len(c.Value) > 4096-len(s.CookieOptions.CookieName) {
|
||||
return splitCookie(c)
|
||||
}
|
||||
return []*http.Cookie{c}
|
||||
}
|
||||
|
||||
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie {
|
||||
return cookies.MakeCookie(
|
||||
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return cookies.MakeCookieFromOptions(
|
||||
req,
|
||||
name,
|
||||
value,
|
||||
s.CookiePath,
|
||||
s.CookieDomain,
|
||||
s.CookieHTTPOnly,
|
||||
s.CookieSecure,
|
||||
s.CookieOptions,
|
||||
expiration,
|
||||
time.Now(),
|
||||
now,
|
||||
)
|
||||
}
|
||||
|
||||
// NewCookieSessionStore initialises a new instance of the SessionStore from
|
||||
// the configuration given
|
||||
func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
||||
var cipher *cookie.Cipher
|
||||
if len(cookieOpts.CookieSecret) > 0 {
|
||||
var err error
|
||||
cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create cipher: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func NewCookieSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
||||
return &SessionStore{
|
||||
CookieCipher: cipher,
|
||||
CookieDomain: cookieOpts.CookieDomain,
|
||||
CookieExpire: cookieOpts.CookieExpire,
|
||||
CookieHTTPOnly: cookieOpts.CookieHTTPOnly,
|
||||
CookieName: cookieOpts.CookieName,
|
||||
CookiePath: cookieOpts.CookiePath,
|
||||
CookieSecret: cookieOpts.CookieSecret,
|
||||
CookieSecure: cookieOpts.CookieSecure,
|
||||
CookieCipher: opts.Cipher,
|
||||
CookieOptions: cookieOpts,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@ import (
|
||||
func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
||||
switch opts.Type {
|
||||
case options.CookieSessionStoreType:
|
||||
return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts)
|
||||
return cookie.NewCookieSessionStore(opts, cookieOpts)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown session store type '%s'", opts.Type)
|
||||
}
|
||||
|
@ -5,16 +5,20 @@ import (
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||
sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/pkg/cookies"
|
||||
"github.com/pusher/oauth2_proxy/pkg/sessions"
|
||||
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
|
||||
sessionscookie "github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/sessions/utils"
|
||||
)
|
||||
|
||||
func TestSessionStore(t *testing.T) {
|
||||
@ -72,6 +76,16 @@ var _ = Describe("NewSessionStore", func() {
|
||||
}
|
||||
})
|
||||
|
||||
It("have a signature timestamp matching session.CreatedAt", func() {
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Value != "" {
|
||||
parts := strings.Split(cookie.Value, "|")
|
||||
Expect(parts).To(HaveLen(3))
|
||||
Expect(parts[1]).To(Equal(strconv.Itoa(int(session.CreatedAt.Unix()))))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
@ -86,6 +100,10 @@ var _ = Describe("NewSessionStore", func() {
|
||||
Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
It("Ensures the session CreatedAt is not zero", func() {
|
||||
Expect(session.CreatedAt.IsZero()).To(BeFalse())
|
||||
})
|
||||
|
||||
CheckCookieOptions()
|
||||
})
|
||||
|
||||
@ -138,12 +156,15 @@ var _ = Describe("NewSessionStore", func() {
|
||||
|
||||
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
|
||||
l := *loadedSession
|
||||
l.CreatedAt = time.Time{}
|
||||
l.ExpiresOn = time.Time{}
|
||||
s := *session
|
||||
s.CreatedAt = time.Time{}
|
||||
s.ExpiresOn = time.Time{}
|
||||
Expect(l).To(Equal(s))
|
||||
|
||||
// Compare time.Time separately
|
||||
Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue())
|
||||
Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue())
|
||||
}
|
||||
})
|
||||
@ -181,12 +202,16 @@ var _ = Describe("NewSessionStore", func() {
|
||||
SessionStoreInterfaceTests()
|
||||
})
|
||||
|
||||
Context("with a cookie-secret set", func() {
|
||||
Context("with a cipher", func() {
|
||||
BeforeEach(func() {
|
||||
secret := make([]byte, 32)
|
||||
_, err := rand.Read(secret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret)
|
||||
cipher, err := cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(cipher).ToNot(BeNil())
|
||||
opts.Cipher = cipher
|
||||
|
||||
ss, err = sessions.NewSessionStore(opts, cookieOpts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@ -231,7 +256,7 @@ var _ = Describe("NewSessionStore", func() {
|
||||
It("creates a cookie.SessionStore", func() {
|
||||
ss, err := sessions.NewSessionStore(opts, cookieOpts)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{}))
|
||||
Expect(ss).To(BeAssignableToTypeOf(&sessionscookie.SessionStore{}))
|
||||
})
|
||||
|
||||
Context("the cookie.SessionStore", func() {
|
||||
|
@ -149,6 +149,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||
RefreshToken: jsonResponse.RefreshToken,
|
||||
Email: c.Email,
|
||||
|
@ -252,6 +252,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||
Email: email,
|
||||
}
|
||||
|
@ -87,6 +87,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
|
||||
s.AccessToken = newSession.AccessToken
|
||||
s.IDToken = newSession.IDToken
|
||||
s.RefreshToken = newSession.RefreshToken
|
||||
s.CreatedAt = newSession.CreatedAt
|
||||
s.ExpiresOn = newSession.ExpiresOn
|
||||
s.Email = newSession.Email
|
||||
return
|
||||
@ -126,6 +127,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
||||
AccessToken: token.AccessToken,
|
||||
IDToken: rawIDToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: token.Expiry,
|
||||
Email: claims.Email,
|
||||
User: claims.Subject,
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
@ -72,7 +73,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat
|
||||
return
|
||||
}
|
||||
if a := v.Get("access_token"); a != "" {
|
||||
s = &sessions.SessionState{AccessToken: a}
|
||||
s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()}
|
||||
} else {
|
||||
err = fmt.Errorf("no access token found %s", body)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user