diff --git a/oauthproxy.go b/oauthproxy.go index 02d9ac1..f7633ac 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -16,7 +16,7 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/logger" - "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/providers" "github.com/yhat/wsutil" ) @@ -75,6 +75,7 @@ type OAuthProxy struct { redirectURL *url.URL // the url to receive requests at whitelistDomains []string provider providers.Provider + sessionStore sessionsapi.SessionStore ProxyPrefix string SignInMessage string HtpasswdFile *HtpasswdFile @@ -249,6 +250,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { ProxyPrefix: opts.ProxyPrefix, provider: opts.provider, + sessionStore: opts.sessionStore, serveMux: serveMux, redirectURL: redirectURL, whitelistDomains: opts.WhitelistDomains, @@ -293,7 +295,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool { return p.HtpasswdFile != nil && p.DisplayHtpasswdForm } -func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) { +func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, err error) { if code == "" { return nil, errors.New("missing code") } @@ -485,7 +487,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, } // LoadCookiedSession reads the user's authentication details from the request -func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) { +func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, time.Duration, error) { var age time.Duration c, err := loadCookie(req, p.CookieName) if err != nil { @@ -507,7 +509,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionSta } // SaveSession creates a new session cookie value and sets this on the response -func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { +func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error { value, err := p.provider.CookieForSession(s, p.CookieCipher) if err != nil { return err @@ -694,7 +696,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { user, ok := p.ManualSignIn(rw, req) if ok { - session := &sessions.SessionState{User: user} + session := &sessionsapi.SessionState{User: user} p.SaveSession(rw, req, session) http.Redirect(rw, req, redirect, 302) } else { @@ -945,7 +947,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int // CheckBasicAuth checks the requests Authorization header for basic auth // credentials and authenticates these against the proxies HtpasswdFile -func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) { +func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessionsapi.SessionState, error) { if p.HtpasswdFile == nil { return nil, nil } @@ -967,7 +969,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, } if p.HtpasswdFile.Validate(pair[0], pair[1]) { logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") - return &sessions.SessionState{User: pair[0]}, nil + return &sessionsapi.SessionState{User: pair[0]}, nil } logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") return nil, nil diff --git a/options.go b/options.go index 3639134..50587c1 100644 --- a/options.go +++ b/options.go @@ -19,6 +19,8 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/pkg/apis/options" + sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "github.com/pusher/oauth2_proxy/pkg/sessions" "github.com/pusher/oauth2_proxy/providers" "gopkg.in/natefinch/lumberjack.v2" ) @@ -111,6 +113,7 @@ type Options struct { proxyURLs []*url.URL CompiledRegex []*regexp.Regexp provider providers.Provider + sessionStore sessionsapi.SessionStore signatureData *SignatureData oidcVerifier *oidc.IDTokenVerifier } @@ -136,6 +139,9 @@ func NewOptions() *Options { CookieExpire: time.Duration(168) * time.Hour, CookieRefresh: time.Duration(0), }, + SessionOptions: options.SessionOptions{ + Type: "cookie", + }, SetXAuthRequest: false, SkipAuthPreflight: false, PassBasicAuth: true, @@ -283,9 +289,19 @@ func (o *Options) Validate() error { "pass_access_token == true or "+ "cookie_refresh != 0, but is %d bytes.%s", len(secretBytes(o.CookieSecret)), suffix)) + } else { + // Enable encryption in the session store + o.EnableCipher = true } } + sessionStore, err := sessions.NewSessionStore(&o.SessionOptions, &o.CookieOptions) + if err != nil { + msgs = append(msgs, fmt.Sprintf("error initialising session storage: %v", err)) + } else { + o.sessionStore = sessionStore + } + if o.CookieRefresh >= o.CookieExpire { msgs = append(msgs, fmt.Sprintf( "cookie_refresh (%s) must be less than "+ diff --git a/pkg/apis/options/sessions.go b/pkg/apis/options/sessions.go index 56fd27a..2429b7b 100644 --- a/pkg/apis/options/sessions.go +++ b/pkg/apis/options/sessions.go @@ -2,7 +2,8 @@ package options // SessionOptions contains configuration options for the SessionStore providers. type SessionOptions struct { - Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` + Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` + EnableCipher bool // Allow the user to choose encryption or not CookieStoreOptions } @@ -11,4 +12,6 @@ type SessionOptions struct { var CookieSessionStoreType = "cookie" // CookieStoreOptions contains configuration options for the CookieSessionStore. -type CookieStoreOptions struct{} +type CookieStoreOptions struct { + EnableCipher bool // Allow the user to choose encryption or not +} diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go index 14c0b71..0fc4a64 100644 --- a/pkg/sessions/cookie/session_store.go +++ b/pkg/sessions/cookie/session_store.go @@ -126,7 +126,7 @@ func (s *SessionStore) makeCookie(req *http.Request, name string, value string, // the configuration given func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { var cipher *cookie.Cipher - if len(cookieOpts.CookieSecret) > 0 { + if opts.EnableCipher { var err error cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) if err != nil { diff --git a/pkg/sessions/session_store.go b/pkg/sessions/session_store.go index cc074c7..3a81eaa 100644 --- a/pkg/sessions/session_store.go +++ b/pkg/sessions/session_store.go @@ -12,6 +12,8 @@ import ( func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { switch opts.Type { case options.CookieSessionStoreType: + // Ensure EnableCipher is propogated from the parent option + opts.CookieStoreOptions.EnableCipher = opts.EnableCipher return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts) default: return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go index 0ceea66..4857912 100644 --- a/pkg/sessions/session_store_test.go +++ b/pkg/sessions/session_store_test.go @@ -181,12 +181,13 @@ var _ = Describe("NewSessionStore", func() { SessionStoreInterfaceTests() }) - Context("with a cookie-secret set", func() { + Context("with encryption enabled", func() { BeforeEach(func() { secret := make([]byte, 32) _, err := rand.Read(secret) Expect(err).ToNot(HaveOccurred()) cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret) + opts.EnableCipher = true ss, err = sessions.NewSessionStore(opts, cookieOpts) Expect(err).ToNot(HaveOccurred()) @@ -194,6 +195,19 @@ var _ = Describe("NewSessionStore", func() { SessionStoreInterfaceTests() }) + + Context("with encryption enabled, but no secret", func() { + BeforeEach(func() { + opts.EnableCipher = true + }) + + It("returns an error", func() { + ss, err := sessions.NewSessionStore(opts, cookieOpts) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("unable to create cipher: crypto/aes: invalid key size 0")) + Expect(ss).To(BeNil()) + }) + }) } BeforeEach(func() {