Merge pull request #148 from pusher/proxy-session-store
Proxy session store
This commit is contained in:
commit
10e240c8bf
@ -10,6 +10,7 @@
|
|||||||
|
|
||||||
## Changes since v3.2.0
|
## Changes since v3.2.0
|
||||||
|
|
||||||
|
- [#148](https://github.com/pusher/outh2_proxy/pull/148) Implement SessionStore interface within proxy (@JoelSpeed)
|
||||||
- [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed)
|
- [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed)
|
||||||
- Allows for multiple different session storage implementations including client and server side
|
- Allows for multiple different session storage implementations including client and server side
|
||||||
- Adds tests suite for interface to ensure consistency across implementations
|
- Adds tests suite for interface to ensure consistency across implementations
|
||||||
|
191
oauthproxy.go
191
oauthproxy.go
@ -16,7 +16,7 @@ import (
|
|||||||
"github.com/mbland/hmacauth"
|
"github.com/mbland/hmacauth"
|
||||||
"github.com/pusher/oauth2_proxy/cookie"
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/pusher/oauth2_proxy/providers"
|
"github.com/pusher/oauth2_proxy/providers"
|
||||||
"github.com/yhat/wsutil"
|
"github.com/yhat/wsutil"
|
||||||
)
|
)
|
||||||
@ -29,10 +29,6 @@ const (
|
|||||||
httpScheme = "http"
|
httpScheme = "http"
|
||||||
httpsScheme = "https"
|
httpsScheme = "https"
|
||||||
|
|
||||||
// Cookies are limited to 4kb including the length of the cookie name,
|
|
||||||
// the cookie name can be up to 256 bytes
|
|
||||||
maxCookieLength = 3840
|
|
||||||
|
|
||||||
applicationJSON = "application/json"
|
applicationJSON = "application/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -75,6 +71,7 @@ type OAuthProxy struct {
|
|||||||
redirectURL *url.URL // the url to receive requests at
|
redirectURL *url.URL // the url to receive requests at
|
||||||
whitelistDomains []string
|
whitelistDomains []string
|
||||||
provider providers.Provider
|
provider providers.Provider
|
||||||
|
sessionStore sessionsapi.SessionStore
|
||||||
ProxyPrefix string
|
ProxyPrefix string
|
||||||
SignInMessage string
|
SignInMessage string
|
||||||
HtpasswdFile *HtpasswdFile
|
HtpasswdFile *HtpasswdFile
|
||||||
@ -88,7 +85,6 @@ type OAuthProxy struct {
|
|||||||
PassAccessToken bool
|
PassAccessToken bool
|
||||||
SetAuthorization bool
|
SetAuthorization bool
|
||||||
PassAuthorization bool
|
PassAuthorization bool
|
||||||
CookieCipher *cookie.Cipher
|
|
||||||
skipAuthRegex []string
|
skipAuthRegex []string
|
||||||
skipAuthPreflight bool
|
skipAuthPreflight bool
|
||||||
compiledRegex []*regexp.Regexp
|
compiledRegex []*regexp.Regexp
|
||||||
@ -218,15 +214,6 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
|||||||
|
|
||||||
logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s path:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, opts.CookiePath, refresh)
|
logger.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s path:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, opts.CookiePath, refresh)
|
||||||
|
|
||||||
var cipher *cookie.Cipher
|
|
||||||
if opts.PassAccessToken || opts.SetAuthorization || opts.PassAuthorization || (opts.CookieRefresh != time.Duration(0)) {
|
|
||||||
var err error
|
|
||||||
cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret))
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal("cookie-secret error: ", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &OAuthProxy{
|
return &OAuthProxy{
|
||||||
CookieName: opts.CookieName,
|
CookieName: opts.CookieName,
|
||||||
CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"),
|
CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"),
|
||||||
@ -249,6 +236,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
|||||||
|
|
||||||
ProxyPrefix: opts.ProxyPrefix,
|
ProxyPrefix: opts.ProxyPrefix,
|
||||||
provider: opts.provider,
|
provider: opts.provider,
|
||||||
|
sessionStore: opts.sessionStore,
|
||||||
serveMux: serveMux,
|
serveMux: serveMux,
|
||||||
redirectURL: redirectURL,
|
redirectURL: redirectURL,
|
||||||
whitelistDomains: opts.WhitelistDomains,
|
whitelistDomains: opts.WhitelistDomains,
|
||||||
@ -263,7 +251,6 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
|||||||
SetAuthorization: opts.SetAuthorization,
|
SetAuthorization: opts.SetAuthorization,
|
||||||
PassAuthorization: opts.PassAuthorization,
|
PassAuthorization: opts.PassAuthorization,
|
||||||
SkipProviderButton: opts.SkipProviderButton,
|
SkipProviderButton: opts.SkipProviderButton,
|
||||||
CookieCipher: cipher,
|
|
||||||
templates: loadTemplates(opts.CustomTemplatesDir),
|
templates: loadTemplates(opts.CustomTemplatesDir),
|
||||||
Footer: opts.Footer,
|
Footer: opts.Footer,
|
||||||
}
|
}
|
||||||
@ -293,7 +280,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
|
|||||||
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
|
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) {
|
func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, err error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, errors.New("missing code")
|
return nil, errors.New("missing code")
|
||||||
}
|
}
|
||||||
@ -316,104 +303,6 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, er
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// MakeSessionCookie creates an http.Cookie containing the authenticated user's
|
|
||||||
// authentication details
|
|
||||||
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie {
|
|
||||||
if value != "" {
|
|
||||||
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
|
|
||||||
}
|
|
||||||
c := p.makeCookie(req, p.CookieName, value, expiration, now)
|
|
||||||
if len(c.Value) > 4096-len(p.CookieName) {
|
|
||||||
return splitCookie(c)
|
|
||||||
}
|
|
||||||
return []*http.Cookie{c}
|
|
||||||
}
|
|
||||||
|
|
||||||
func copyCookie(c *http.Cookie) *http.Cookie {
|
|
||||||
return &http.Cookie{
|
|
||||||
Name: c.Name,
|
|
||||||
Value: c.Value,
|
|
||||||
Path: c.Path,
|
|
||||||
Domain: c.Domain,
|
|
||||||
Expires: c.Expires,
|
|
||||||
RawExpires: c.RawExpires,
|
|
||||||
MaxAge: c.MaxAge,
|
|
||||||
Secure: c.Secure,
|
|
||||||
HttpOnly: c.HttpOnly,
|
|
||||||
Raw: c.Raw,
|
|
||||||
Unparsed: c.Unparsed,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// splitCookie reads the full cookie generated to store the session and splits
|
|
||||||
// it into a slice of cookies which fit within the 4kb cookie limit indexing
|
|
||||||
// the cookies from 0
|
|
||||||
func splitCookie(c *http.Cookie) []*http.Cookie {
|
|
||||||
if len(c.Value) < maxCookieLength {
|
|
||||||
return []*http.Cookie{c}
|
|
||||||
}
|
|
||||||
cookies := []*http.Cookie{}
|
|
||||||
valueBytes := []byte(c.Value)
|
|
||||||
count := 0
|
|
||||||
for len(valueBytes) > 0 {
|
|
||||||
new := copyCookie(c)
|
|
||||||
new.Name = fmt.Sprintf("%s_%d", c.Name, count)
|
|
||||||
count++
|
|
||||||
if len(valueBytes) < maxCookieLength {
|
|
||||||
new.Value = string(valueBytes)
|
|
||||||
valueBytes = []byte{}
|
|
||||||
} else {
|
|
||||||
newValue := valueBytes[:maxCookieLength]
|
|
||||||
valueBytes = valueBytes[maxCookieLength:]
|
|
||||||
new.Value = string(newValue)
|
|
||||||
}
|
|
||||||
cookies = append(cookies, new)
|
|
||||||
}
|
|
||||||
return cookies
|
|
||||||
}
|
|
||||||
|
|
||||||
// joinCookies takes a slice of cookies from the request and reconstructs the
|
|
||||||
// full session cookie
|
|
||||||
func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
|
|
||||||
if len(cookies) == 0 {
|
|
||||||
return nil, fmt.Errorf("list of cookies must be > 0")
|
|
||||||
}
|
|
||||||
if len(cookies) == 1 {
|
|
||||||
return cookies[0], nil
|
|
||||||
}
|
|
||||||
c := copyCookie(cookies[0])
|
|
||||||
for i := 1; i < len(cookies); i++ {
|
|
||||||
c.Value += cookies[i].Value
|
|
||||||
}
|
|
||||||
c.Name = strings.TrimRight(c.Name, "_0")
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadCookie retreieves the sessions state cookie from the http request.
|
|
||||||
// If a single cookie is present this will be returned, otherwise it attempts
|
|
||||||
// to reconstruct a cookie split up by splitCookie
|
|
||||||
func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
|
|
||||||
c, err := req.Cookie(cookieName)
|
|
||||||
if err == nil {
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
cookies := []*http.Cookie{}
|
|
||||||
err = nil
|
|
||||||
count := 0
|
|
||||||
for err == nil {
|
|
||||||
var c *http.Cookie
|
|
||||||
c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count))
|
|
||||||
if err == nil {
|
|
||||||
cookies = append(cookies, c)
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(cookies) == 0 {
|
|
||||||
return nil, fmt.Errorf("Could not find cookie %s", cookieName)
|
|
||||||
}
|
|
||||||
return joinCookies(cookies)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MakeCSRFCookie creates a cookie for CSRF
|
// MakeCSRFCookie creates a cookie for CSRF
|
||||||
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
|
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
|
||||||
@ -454,66 +343,18 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va
|
|||||||
|
|
||||||
// ClearSessionCookie creates a cookie to unset the user's authentication cookie
|
// ClearSessionCookie creates a cookie to unset the user's authentication cookie
|
||||||
// stored in the user's session
|
// stored in the user's session
|
||||||
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
|
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) error {
|
||||||
var cookies []*http.Cookie
|
return p.sessionStore.Clear(rw, req)
|
||||||
|
|
||||||
// matches CookieName, CookieName_<number>
|
|
||||||
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", p.CookieName))
|
|
||||||
|
|
||||||
for _, c := range req.Cookies() {
|
|
||||||
if cookieNameRegex.MatchString(c.Name) {
|
|
||||||
clearCookie := p.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
|
|
||||||
|
|
||||||
http.SetCookie(rw, clearCookie)
|
|
||||||
cookies = append(cookies, clearCookie)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ugly hack because default domain changed
|
|
||||||
if p.CookieDomain == "" && len(cookies) > 0 {
|
|
||||||
clr2 := *cookies[0]
|
|
||||||
clr2.Domain = req.Host
|
|
||||||
http.SetCookie(rw, &clr2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSessionCookie adds the user's session cookie to the response
|
|
||||||
func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
|
||||||
for _, c := range p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()) {
|
|
||||||
http.SetCookie(rw, c)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadCookiedSession reads the user's authentication details from the request
|
// LoadCookiedSession reads the user's authentication details from the request
|
||||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) {
|
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||||
var age time.Duration
|
return p.sessionStore.Load(req)
|
||||||
c, err := loadCookie(req, p.CookieName)
|
|
||||||
if err != nil {
|
|
||||||
// always http.ErrNoCookie
|
|
||||||
return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
|
|
||||||
}
|
|
||||||
val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire)
|
|
||||||
if !ok {
|
|
||||||
return nil, age, errors.New("Cookie Signature not valid")
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := p.provider.SessionFromCookie(val, p.CookieCipher)
|
|
||||||
if err != nil {
|
|
||||||
return nil, age, err
|
|
||||||
}
|
|
||||||
|
|
||||||
age = time.Now().Truncate(time.Second).Sub(timestamp)
|
|
||||||
return session, age, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveSession creates a new session cookie value and sets this on the response
|
// SaveSession creates a new session cookie value and sets this on the response
|
||||||
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error {
|
||||||
value, err := p.provider.CookieForSession(s, p.CookieCipher)
|
return p.sessionStore.Save(rw, req, s)
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
p.SetSessionCookie(rw, req, value)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RobotsTxt disallows scraping pages from the OAuthProxy
|
// RobotsTxt disallows scraping pages from the OAuthProxy
|
||||||
@ -694,7 +535,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
user, ok := p.ManualSignIn(rw, req)
|
user, ok := p.ManualSignIn(rw, req)
|
||||||
if ok {
|
if ok {
|
||||||
session := &sessions.SessionState{User: user}
|
session := &sessionsapi.SessionState{User: user}
|
||||||
p.SaveSession(rw, req, session)
|
p.SaveSession(rw, req, session)
|
||||||
http.Redirect(rw, req, redirect, 302)
|
http.Redirect(rw, req, redirect, 302)
|
||||||
} else {
|
} else {
|
||||||
@ -833,12 +674,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
|||||||
var saveSession, clearSession, revalidated bool
|
var saveSession, clearSession, revalidated bool
|
||||||
remoteAddr := getRemoteAddr(req)
|
remoteAddr := getRemoteAddr(req)
|
||||||
|
|
||||||
session, sessionAge, err := p.LoadCookiedSession(req)
|
session, err := p.LoadCookiedSession(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("Error loading cookied session: %s", err)
|
logger.Printf("Error loading cookied session: %s", err)
|
||||||
}
|
}
|
||||||
if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
|
if session != nil && session.Age() > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
|
||||||
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", sessionAge, session, p.CookieRefresh)
|
logger.Printf("Refreshing %s old session cookie for %s (refresh after %s)", session.Age(), session, p.CookieRefresh)
|
||||||
saveSession = true
|
saveSession = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -945,7 +786,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
|||||||
|
|
||||||
// CheckBasicAuth checks the requests Authorization header for basic auth
|
// CheckBasicAuth checks the requests Authorization header for basic auth
|
||||||
// credentials and authenticates these against the proxies HtpasswdFile
|
// credentials and authenticates these against the proxies HtpasswdFile
|
||||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) {
|
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) {
|
||||||
if p.HtpasswdFile == nil {
|
if p.HtpasswdFile == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -967,7 +808,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState,
|
|||||||
}
|
}
|
||||||
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
|
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
|
||||||
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
||||||
return &sessions.SessionState{User: pair[0]}, nil
|
return &sessionsapi.SessionState{User: pair[0]}, nil
|
||||||
}
|
}
|
||||||
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/mbland/hmacauth"
|
"github.com/mbland/hmacauth"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
|
||||||
"github.com/pusher/oauth2_proxy/providers"
|
"github.com/pusher/oauth2_proxy/providers"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -600,10 +601,15 @@ type ProcessCookieTestOpts struct {
|
|||||||
providerValidateCookieResponse bool
|
providerValidateCookieResponse bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
|
type OptionsModifier func(*Options)
|
||||||
|
|
||||||
|
func NewProcessCookieTest(opts ProcessCookieTestOpts, modifiers ...OptionsModifier) *ProcessCookieTest {
|
||||||
var pcTest ProcessCookieTest
|
var pcTest ProcessCookieTest
|
||||||
|
|
||||||
pcTest.opts = NewOptions()
|
pcTest.opts = NewOptions()
|
||||||
|
for _, modifier := range modifiers {
|
||||||
|
modifier(pcTest.opts)
|
||||||
|
}
|
||||||
pcTest.opts.ClientID = "bazquux"
|
pcTest.opts.ClientID = "bazquux"
|
||||||
pcTest.opts.ClientSecret = "xyzzyplugh"
|
pcTest.opts.ClientSecret = "xyzzyplugh"
|
||||||
pcTest.opts.CookieSecret = "0123456789abcdefabcd"
|
pcTest.opts.CookieSecret = "0123456789abcdefabcd"
|
||||||
@ -634,32 +640,34 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie {
|
func NewProcessCookieTestWithOptionsModifiers(modifiers ...OptionsModifier) *ProcessCookieTest {
|
||||||
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
return NewProcessCookieTest(ProcessCookieTestOpts{
|
||||||
|
providerValidateCookieResponse: true,
|
||||||
|
}, modifiers...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error {
|
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState) error {
|
||||||
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
|
err := p.proxy.SaveSession(p.rw, p.req, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) {
|
for _, cookie := range p.rw.Result().Cookies() {
|
||||||
p.req.AddCookie(c)
|
p.req.AddCookie(cookie)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) {
|
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, error) {
|
||||||
return p.proxy.LoadCookiedSession(p.req)
|
return p.proxy.LoadCookiedSession(p.req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadCookiedSession(t *testing.T) {
|
func TestLoadCookiedSession(t *testing.T) {
|
||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithDefaults()
|
||||||
|
|
||||||
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token", CreatedAt: time.Now()}
|
||||||
pcTest.SaveSession(startSession, time.Now())
|
pcTest.SaveSession(startSession)
|
||||||
|
|
||||||
session, _, err := pcTest.LoadCookiedSession()
|
session, err := pcTest.LoadCookiedSession()
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, startSession.Email, session.Email)
|
assert.Equal(t, startSession.Email, session.Email)
|
||||||
assert.Equal(t, "john.doe@example.com", session.User)
|
assert.Equal(t, "john.doe@example.com", session.User)
|
||||||
@ -669,7 +677,7 @@ func TestLoadCookiedSession(t *testing.T) {
|
|||||||
func TestProcessCookieNoCookieError(t *testing.T) {
|
func TestProcessCookieNoCookieError(t *testing.T) {
|
||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithDefaults()
|
||||||
|
|
||||||
session, _, err := pcTest.LoadCookiedSession()
|
session, err := pcTest.LoadCookiedSession()
|
||||||
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
|
assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
|
||||||
if session != nil {
|
if session != nil {
|
||||||
t.Errorf("expected nil session. got %#v", session)
|
t.Errorf("expected nil session. got %#v", session)
|
||||||
@ -677,29 +685,31 @@ func TestProcessCookieNoCookieError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookieRefreshNotSet(t *testing.T) {
|
func TestProcessCookieRefreshNotSet(t *testing.T) {
|
||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
|
||||||
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
|
opts.CookieExpire = time.Duration(23) * time.Hour
|
||||||
|
})
|
||||||
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||||
|
|
||||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
|
||||||
pcTest.SaveSession(startSession, reference)
|
pcTest.SaveSession(startSession)
|
||||||
|
|
||||||
session, age, err := pcTest.LoadCookiedSession()
|
session, err := pcTest.LoadCookiedSession()
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
if age < time.Duration(-2)*time.Hour {
|
if session.Age() < time.Duration(-2)*time.Hour {
|
||||||
t.Errorf("cookie too young %v", age)
|
t.Errorf("cookie too young %v", session.Age())
|
||||||
}
|
}
|
||||||
assert.Equal(t, startSession.Email, session.Email)
|
assert.Equal(t, startSession.Email, session.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
|
||||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
opts.CookieExpire = time.Duration(24) * time.Hour
|
||||||
|
})
|
||||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
|
||||||
pcTest.SaveSession(startSession, reference)
|
pcTest.SaveSession(startSession)
|
||||||
|
|
||||||
session, _, err := pcTest.LoadCookiedSession()
|
session, err := pcTest.LoadCookiedSession()
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
if session != nil {
|
if session != nil {
|
||||||
t.Errorf("expected nil session %#v", session)
|
t.Errorf("expected nil session %#v", session)
|
||||||
@ -707,22 +717,23 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithOptionsModifiers(func(opts *Options) {
|
||||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
opts.CookieExpire = time.Duration(24) * time.Hour
|
||||||
|
})
|
||||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
|
||||||
pcTest.SaveSession(startSession, reference)
|
pcTest.SaveSession(startSession)
|
||||||
|
|
||||||
pcTest.proxy.CookieRefresh = time.Hour
|
pcTest.proxy.CookieRefresh = time.Hour
|
||||||
session, _, err := pcTest.LoadCookiedSession()
|
session, err := pcTest.LoadCookiedSession()
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
if session != nil {
|
if session != nil {
|
||||||
t.Errorf("expected nil session %#v", session)
|
t.Errorf("expected nil session %#v", session)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthOnlyEndpointTest() *ProcessCookieTest {
|
func NewAuthOnlyEndpointTest(modifiers ...OptionsModifier) *ProcessCookieTest {
|
||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithOptionsModifiers(modifiers...)
|
||||||
pcTest.req, _ = http.NewRequest("GET",
|
pcTest.req, _ = http.NewRequest("GET",
|
||||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||||
return pcTest
|
return pcTest
|
||||||
@ -731,8 +742,8 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest {
|
|||||||
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
||||||
test := NewAuthOnlyEndpointTest()
|
test := NewAuthOnlyEndpointTest()
|
||||||
startSession := &sessions.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
|
||||||
test.SaveSession(startSession, time.Now())
|
test.SaveSession(startSession)
|
||||||
|
|
||||||
test.proxy.ServeHTTP(test.rw, test.req)
|
test.proxy.ServeHTTP(test.rw, test.req)
|
||||||
assert.Equal(t, http.StatusAccepted, test.rw.Code)
|
assert.Equal(t, http.StatusAccepted, test.rw.Code)
|
||||||
@ -750,12 +761,13 @@ func TestAuthOnlyEndpointUnauthorizedOnNoCookieSetError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
||||||
test := NewAuthOnlyEndpointTest()
|
test := NewAuthOnlyEndpointTest(func(opts *Options) {
|
||||||
test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
opts.CookieExpire = time.Duration(24) * time.Hour
|
||||||
|
})
|
||||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||||
startSession := &sessions.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: reference}
|
||||||
test.SaveSession(startSession, reference)
|
test.SaveSession(startSession)
|
||||||
|
|
||||||
test.proxy.ServeHTTP(test.rw, test.req)
|
test.proxy.ServeHTTP(test.rw, test.req)
|
||||||
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
assert.Equal(t, http.StatusUnauthorized, test.rw.Code)
|
||||||
@ -766,8 +778,8 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
|||||||
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||||
test := NewAuthOnlyEndpointTest()
|
test := NewAuthOnlyEndpointTest()
|
||||||
startSession := &sessions.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()}
|
||||||
test.SaveSession(startSession, time.Now())
|
test.SaveSession(startSession)
|
||||||
test.validateUser = false
|
test.validateUser = false
|
||||||
|
|
||||||
test.proxy.ServeHTTP(test.rw, test.req)
|
test.proxy.ServeHTTP(test.rw, test.req)
|
||||||
@ -797,8 +809,8 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
|||||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||||
|
|
||||||
startSession := &sessions.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token", CreatedAt: time.Now()}
|
||||||
pcTest.SaveSession(startSession, time.Now())
|
pcTest.SaveSession(startSession)
|
||||||
|
|
||||||
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req)
|
||||||
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
assert.Equal(t, http.StatusAccepted, pcTest.rw.Code)
|
||||||
@ -930,11 +942,11 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
|
|||||||
|
|
||||||
state := &sessions.SessionState{
|
state := &sessions.SessionState{
|
||||||
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
||||||
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
|
err = proxy.SaveSession(st.rw, req, state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) {
|
for _, c := range st.rw.Result().Cookies() {
|
||||||
req.AddCookie(c)
|
req.AddCookie(c)
|
||||||
}
|
}
|
||||||
// This is used by the upstream to validate the signature.
|
// This is used by the upstream to validate the signature.
|
||||||
@ -1068,7 +1080,12 @@ func TestAjaxForbiddendRequest(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClearSplitCookie(t *testing.T) {
|
func TestClearSplitCookie(t *testing.T) {
|
||||||
p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"}
|
opts := NewOptions()
|
||||||
|
opts.CookieName = "oauth2"
|
||||||
|
opts.CookieDomain = "abc"
|
||||||
|
store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions)
|
||||||
|
assert.Equal(t, err, nil)
|
||||||
|
p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store}
|
||||||
var rw = httptest.NewRecorder()
|
var rw = httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("get", "/", nil)
|
req := httptest.NewRequest("get", "/", nil)
|
||||||
|
|
||||||
@ -1092,7 +1109,12 @@ func TestClearSplitCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClearSingleCookie(t *testing.T) {
|
func TestClearSingleCookie(t *testing.T) {
|
||||||
p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"}
|
opts := NewOptions()
|
||||||
|
opts.CookieName = "oauth2"
|
||||||
|
opts.CookieDomain = "abc"
|
||||||
|
store, err := cookie.NewCookieSessionStore(&opts.SessionOptions, &opts.CookieOptions)
|
||||||
|
assert.Equal(t, err, nil)
|
||||||
|
p := OAuthProxy{CookieName: opts.CookieName, CookieDomain: opts.CookieDomain, sessionStore: store}
|
||||||
var rw = httptest.NewRecorder()
|
var rw = httptest.NewRecorder()
|
||||||
req := httptest.NewRequest("get", "/", nil)
|
req := httptest.NewRequest("get", "/", nil)
|
||||||
|
|
||||||
|
24
options.go
24
options.go
@ -17,8 +17,11 @@ import (
|
|||||||
oidc "github.com/coreos/go-oidc"
|
oidc "github.com/coreos/go-oidc"
|
||||||
"github.com/dgrijalva/jwt-go"
|
"github.com/dgrijalva/jwt-go"
|
||||||
"github.com/mbland/hmacauth"
|
"github.com/mbland/hmacauth"
|
||||||
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||||
|
sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/sessions"
|
||||||
"github.com/pusher/oauth2_proxy/providers"
|
"github.com/pusher/oauth2_proxy/providers"
|
||||||
"gopkg.in/natefinch/lumberjack.v2"
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
)
|
)
|
||||||
@ -111,6 +114,7 @@ type Options struct {
|
|||||||
proxyURLs []*url.URL
|
proxyURLs []*url.URL
|
||||||
CompiledRegex []*regexp.Regexp
|
CompiledRegex []*regexp.Regexp
|
||||||
provider providers.Provider
|
provider providers.Provider
|
||||||
|
sessionStore sessionsapi.SessionStore
|
||||||
signatureData *SignatureData
|
signatureData *SignatureData
|
||||||
oidcVerifier *oidc.IDTokenVerifier
|
oidcVerifier *oidc.IDTokenVerifier
|
||||||
}
|
}
|
||||||
@ -136,6 +140,9 @@ func NewOptions() *Options {
|
|||||||
CookieExpire: time.Duration(168) * time.Hour,
|
CookieExpire: time.Duration(168) * time.Hour,
|
||||||
CookieRefresh: time.Duration(0),
|
CookieRefresh: time.Duration(0),
|
||||||
},
|
},
|
||||||
|
SessionOptions: options.SessionOptions{
|
||||||
|
Type: "cookie",
|
||||||
|
},
|
||||||
SetXAuthRequest: false,
|
SetXAuthRequest: false,
|
||||||
SkipAuthPreflight: false,
|
SkipAuthPreflight: false,
|
||||||
PassBasicAuth: true,
|
PassBasicAuth: true,
|
||||||
@ -261,7 +268,8 @@ func (o *Options) Validate() error {
|
|||||||
}
|
}
|
||||||
msgs = parseProviderInfo(o, msgs)
|
msgs = parseProviderInfo(o, msgs)
|
||||||
|
|
||||||
if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) {
|
var cipher *cookie.Cipher
|
||||||
|
if o.PassAccessToken || o.SetAuthorization || o.PassAuthorization || (o.CookieRefresh != time.Duration(0)) {
|
||||||
validCookieSecretSize := false
|
validCookieSecretSize := false
|
||||||
for _, i := range []int{16, 24, 32} {
|
for _, i := range []int{16, 24, 32} {
|
||||||
if len(secretBytes(o.CookieSecret)) == i {
|
if len(secretBytes(o.CookieSecret)) == i {
|
||||||
@ -283,8 +291,22 @@ func (o *Options) Validate() error {
|
|||||||
"pass_access_token == true or "+
|
"pass_access_token == true or "+
|
||||||
"cookie_refresh != 0, but is %d bytes.%s",
|
"cookie_refresh != 0, but is %d bytes.%s",
|
||||||
len(secretBytes(o.CookieSecret)), suffix))
|
len(secretBytes(o.CookieSecret)), suffix))
|
||||||
|
} else {
|
||||||
|
var err error
|
||||||
|
cipher, err = cookie.NewCipher(secretBytes(o.CookieSecret))
|
||||||
|
if err != nil {
|
||||||
|
msgs = append(msgs, fmt.Sprintf("cookie-secret error: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
o.SessionOptions.Cipher = cipher
|
||||||
|
sessionStore, err := sessions.NewSessionStore(&o.SessionOptions, &o.CookieOptions)
|
||||||
|
if err != nil {
|
||||||
|
msgs = append(msgs, fmt.Sprintf("error initialising session storage: %v", err))
|
||||||
|
} else {
|
||||||
|
o.sessionStore = sessionStore
|
||||||
|
}
|
||||||
|
|
||||||
if o.CookieRefresh >= o.CookieExpire {
|
if o.CookieRefresh >= o.CookieExpire {
|
||||||
msgs = append(msgs, fmt.Sprintf(
|
msgs = append(msgs, fmt.Sprintf(
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
package options
|
package options
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
|
)
|
||||||
|
|
||||||
// SessionOptions contains configuration options for the SessionStore providers.
|
// SessionOptions contains configuration options for the SessionStore providers.
|
||||||
type SessionOptions struct {
|
type SessionOptions struct {
|
||||||
Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"`
|
Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"`
|
||||||
|
Cipher *cookie.Cipher
|
||||||
CookieStoreOptions
|
CookieStoreOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
type SessionState struct {
|
type SessionState struct {
|
||||||
AccessToken string `json:",omitempty"`
|
AccessToken string `json:",omitempty"`
|
||||||
IDToken string `json:",omitempty"`
|
IDToken string `json:",omitempty"`
|
||||||
|
CreatedAt time.Time `json:"-"`
|
||||||
ExpiresOn time.Time `json:"-"`
|
ExpiresOn time.Time `json:"-"`
|
||||||
RefreshToken string `json:",omitempty"`
|
RefreshToken string `json:",omitempty"`
|
||||||
Email string `json:",omitempty"`
|
Email string `json:",omitempty"`
|
||||||
@ -23,6 +24,7 @@ type SessionState struct {
|
|||||||
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
|
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
|
||||||
type SessionStateJSON struct {
|
type SessionStateJSON struct {
|
||||||
*SessionState
|
*SessionState
|
||||||
|
CreatedAt *time.Time `json:",omitempty"`
|
||||||
ExpiresOn *time.Time `json:",omitempty"`
|
ExpiresOn *time.Time `json:",omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -34,6 +36,14 @@ func (s *SessionState) IsExpired() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Age returns the age of a session
|
||||||
|
func (s *SessionState) Age() time.Duration {
|
||||||
|
if !s.CreatedAt.IsZero() {
|
||||||
|
return time.Now().Truncate(time.Second).Sub(s.CreatedAt)
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
// String constructs a summary of the session state
|
// String constructs a summary of the session state
|
||||||
func (s *SessionState) String() string {
|
func (s *SessionState) String() string {
|
||||||
o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
|
o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
|
||||||
@ -43,6 +53,9 @@ func (s *SessionState) String() string {
|
|||||||
if s.IDToken != "" {
|
if s.IDToken != "" {
|
||||||
o += " id_token:true"
|
o += " id_token:true"
|
||||||
}
|
}
|
||||||
|
if !s.CreatedAt.IsZero() {
|
||||||
|
o += fmt.Sprintf(" created:%s", s.CreatedAt)
|
||||||
|
}
|
||||||
if !s.ExpiresOn.IsZero() {
|
if !s.ExpiresOn.IsZero() {
|
||||||
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
|
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
|
||||||
}
|
}
|
||||||
@ -95,6 +108,9 @@ func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
|
|||||||
}
|
}
|
||||||
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
|
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
|
||||||
ssj := &SessionStateJSON{SessionState: &ss}
|
ssj := &SessionStateJSON{SessionState: &ss}
|
||||||
|
if !ss.CreatedAt.IsZero() {
|
||||||
|
ssj.CreatedAt = &ss.CreatedAt
|
||||||
|
}
|
||||||
if !ss.ExpiresOn.IsZero() {
|
if !ss.ExpiresOn.IsZero() {
|
||||||
ssj.ExpiresOn = &ss.ExpiresOn
|
ssj.ExpiresOn = &ss.ExpiresOn
|
||||||
}
|
}
|
||||||
@ -165,8 +181,11 @@ func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
|
|||||||
var ss *SessionState
|
var ss *SessionState
|
||||||
err := json.Unmarshal([]byte(v), &ssj)
|
err := json.Unmarshal([]byte(v), &ssj)
|
||||||
if err == nil && ssj.SessionState != nil {
|
if err == nil && ssj.SessionState != nil {
|
||||||
// Extract SessionState and ExpiresOn value from SessionStateJSON
|
// Extract SessionState and CreatedAt,ExpiresOn value from SessionStateJSON
|
||||||
ss = ssj.SessionState
|
ss = ssj.SessionState
|
||||||
|
if ssj.CreatedAt != nil {
|
||||||
|
ss.CreatedAt = *ssj.CreatedAt
|
||||||
|
}
|
||||||
if ssj.ExpiresOn != nil {
|
if ssj.ExpiresOn != nil {
|
||||||
ss.ExpiresOn = *ssj.ExpiresOn
|
ss.ExpiresOn = *ssj.ExpiresOn
|
||||||
}
|
}
|
||||||
|
@ -22,6 +22,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
IDToken: "rawtoken1234",
|
IDToken: "rawtoken1234",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||||
RefreshToken: "refresh4321",
|
RefreshToken: "refresh4321",
|
||||||
}
|
}
|
||||||
@ -35,6 +36,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||||
assert.Equal(t, s.IDToken, ss.IDToken)
|
assert.Equal(t, s.IDToken, ss.IDToken)
|
||||||
|
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
|
||||||
@ -44,6 +46,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, "user@domain.com", ss.User)
|
assert.NotEqual(t, "user@domain.com", ss.User)
|
||||||
assert.NotEqual(t, s.Email, ss.Email)
|
assert.NotEqual(t, s.Email, ss.Email)
|
||||||
|
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||||
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||||
assert.NotEqual(t, s.IDToken, ss.IDToken)
|
assert.NotEqual(t, s.IDToken, ss.IDToken)
|
||||||
@ -59,6 +62,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
User: "just-user",
|
User: "just-user",
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||||
RefreshToken: "refresh4321",
|
RefreshToken: "refresh4321",
|
||||||
}
|
}
|
||||||
@ -71,6 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
assert.Equal(t, s.User, ss.User)
|
assert.Equal(t, s.User, ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||||
|
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
|
||||||
@ -80,6 +85,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, s.User, ss.User)
|
assert.NotEqual(t, s.User, ss.User)
|
||||||
assert.NotEqual(t, s.Email, ss.Email)
|
assert.NotEqual(t, s.Email, ss.Email)
|
||||||
|
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
|
||||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||||
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||||
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
|
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
|
||||||
@ -89,6 +95,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
|||||||
s := &sessions.SessionState{
|
s := &sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||||
RefreshToken: "refresh4321",
|
RefreshToken: "refresh4321",
|
||||||
}
|
}
|
||||||
@ -109,6 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
|||||||
User: "just-user",
|
User: "just-user",
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||||
RefreshToken: "refresh4321",
|
RefreshToken: "refresh4321",
|
||||||
}
|
}
|
||||||
@ -147,6 +155,7 @@ type testCase struct {
|
|||||||
// Currently only tests without cipher here because we have no way to mock
|
// Currently only tests without cipher here because we have no way to mock
|
||||||
// the random generator used in EncodeSessionState.
|
// the random generator used in EncodeSessionState.
|
||||||
func TestEncodeSessionState(t *testing.T) {
|
func TestEncodeSessionState(t *testing.T) {
|
||||||
|
c := time.Now()
|
||||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
@ -163,6 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
|
|||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
IDToken: "rawtoken1234",
|
IDToken: "rawtoken1234",
|
||||||
|
CreatedAt: c,
|
||||||
ExpiresOn: e,
|
ExpiresOn: e,
|
||||||
RefreshToken: "refresh4321",
|
RefreshToken: "refresh4321",
|
||||||
},
|
},
|
||||||
@ -185,6 +195,9 @@ func TestEncodeSessionState(t *testing.T) {
|
|||||||
|
|
||||||
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
|
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
|
||||||
func TestDecodeSessionState(t *testing.T) {
|
func TestDecodeSessionState(t *testing.T) {
|
||||||
|
created := time.Now()
|
||||||
|
createdJSON, _ := created.MarshalJSON()
|
||||||
|
createdString := string(createdJSON)
|
||||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||||
eJSON, _ := e.MarshalJSON()
|
eJSON, _ := e.MarshalJSON()
|
||||||
eString := string(eJSON)
|
eString := string(eJSON)
|
||||||
@ -219,7 +232,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: sessions.SessionState{
|
SessionState: sessions.SessionState{
|
||||||
@ -227,10 +240,11 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
IDToken: "rawtoken1234",
|
IDToken: "rawtoken1234",
|
||||||
|
CreatedAt: created,
|
||||||
ExpiresOn: e,
|
ExpiresOn: e,
|
||||||
RefreshToken: "refresh4321",
|
RefreshToken: "refresh4321",
|
||||||
},
|
},
|
||||||
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
Encoded: fmt.Sprintf(`{"Email":"FsKKYrTWZWrxSOAqA/fTNAUZS5QWCqOBjuAbBlbVOw==","User":"rT6JP3dxQhxUhkWrrd7yt6c1mDVyQCVVxw==","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","CreatedAt":%s,"ExpiresOn":%s}`, createdString, eString),
|
||||||
Cipher: c,
|
Cipher: c,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -316,3 +330,14 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSessionStateAge(t *testing.T) {
|
||||||
|
ss := &sessions.SessionState{}
|
||||||
|
|
||||||
|
// Created at unset so should be 0
|
||||||
|
assert.Equal(t, time.Duration(0), ss.Age())
|
||||||
|
|
||||||
|
// Set CreatedAt to 1 hour ago
|
||||||
|
ss.CreatedAt = time.Now().Add(-1 * time.Hour)
|
||||||
|
assert.Equal(t, time.Hour, ss.Age().Round(time.Minute))
|
||||||
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MakeCookie constructs a cookie from the given parameters,
|
// MakeCookie constructs a cookie from the given parameters,
|
||||||
@ -32,3 +33,9 @@ func MakeCookie(req *http.Request, name string, value string, path string, domai
|
|||||||
Expires: now.Add(expiration),
|
Expires: now.Add(expiration),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MakeCookieFromOptions constructs a cookie based on the givemn *options.CookieOptions,
|
||||||
|
// value and creation time
|
||||||
|
func MakeCookieFromOptions(req *http.Request, name string, value string, opts *options.CookieOptions, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
|
return MakeCookie(req, name, value, opts.CookiePath, opts.CookieDomain, opts.CookieHTTPOnly, opts.CookieSecure, expiration, now)
|
||||||
|
}
|
||||||
|
@ -27,36 +27,33 @@ var _ sessions.SessionStore = &SessionStore{}
|
|||||||
// SessionStore is an implementation of the sessions.SessionStore
|
// SessionStore is an implementation of the sessions.SessionStore
|
||||||
// interface that stores sessions in client side cookies
|
// interface that stores sessions in client side cookies
|
||||||
type SessionStore struct {
|
type SessionStore struct {
|
||||||
|
CookieOptions *options.CookieOptions
|
||||||
CookieCipher *cookie.Cipher
|
CookieCipher *cookie.Cipher
|
||||||
CookieDomain string
|
|
||||||
CookieExpire time.Duration
|
|
||||||
CookieHTTPOnly bool
|
|
||||||
CookieName string
|
|
||||||
CookiePath string
|
|
||||||
CookieSecret string
|
|
||||||
CookieSecure bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save takes a sessions.SessionState and stores the information from it
|
// Save takes a sessions.SessionState and stores the information from it
|
||||||
// within Cookies set on the HTTP response writer
|
// within Cookies set on the HTTP response writer
|
||||||
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
|
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
|
||||||
|
if ss.CreatedAt.IsZero() {
|
||||||
|
ss.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
value, err := utils.CookieForSession(ss, s.CookieCipher)
|
value, err := utils.CookieForSession(ss, s.CookieCipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.setSessionCookie(rw, req, value)
|
s.setSessionCookie(rw, req, value, ss.CreatedAt)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load reads sessions.SessionState information from Cookies within the
|
// Load reads sessions.SessionState information from Cookies within the
|
||||||
// HTTP request object
|
// HTTP request object
|
||||||
func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
|
func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
|
||||||
c, err := loadCookie(req, s.CookieName)
|
c, err := loadCookie(req, s.CookieOptions.CookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// always http.ErrNoCookie
|
// always http.ErrNoCookie
|
||||||
return nil, fmt.Errorf("Cookie %q not present", s.CookieName)
|
return nil, fmt.Errorf("Cookie %q not present", s.CookieOptions.CookieName)
|
||||||
}
|
}
|
||||||
val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire)
|
val, _, ok := cookie.Validate(c, s.CookieOptions.CookieSecret, s.CookieOptions.CookieExpire)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("Cookie Signature not valid")
|
return nil, errors.New("Cookie Signature not valid")
|
||||||
}
|
}
|
||||||
@ -74,11 +71,11 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
|||||||
var cookies []*http.Cookie
|
var cookies []*http.Cookie
|
||||||
|
|
||||||
// matches CookieName, CookieName_<number>
|
// matches CookieName, CookieName_<number>
|
||||||
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName))
|
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieOptions.CookieName))
|
||||||
|
|
||||||
for _, c := range req.Cookies() {
|
for _, c := range req.Cookies() {
|
||||||
if cookieNameRegex.MatchString(c.Name) {
|
if cookieNameRegex.MatchString(c.Name) {
|
||||||
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1)
|
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
|
||||||
|
|
||||||
http.SetCookie(rw, clearCookie)
|
http.SetCookie(rw, clearCookie)
|
||||||
cookies = append(cookies, clearCookie)
|
cookies = append(cookies, clearCookie)
|
||||||
@ -89,60 +86,42 @@ func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// setSessionCookie adds the user's session cookie to the response
|
// setSessionCookie adds the user's session cookie to the response
|
||||||
func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string, created time.Time) {
|
||||||
for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) {
|
for _, c := range s.makeSessionCookie(req, val, created) {
|
||||||
http.SetCookie(rw, c)
|
http.SetCookie(rw, c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeSessionCookie creates an http.Cookie containing the authenticated user's
|
// makeSessionCookie creates an http.Cookie containing the authenticated user's
|
||||||
// authentication details
|
// authentication details
|
||||||
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie {
|
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, now time.Time) []*http.Cookie {
|
||||||
if value != "" {
|
if value != "" {
|
||||||
value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now)
|
value = cookie.SignedValue(s.CookieOptions.CookieSecret, s.CookieOptions.CookieName, value, now)
|
||||||
}
|
}
|
||||||
c := s.makeCookie(req, s.CookieName, value, expiration)
|
c := s.makeCookie(req, s.CookieOptions.CookieName, value, s.CookieOptions.CookieExpire, now)
|
||||||
if len(c.Value) > 4096-len(s.CookieName) {
|
if len(c.Value) > 4096-len(s.CookieOptions.CookieName) {
|
||||||
return splitCookie(c)
|
return splitCookie(c)
|
||||||
}
|
}
|
||||||
return []*http.Cookie{c}
|
return []*http.Cookie{c}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie {
|
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
return cookies.MakeCookie(
|
return cookies.MakeCookieFromOptions(
|
||||||
req,
|
req,
|
||||||
name,
|
name,
|
||||||
value,
|
value,
|
||||||
s.CookiePath,
|
s.CookieOptions,
|
||||||
s.CookieDomain,
|
|
||||||
s.CookieHTTPOnly,
|
|
||||||
s.CookieSecure,
|
|
||||||
expiration,
|
expiration,
|
||||||
time.Now(),
|
now,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCookieSessionStore initialises a new instance of the SessionStore from
|
// NewCookieSessionStore initialises a new instance of the SessionStore from
|
||||||
// the configuration given
|
// the configuration given
|
||||||
func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
func NewCookieSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
||||||
var cipher *cookie.Cipher
|
|
||||||
if len(cookieOpts.CookieSecret) > 0 {
|
|
||||||
var err error
|
|
||||||
cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to create cipher: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &SessionStore{
|
return &SessionStore{
|
||||||
CookieCipher: cipher,
|
CookieCipher: opts.Cipher,
|
||||||
CookieDomain: cookieOpts.CookieDomain,
|
CookieOptions: cookieOpts,
|
||||||
CookieExpire: cookieOpts.CookieExpire,
|
|
||||||
CookieHTTPOnly: cookieOpts.CookieHTTPOnly,
|
|
||||||
CookieName: cookieOpts.CookieName,
|
|
||||||
CookiePath: cookieOpts.CookiePath,
|
|
||||||
CookieSecret: cookieOpts.CookieSecret,
|
|
||||||
CookieSecure: cookieOpts.CookieSecure,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
||||||
switch opts.Type {
|
switch opts.Type {
|
||||||
case options.CookieSessionStoreType:
|
case options.CookieSessionStoreType:
|
||||||
return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts)
|
return cookie.NewCookieSessionStore(opts, cookieOpts)
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown session store type '%s'", opts.Type)
|
return nil, fmt.Errorf("unknown session store type '%s'", opts.Type)
|
||||||
}
|
}
|
||||||
|
@ -5,16 +5,20 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
. "github.com/onsi/ginkgo"
|
||||||
. "github.com/onsi/gomega"
|
. "github.com/onsi/gomega"
|
||||||
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||||
sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/pusher/oauth2_proxy/pkg/cookies"
|
"github.com/pusher/oauth2_proxy/pkg/cookies"
|
||||||
"github.com/pusher/oauth2_proxy/pkg/sessions"
|
"github.com/pusher/oauth2_proxy/pkg/sessions"
|
||||||
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
|
sessionscookie "github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/sessions/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSessionStore(t *testing.T) {
|
func TestSessionStore(t *testing.T) {
|
||||||
@ -72,6 +76,16 @@ var _ = Describe("NewSessionStore", func() {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("have a signature timestamp matching session.CreatedAt", func() {
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
if cookie.Value != "" {
|
||||||
|
parts := strings.Split(cookie.Value, "|")
|
||||||
|
Expect(parts).To(HaveLen(3))
|
||||||
|
Expect(parts[1]).To(Equal(strconv.Itoa(int(session.CreatedAt.Unix()))))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -86,6 +100,10 @@ var _ = Describe("NewSessionStore", func() {
|
|||||||
Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty())
|
Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
It("Ensures the session CreatedAt is not zero", func() {
|
||||||
|
Expect(session.CreatedAt.IsZero()).To(BeFalse())
|
||||||
|
})
|
||||||
|
|
||||||
CheckCookieOptions()
|
CheckCookieOptions()
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -138,12 +156,15 @@ var _ = Describe("NewSessionStore", func() {
|
|||||||
|
|
||||||
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
|
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
|
||||||
l := *loadedSession
|
l := *loadedSession
|
||||||
|
l.CreatedAt = time.Time{}
|
||||||
l.ExpiresOn = time.Time{}
|
l.ExpiresOn = time.Time{}
|
||||||
s := *session
|
s := *session
|
||||||
|
s.CreatedAt = time.Time{}
|
||||||
s.ExpiresOn = time.Time{}
|
s.ExpiresOn = time.Time{}
|
||||||
Expect(l).To(Equal(s))
|
Expect(l).To(Equal(s))
|
||||||
|
|
||||||
// Compare time.Time separately
|
// Compare time.Time separately
|
||||||
|
Expect(loadedSession.CreatedAt.Equal(session.CreatedAt)).To(BeTrue())
|
||||||
Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue())
|
Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue())
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -181,12 +202,16 @@ var _ = Describe("NewSessionStore", func() {
|
|||||||
SessionStoreInterfaceTests()
|
SessionStoreInterfaceTests()
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("with a cookie-secret set", func() {
|
Context("with a cipher", func() {
|
||||||
BeforeEach(func() {
|
BeforeEach(func() {
|
||||||
secret := make([]byte, 32)
|
secret := make([]byte, 32)
|
||||||
_, err := rand.Read(secret)
|
_, err := rand.Read(secret)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret)
|
cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret)
|
||||||
|
cipher, err := cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(cipher).ToNot(BeNil())
|
||||||
|
opts.Cipher = cipher
|
||||||
|
|
||||||
ss, err = sessions.NewSessionStore(opts, cookieOpts)
|
ss, err = sessions.NewSessionStore(opts, cookieOpts)
|
||||||
Expect(err).ToNot(HaveOccurred())
|
Expect(err).ToNot(HaveOccurred())
|
||||||
@ -231,7 +256,7 @@ var _ = Describe("NewSessionStore", func() {
|
|||||||
It("creates a cookie.SessionStore", func() {
|
It("creates a cookie.SessionStore", func() {
|
||||||
ss, err := sessions.NewSessionStore(opts, cookieOpts)
|
ss, err := sessions.NewSessionStore(opts, cookieOpts)
|
||||||
Expect(err).NotTo(HaveOccurred())
|
Expect(err).NotTo(HaveOccurred())
|
||||||
Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{}))
|
Expect(ss).To(BeAssignableToTypeOf(&sessionscookie.SessionStore{}))
|
||||||
})
|
})
|
||||||
|
|
||||||
Context("the cookie.SessionStore", func() {
|
Context("the cookie.SessionStore", func() {
|
||||||
|
@ -149,6 +149,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionSt
|
|||||||
s = &sessions.SessionState{
|
s = &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||||
RefreshToken: jsonResponse.RefreshToken,
|
RefreshToken: jsonResponse.RefreshToken,
|
||||||
Email: c.Email,
|
Email: c.Email,
|
||||||
|
@ -252,6 +252,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.Session
|
|||||||
s = &sessions.SessionState{
|
s = &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||||
Email: email,
|
Email: email,
|
||||||
}
|
}
|
||||||
|
@ -87,6 +87,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error)
|
|||||||
s.AccessToken = newSession.AccessToken
|
s.AccessToken = newSession.AccessToken
|
||||||
s.IDToken = newSession.IDToken
|
s.IDToken = newSession.IDToken
|
||||||
s.RefreshToken = newSession.RefreshToken
|
s.RefreshToken = newSession.RefreshToken
|
||||||
|
s.CreatedAt = newSession.CreatedAt
|
||||||
s.ExpiresOn = newSession.ExpiresOn
|
s.ExpiresOn = newSession.ExpiresOn
|
||||||
s.Email = newSession.Email
|
s.Email = newSession.Email
|
||||||
return
|
return
|
||||||
@ -126,6 +127,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
|||||||
AccessToken: token.AccessToken,
|
AccessToken: token.AccessToken,
|
||||||
IDToken: rawIDToken,
|
IDToken: rawIDToken,
|
||||||
RefreshToken: token.RefreshToken,
|
RefreshToken: token.RefreshToken,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
ExpiresOn: token.Expiry,
|
ExpiresOn: token.Expiry,
|
||||||
Email: claims.Email,
|
Email: claims.Email,
|
||||||
User: claims.Subject,
|
User: claims.Subject,
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/cookie"
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
@ -72,7 +73,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if a := v.Get("access_token"); a != "" {
|
if a := v.Get("access_token"); a != "" {
|
||||||
s = &sessions.SessionState{AccessToken: a}
|
s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()}
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("no access token found %s", body)
|
err = fmt.Errorf("no access token found %s", body)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user