diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f04869..eb798f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,11 @@ ## Changes since v3.2.0 -- [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings) +- [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed) + - Allows for multiple different session storage implementations including client and server side + - Adds tests suite for interface to ensure consistency across implementations + - Refactor some configuration options (around cookies) into packages +- [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings) - [#146](https://github.com/pusher/oauth2_proxy/pull/146) Use full email address as `User` if the auth response did not contain a `User` field (@gargath) - [#144](https://github.com/pusher/oauth2_proxy/pull/144) Use GO 1.12 for ARM builds (@kskewes) - [#142](https://github.com/pusher/oauth2_proxy/pull/142) ARM Docker USER fix (@kskewes) diff --git a/Gopkg.lock b/Gopkg.lock index 2a69229..01af4d2 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -57,6 +57,20 @@ pruneopts = "" revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845" +[[projects]] + digest = "1:b3c5b95e56c06f5aa72cb2500e6ee5f44fcd122872d4fec2023a488e561218bc" + name = "github.com/hpcloud/tail" + packages = [ + ".", + "ratelimiter", + "util", + "watch", + "winfile", + ] + pruneopts = "" + revision = "a30252cb686a21eb2d0b98132633053ec2f7f1e5" + version = "v1.0.0" + [[projects]] digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f" name = "github.com/mbland/hmacauth" @@ -73,6 +87,54 @@ pruneopts = "" revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95" +[[projects]] + digest = "1:a3735b0978a8b53fc2ac97a6f46ca1189f0712a00df86d6ec4cf26c1a25e6d77" + name = "github.com/onsi/ginkgo" + packages = [ + ".", + "config", + "internal/codelocation", + "internal/containernode", + "internal/failer", + "internal/leafnodes", + "internal/remote", + "internal/spec", + "internal/spec_iterator", + "internal/specrunner", + "internal/suite", + "internal/testingtproxy", + "internal/writer", + "reporters", + "reporters/stenographer", + "reporters/stenographer/support/go-colorable", + "reporters/stenographer/support/go-isatty", + "types", + ] + pruneopts = "" + revision = "eea6ad008b96acdaa524f5b409513bf062b500ad" + version = "v1.8.0" + +[[projects]] + digest = "1:dbafce2fddb1ca331646fe2ac9c9413980368b19a60a4406a6e5861680bd73be" + name = "github.com/onsi/gomega" + packages = [ + ".", + "format", + "internal/assertion", + "internal/asyncassertion", + "internal/oraclematcher", + "internal/testingtsupport", + "matchers", + "matchers/support/goraph/bipartitegraph", + "matchers/support/goraph/edge", + "matchers/support/goraph/node", + "matchers/support/goraph/util", + "types", + ] + pruneopts = "" + revision = "90e289841c1ed79b7a598a7cd9959750cb5e89e2" + version = "v1.5.0" + [[projects]] digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" name = "github.com/pmezard/go-difflib" @@ -131,6 +193,9 @@ packages = [ "context", "context/ctxhttp", + "html", + "html/atom", + "html/charset", "websocket", ] pruneopts = "" @@ -150,6 +215,42 @@ pruneopts = "" revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402" +[[projects]] + branch = "master" + digest = "1:67a6e61e60283fd7dce50eba228080bff8805d9d69b2f121d7ec2260d120c4a8" + name = "golang.org/x/sys" + packages = ["unix"] + pruneopts = "" + revision = "ca7f33d4116e3a1f9425755d6a44e7ed9b4c97df" + +[[projects]] + digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0" + name = "golang.org/x/text" + packages = [ + "encoding", + "encoding/charmap", + "encoding/htmlindex", + "encoding/internal", + "encoding/internal/identifier", + "encoding/japanese", + "encoding/korean", + "encoding/simplifiedchinese", + "encoding/traditionalchinese", + "encoding/unicode", + "internal/gen", + "internal/language", + "internal/language/compact", + "internal/tag", + "internal/utf8internal", + "language", + "runes", + "transform", + "unicode/cldr", + ] + pruneopts = "" + revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475" + version = "v0.3.2" + [[projects]] branch = "master" digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed" @@ -182,6 +283,15 @@ revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a" version = "v1.0.0" +[[projects]] + digest = "1:eb53021a8aa3f599d29c7102e65026242bdedce998a54837dc67f14b6a97c5fd" + name = "gopkg.in/fsnotify.v1" + packages = ["."] + pruneopts = "" + revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9" + source = "https://github.com/fsnotify/fsnotify.git" + version = "v1.4.7" + [[projects]] digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2" name = "gopkg.in/fsnotify/fsnotify.v1" @@ -210,6 +320,22 @@ revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1" version = "v2.1.3" +[[projects]] + branch = "v1" + digest = "1:a96d16bd088460f2e0685d46c39bcf1208ba46e0a977be2df49864ec7da447dd" + name = "gopkg.in/tomb.v1" + packages = ["."] + pruneopts = "" + revision = "dd632973f1e7218eb1089048e0798ec9ae7dceb8" + +[[projects]] + digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d" + name = "gopkg.in/yaml.v2" + packages = ["."] + pruneopts = "" + revision = "51d6538a90f86fe93ac480b35f37b2be17fef232" + version = "v2.2.2" + [solve-meta] analyzer-name = "dep" analyzer-version = 1 @@ -220,6 +346,8 @@ "github.com/dgrijalva/jwt-go", "github.com/mbland/hmacauth", "github.com/mreiferson/go-options", + "github.com/onsi/ginkgo", + "github.com/onsi/gomega", "github.com/stretchr/testify/assert", "github.com/stretchr/testify/require", "github.com/yhat/wsutil", @@ -231,6 +359,7 @@ "google.golang.org/api/googleapi", "gopkg.in/fsnotify/fsnotify.v1", "gopkg.in/natefinch/lumberjack.v2", + "gopkg.in/square/go-jose.v2", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index f15d952..732bbcb 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -35,6 +35,10 @@ name = "gopkg.in/fsnotify/fsnotify.v1" version = "~1.2.0" +[[override]] + name = "gopkg.in/fsnotify.v1" + source = "https://github.com/fsnotify/fsnotify.git" + [[constraint]] branch = "master" name = "golang.org/x/crypto" diff --git a/Makefile b/Makefile index 11e2ed3..0dca566 100644 --- a/Makefile +++ b/Makefile @@ -33,6 +33,7 @@ lint: $(GOMETALINTER) --enable=deadcode \ --enable=gofmt \ --enable=goimports \ + --deadline=120s \ --tests ./... .PHONY: dep diff --git a/docs/3_configuration.md b/docs/configuration/configuration.md similarity index 99% rename from docs/3_configuration.md rename to docs/configuration/configuration.md index b390944..fd33d37 100644 --- a/docs/3_configuration.md +++ b/docs/configuration/configuration.md @@ -1,7 +1,8 @@ --- layout: default title: Configuration -permalink: /configuration +permalink: /docs/configuration +has_children: true nav_order: 3 --- @@ -78,6 +79,7 @@ Usage of oauth2_proxy: -request-logging-format: Template for request log lines (see "Logging Configuration" paragraph below) -resource string: The resource that is protected (Azure AD only) -scope string: OAuth scope specification + -session-store-type: Session data storage backend (default: cookie) -set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode) -set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode) -signature-key string: GAP-Signature request signature key (algorithm:secretkey) diff --git a/docs/configuration/sessions.md b/docs/configuration/sessions.md new file mode 100644 index 0000000..1896a5d --- /dev/null +++ b/docs/configuration/sessions.md @@ -0,0 +1,34 @@ +--- +layout: default +title: Sessions +permalink: /configuration +parent: Configuration +nav_order: 3 +--- + +## Sessions + +Sessions allow a user's authentication to be tracked between multiple HTTP +requests to a service. + +The OAuth2 Proxy uses a Cookie to track user sessions and will store the session +data in one of the available session storage backends. + +At present the available backends are (as passed to `--session-store-type`): +- [cookie](cookie-storage) (deafult) + +### Cookie Storage + +The Cookie storage backend is the default backend implementation and has +been used in the OAuth2 Proxy historically. + +With the Cookie storage backend, all session information is stored in client +side cookies and transferred with each and every request. + +The following should be known when using this implementation: +- Since all state is stored client side, this storage backend means that the OAuth2 Proxy is completely stateless +- Cookies are signed server side to prevent modification client-side +- It is recommended to set a `cookie-secret` which will ensure data is encrypted within the cookie data. +- Since multiple requests can be made concurrently to the OAuth2 Proxy, this session implementation +cannot lock sessions and while updating and refreshing sessions, there can be conflicts which force +users to re-authenticate diff --git a/env_options.go b/env_options.go index d54f5af..9db2d89 100644 --- a/env_options.go +++ b/env_options.go @@ -15,14 +15,27 @@ type EnvOptions map[string]interface{} // Fields in the options struct must have an `env` and `cfg` tag to be read // from the environment func (cfg EnvOptions) LoadEnvForStruct(options interface{}) { - val := reflect.ValueOf(options).Elem() - typ := val.Type() + val := reflect.ValueOf(options) + var typ reflect.Type + if val.Kind() == reflect.Ptr { + typ = val.Elem().Type() + } else { + typ = val.Type() + } + for i := 0; i < typ.NumField(); i++ { // pull out the struct tags: // flag - the name of the command line flag // deprecated - (optional) the name of the deprecated command line flag // cfg - (optional, defaults to underscored flag) the name of the config file option field := typ.Field(i) + fieldV := reflect.Indirect(val).Field(i) + + if field.Type.Kind() == reflect.Struct && field.Anonymous { + cfg.LoadEnvForStruct(fieldV.Interface()) + continue + } + flagName := field.Tag.Get("flag") envName := field.Tag.Get("env") cfgName := field.Tag.Get("cfg") diff --git a/env_options_test.go b/env_options_test.go index e9277f7..c1937e6 100644 --- a/env_options_test.go +++ b/env_options_test.go @@ -1,26 +1,46 @@ -package main +package main_test import ( "os" "testing" + proxy "github.com/pusher/oauth2_proxy" "github.com/stretchr/testify/assert" ) -type envTest struct { - testField string `cfg:"target_field" env:"TEST_ENV_FIELD"` +type EnvTest struct { + TestField string `cfg:"target_field" env:"TEST_ENV_FIELD"` + EnvTestEmbed +} + +type EnvTestEmbed struct { + TestFieldEmbed string `cfg:"target_field_embed" env:"TEST_ENV_FIELD_EMBED"` } func TestLoadEnvForStruct(t *testing.T) { - cfg := make(EnvOptions) - cfg.LoadEnvForStruct(&envTest{}) + cfg := make(proxy.EnvOptions) + cfg.LoadEnvForStruct(&EnvTest{}) _, ok := cfg["target_field"] assert.Equal(t, ok, false) os.Setenv("TEST_ENV_FIELD", "1234abcd") - cfg.LoadEnvForStruct(&envTest{}) + cfg.LoadEnvForStruct(&EnvTest{}) v := cfg["target_field"] assert.Equal(t, v, "1234abcd") } + +func TestLoadEnvForStructWithEmbeddedFields(t *testing.T) { + + cfg := make(proxy.EnvOptions) + cfg.LoadEnvForStruct(&EnvTest{}) + + _, ok := cfg["target_field_embed"] + assert.Equal(t, ok, false) + + os.Setenv("TEST_ENV_FIELD_EMBED", "1234abcd") + cfg.LoadEnvForStruct(&EnvTest{}) + v := cfg["target_field_embed"] + assert.Equal(t, v, "1234abcd") +} diff --git a/main.go b/main.go index b74540b..a74c245 100644 --- a/main.go +++ b/main.go @@ -75,6 +75,8 @@ func main() { flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag") flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag") + flagSet.String("session-store-type", "cookie", "the session storage provider to use") + flagSet.String("logging-filename", "", "File to log requests to, empty for stdout") flagSet.Int("logging-max-size", 100, "Maximum size in megabytes of the log file before rotation") flagSet.Int("logging-max-age", 7, "Maximum number of days to retain old log files") diff --git a/oauthproxy.go b/oauthproxy.go index 52f7c79..02d9ac1 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -16,6 +16,7 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/providers" "github.com/yhat/wsutil" ) @@ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool { return p.HtpasswdFile != nil && p.DisplayHtpasswdForm } -func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { +func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) { if code == "" { return nil, errors.New("missing code") } @@ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, } // LoadCookiedSession reads the user's authentication details from the request -func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { +func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) { var age time.Duration c, err := loadCookie(req, p.CookieName) if err != nil { @@ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt } // SaveSession creates a new session cookie value and sets this on the response -func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { +func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { value, err := p.provider.CookieForSession(s, p.CookieCipher) if err != nil { return err @@ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { user, ok := p.ManualSignIn(rw, req) if ok { - session := &providers.SessionState{User: user} + session := &sessions.SessionState{User: user} p.SaveSession(rw, req, session) http.Redirect(rw, req, redirect, 302) } else { @@ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int // CheckBasicAuth checks the requests Authorization header for basic auth // credentials and authenticates these against the proxies HtpasswdFile -func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { +func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) { if p.HtpasswdFile == nil { return nil, nil } @@ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, } if p.HtpasswdFile.Validate(pair[0], pair[1]) { logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") - return &providers.SessionState{User: pair[0]}, nil + return &sessions.SessionState{User: pair[0]}, nil } logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") return nil, nil diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 65a8fe1..914e99f 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -16,6 +16,7 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/providers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { } } -func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { +func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) { return tp.EmailAddress, nil } -func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { +func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool { return tp.ValidToken } @@ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) } -func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { +func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error { value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) if err != nil { return err @@ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time return nil } -func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) { +func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) { return p.proxy.LoadCookiedSession(p.req) } func TestLoadCookiedSession(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() - startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} + startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} pcTest.SaveSession(startSession, time.Now()) session, _, err := pcTest.LoadCookiedSession() @@ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) { pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour reference := time.Now().Add(time.Duration(-2) * time.Hour) - startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} pcTest.SaveSession(startSession, reference) session, age, err := pcTest.LoadCookiedSession() @@ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} pcTest.SaveSession(startSession, reference) session, _, err := pcTest.LoadCookiedSession() @@ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} pcTest.SaveSession(startSession, reference) pcTest.proxy.CookieRefresh = time.Hour @@ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest { func TestAuthOnlyEndpointAccepted(t *testing.T) { test := NewAuthOnlyEndpointTest() - startSession := &providers.SessionState{ + startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} test.SaveSession(startSession, time.Now()) @@ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { test := NewAuthOnlyEndpointTest() test.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &providers.SessionState{ + startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} test.SaveSession(startSession, reference) @@ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { test := NewAuthOnlyEndpointTest() - startSession := &providers.SessionState{ + startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} test.SaveSession(startSession, time.Now()) test.validateUser = false @@ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) - startSession := &providers.SessionState{ + startSession := &sessions.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} pcTest.SaveSession(startSession, time.Now()) @@ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { req := httptest.NewRequest(method, "/foo/bar", bodyBuf) req.Header = st.header - state := &providers.SessionState{ + state := &sessions.SessionState{ Email: "mbland@acm.org", AccessToken: "my_access_token"} value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) if err != nil { diff --git a/options.go b/options.go index fc20c51..3639134 100644 --- a/options.go +++ b/options.go @@ -18,6 +18,7 @@ import ( "github.com/dgrijalva/jwt-go" "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/options" "github.com/pusher/oauth2_proxy/providers" "gopkg.in/natefinch/lumberjack.v2" ) @@ -49,14 +50,11 @@ type Options struct { CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"` Footer string `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"` - CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` - CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` - CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` - CookiePath string `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"` - CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` - CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` - CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"` - CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"` + // Embed CookieOptions + options.CookieOptions + + // Embed SessionOptions + options.SessionOptions Upstreams []string `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"` @@ -126,16 +124,18 @@ type SignatureData struct { // NewOptions constructs a new Options with defaulted values func NewOptions() *Options { return &Options{ - ProxyPrefix: "/oauth2", - ProxyWebSockets: true, - HTTPAddress: "127.0.0.1:4180", - HTTPSAddress: ":443", - DisplayHtpasswdForm: true, - CookieName: "_oauth2_proxy", - CookieSecure: true, - CookieHTTPOnly: true, - CookieExpire: time.Duration(168) * time.Hour, - CookieRefresh: time.Duration(0), + ProxyPrefix: "/oauth2", + ProxyWebSockets: true, + HTTPAddress: "127.0.0.1:4180", + HTTPSAddress: ":443", + DisplayHtpasswdForm: true, + CookieOptions: options.CookieOptions{ + CookieName: "_oauth2_proxy", + CookieSecure: true, + CookieHTTPOnly: true, + CookieExpire: time.Duration(168) * time.Hour, + CookieRefresh: time.Duration(0), + }, SetXAuthRequest: false, SkipAuthPreflight: false, PassBasicAuth: true, diff --git a/pkg/apis/options/cookie.go b/pkg/apis/options/cookie.go new file mode 100644 index 0000000..80ecf57 --- /dev/null +++ b/pkg/apis/options/cookie.go @@ -0,0 +1,15 @@ +package options + +import "time" + +// CookieOptions contains configuration options relating to Cookie configuration +type CookieOptions struct { + CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` + CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` + CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` + CookiePath string `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"` + CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` + CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` + CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"` + CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"` +} diff --git a/pkg/apis/options/sessions.go b/pkg/apis/options/sessions.go new file mode 100644 index 0000000..56fd27a --- /dev/null +++ b/pkg/apis/options/sessions.go @@ -0,0 +1,14 @@ +package options + +// SessionOptions contains configuration options for the SessionStore providers. +type SessionOptions struct { + Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"` + CookieStoreOptions +} + +// CookieSessionStoreType is used to indicate the CookieSessionStore should be +// used for storing sessions. +var CookieSessionStoreType = "cookie" + +// CookieStoreOptions contains configuration options for the CookieSessionStore. +type CookieStoreOptions struct{} diff --git a/pkg/apis/sessions/interfaces.go b/pkg/apis/sessions/interfaces.go new file mode 100644 index 0000000..34d945f --- /dev/null +++ b/pkg/apis/sessions/interfaces.go @@ -0,0 +1,12 @@ +package sessions + +import ( + "net/http" +) + +// SessionStore is an interface to storing user sessions in the proxy +type SessionStore interface { + Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error + Load(req *http.Request) (*SessionState, error) + Clear(rw http.ResponseWriter, req *http.Request) error +} diff --git a/providers/session_state.go b/pkg/apis/sessions/session_state.go similarity index 99% rename from providers/session_state.go rename to pkg/apis/sessions/session_state.go index c3402ac..f6efefd 100644 --- a/providers/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -1,4 +1,4 @@ -package providers +package sessions import ( "encoding/json" diff --git a/providers/session_state_test.go b/pkg/apis/sessions/session_state_test.go similarity index 84% rename from providers/session_state_test.go rename to pkg/apis/sessions/session_state_test.go index 78957c6..83b21a4 100644 --- a/providers/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -1,4 +1,4 @@ -package providers +package sessions_test import ( "fmt" @@ -6,6 +6,7 @@ import ( "time" "github.com/pusher/oauth2_proxy/cookie" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, nil, err) c2, err := cookie.NewCipher([]byte(altSecret)) assert.Equal(t, nil, err) - s := &SessionState{ + s := &sessions.SessionState{ Email: "user@domain.com", AccessToken: "token1234", IDToken: "rawtoken1234", @@ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) { encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - ss, err := DecodeSessionState(encoded, c) + ss, err := sessions.DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, "user@domain.com", ss.User) @@ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, s.RefreshToken, ss.RefreshToken) // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = DecodeSessionState(encoded, c2) + ss, err = sessions.DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.NotEqual(t, "user@domain.com", ss.User) @@ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, nil, err) c2, err := cookie.NewCipher([]byte(altSecret)) assert.Equal(t, nil, err) - s := &SessionState{ + s := &sessions.SessionState{ User: "just-user", Email: "user@domain.com", AccessToken: "token1234", @@ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - ss, err := DecodeSessionState(encoded, c) + ss, err := sessions.DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) @@ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, s.RefreshToken, ss.RefreshToken) // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = DecodeSessionState(encoded, c2) + ss, err = sessions.DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.NotEqual(t, s.User, ss.User) @@ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { } func TestSessionStateSerializationNoCipher(t *testing.T) { - s := &SessionState{ + s := &sessions.SessionState{ Email: "user@domain.com", AccessToken: "token1234", ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), @@ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { assert.Equal(t, nil, err) // only email should have been serialized - ss, err := DecodeSessionState(encoded, nil) + ss, err := sessions.DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) assert.Equal(t, "user@domain.com", ss.User) assert.Equal(t, s.Email, ss.Email) @@ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { } func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { - s := &SessionState{ + s := &sessions.SessionState{ User: "just-user", Email: "user@domain.com", AccessToken: "token1234", @@ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { assert.Equal(t, nil, err) // only email should have been serialized - ss, err := DecodeSessionState(encoded, nil) + ss, err := sessions.DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) @@ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { } func TestExpired(t *testing.T) { - s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} + s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} assert.Equal(t, true, s.IsExpired()) - s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} + s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} assert.Equal(t, false, s.IsExpired()) - s = &SessionState{} + s = &sessions.SessionState{} assert.Equal(t, false, s.IsExpired()) } type testCase struct { - SessionState + sessions.SessionState Encoded string Cipher *cookie.Cipher Error bool @@ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) { testCases := []testCase{ { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: `{"Email":"user@domain.com","User":"just-user"}`, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) { for i, tc := range testCases { encoded, err := tc.EncodeSessionState(tc.Cipher) - t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) + t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) if tc.Error { assert.Error(t, err) assert.Empty(t, encoded) @@ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) { } } -// TestDecodeSessionState tests DecodeSessionState with the test vector +// TestDecodeSessionState testssessions.DecodeSessionState with the test vector func TestDecodeSessionState(t *testing.T) { e := time.Now().Add(time.Duration(1) * time.Hour) eJSON, _ := e.MarshalJSON() @@ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) { testCases := []testCase{ { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: `{"Email":"user@domain.com","User":"just-user"}`, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "user@domain.com", }, Encoded: `{"Email":"user@domain.com"}`, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ User: "just-user", }, Encoded: `{"User":"just-user"}`, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) { Cipher: c, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", }, @@ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) { Error: true, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ User: "just-user", Email: "user@domain.com", }, @@ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) { Error: true, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) { Cipher: c, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) { } for i, tc := range testCases { - ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) - t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) + ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher) + t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) if tc.Error { assert.Error(t, err) assert.Nil(t, ss) diff --git a/pkg/cookies/cookies.go b/pkg/cookies/cookies.go new file mode 100644 index 0000000..936f08e --- /dev/null +++ b/pkg/cookies/cookies.go @@ -0,0 +1,34 @@ +package cookies + +import ( + "net" + "net/http" + "strings" + "time" + + "github.com/pusher/oauth2_proxy/logger" +) + +// MakeCookie constructs a cookie from the given parameters, +// discovering the domain from the request if not specified. +func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time) *http.Cookie { + if domain != "" { + host := req.Host + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + if !strings.HasSuffix(host, domain) { + logger.Printf("Warning: request host is %q but using configured cookie domain of %q", host, domain) + } + } + + return &http.Cookie{ + Name: name, + Value: value, + Path: path, + Domain: domain, + HttpOnly: httpOnly, + Secure: secure, + Expires: now.Add(expiration), + } +} diff --git a/pkg/sessions/cookie/session_store.go b/pkg/sessions/cookie/session_store.go new file mode 100644 index 0000000..14c0b71 --- /dev/null +++ b/pkg/sessions/cookie/session_store.go @@ -0,0 +1,232 @@ +package cookie + +import ( + "errors" + "fmt" + "net/http" + "regexp" + "strings" + "time" + + "github.com/pusher/oauth2_proxy/cookie" + "github.com/pusher/oauth2_proxy/pkg/apis/options" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "github.com/pusher/oauth2_proxy/pkg/cookies" + "github.com/pusher/oauth2_proxy/pkg/sessions/utils" +) + +const ( + // Cookies are limited to 4kb including the length of the cookie name, + // the cookie name can be up to 256 bytes + maxCookieLength = 3840 +) + +// Ensure CookieSessionStore implements the interface +var _ sessions.SessionStore = &SessionStore{} + +// SessionStore is an implementation of the sessions.SessionStore +// interface that stores sessions in client side cookies +type SessionStore struct { + CookieCipher *cookie.Cipher + CookieDomain string + CookieExpire time.Duration + CookieHTTPOnly bool + CookieName string + CookiePath string + CookieSecret string + CookieSecure bool +} + +// Save takes a sessions.SessionState and stores the information from it +// within Cookies set on the HTTP response writer +func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error { + value, err := utils.CookieForSession(ss, s.CookieCipher) + if err != nil { + return err + } + s.setSessionCookie(rw, req, value) + return nil +} + +// Load reads sessions.SessionState information from Cookies within the +// HTTP request object +func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) { + c, err := loadCookie(req, s.CookieName) + if err != nil { + // always http.ErrNoCookie + return nil, fmt.Errorf("Cookie %q not present", s.CookieName) + } + val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire) + if !ok { + return nil, errors.New("Cookie Signature not valid") + } + + session, err := utils.SessionFromCookie(val, s.CookieCipher) + if err != nil { + return nil, err + } + return session, nil +} + +// Clear clears any saved session information by writing a cookie to +// clear the session +func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error { + var cookies []*http.Cookie + + // matches CookieName, CookieName_ + var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName)) + + for _, c := range req.Cookies() { + if cookieNameRegex.MatchString(c.Name) { + clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1) + + http.SetCookie(rw, clearCookie) + cookies = append(cookies, clearCookie) + } + } + + return nil +} + +// setSessionCookie adds the user's session cookie to the response +func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { + for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) { + http.SetCookie(rw, c) + } +} + +// makeSessionCookie creates an http.Cookie containing the authenticated user's +// authentication details +func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { + if value != "" { + value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now) + } + c := s.makeCookie(req, s.CookieName, value, expiration) + if len(c.Value) > 4096-len(s.CookieName) { + return splitCookie(c) + } + return []*http.Cookie{c} +} + +func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie { + return cookies.MakeCookie( + req, + name, + value, + s.CookiePath, + s.CookieDomain, + s.CookieHTTPOnly, + s.CookieSecure, + expiration, + time.Now(), + ) +} + +// NewCookieSessionStore initialises a new instance of the SessionStore from +// the configuration given +func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { + var cipher *cookie.Cipher + if len(cookieOpts.CookieSecret) > 0 { + var err error + cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret)) + if err != nil { + return nil, fmt.Errorf("unable to create cipher: %v", err) + } + } + + return &SessionStore{ + CookieCipher: cipher, + CookieDomain: cookieOpts.CookieDomain, + CookieExpire: cookieOpts.CookieExpire, + CookieHTTPOnly: cookieOpts.CookieHTTPOnly, + CookieName: cookieOpts.CookieName, + CookiePath: cookieOpts.CookiePath, + CookieSecret: cookieOpts.CookieSecret, + CookieSecure: cookieOpts.CookieSecure, + }, nil +} + +// splitCookie reads the full cookie generated to store the session and splits +// it into a slice of cookies which fit within the 4kb cookie limit indexing +// the cookies from 0 +func splitCookie(c *http.Cookie) []*http.Cookie { + if len(c.Value) < maxCookieLength { + return []*http.Cookie{c} + } + cookies := []*http.Cookie{} + valueBytes := []byte(c.Value) + count := 0 + for len(valueBytes) > 0 { + new := copyCookie(c) + new.Name = fmt.Sprintf("%s_%d", c.Name, count) + count++ + if len(valueBytes) < maxCookieLength { + new.Value = string(valueBytes) + valueBytes = []byte{} + } else { + newValue := valueBytes[:maxCookieLength] + valueBytes = valueBytes[maxCookieLength:] + new.Value = string(newValue) + } + cookies = append(cookies, new) + } + return cookies +} + +// loadCookie retreieves the sessions state cookie from the http request. +// If a single cookie is present this will be returned, otherwise it attempts +// to reconstruct a cookie split up by splitCookie +func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { + c, err := req.Cookie(cookieName) + if err == nil { + return c, nil + } + cookies := []*http.Cookie{} + err = nil + count := 0 + for err == nil { + var c *http.Cookie + c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count)) + if err == nil { + cookies = append(cookies, c) + count++ + } + } + if len(cookies) == 0 { + return nil, fmt.Errorf("Could not find cookie %s", cookieName) + } + return joinCookies(cookies) +} + +// joinCookies takes a slice of cookies from the request and reconstructs the +// full session cookie +func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { + if len(cookies) == 0 { + return nil, fmt.Errorf("list of cookies must be > 0") + } + if len(cookies) == 1 { + return cookies[0], nil + } + c := copyCookie(cookies[0]) + for i := 1; i < len(cookies); i++ { + c.Value += cookies[i].Value + } + c.Name = strings.TrimRight(c.Name, "_0") + return c, nil +} + +func copyCookie(c *http.Cookie) *http.Cookie { + return &http.Cookie{ + Name: c.Name, + Value: c.Value, + Path: c.Path, + Domain: c.Domain, + Expires: c.Expires, + RawExpires: c.RawExpires, + MaxAge: c.MaxAge, + Secure: c.Secure, + HttpOnly: c.HttpOnly, + Raw: c.Raw, + Unparsed: c.Unparsed, + } +} diff --git a/pkg/sessions/session_store.go b/pkg/sessions/session_store.go new file mode 100644 index 0000000..cc074c7 --- /dev/null +++ b/pkg/sessions/session_store.go @@ -0,0 +1,19 @@ +package sessions + +import ( + "fmt" + + "github.com/pusher/oauth2_proxy/pkg/apis/options" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "github.com/pusher/oauth2_proxy/pkg/sessions/cookie" +) + +// NewSessionStore creates a SessionStore from the provided configuration +func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) { + switch opts.Type { + case options.CookieSessionStoreType: + return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts) + default: + return nil, fmt.Errorf("unknown session store type '%s'", opts.Type) + } +} diff --git a/pkg/sessions/session_store_test.go b/pkg/sessions/session_store_test.go new file mode 100644 index 0000000..0ceea66 --- /dev/null +++ b/pkg/sessions/session_store_test.go @@ -0,0 +1,254 @@ +package sessions_test + +import ( + "crypto/rand" + "encoding/base64" + "net/http" + "net/http/httptest" + "testing" + "time" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/pusher/oauth2_proxy/pkg/apis/options" + sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "github.com/pusher/oauth2_proxy/pkg/cookies" + "github.com/pusher/oauth2_proxy/pkg/sessions" + "github.com/pusher/oauth2_proxy/pkg/sessions/cookie" +) + +func TestSessionStore(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "SessionStore") +} + +var _ = Describe("NewSessionStore", func() { + var opts *options.SessionOptions + var cookieOpts *options.CookieOptions + + var request *http.Request + var response *httptest.ResponseRecorder + var session *sessionsapi.SessionState + var ss sessionsapi.SessionStore + + CheckCookieOptions := func() { + Context("the cookies returned", func() { + var cookies []*http.Cookie + BeforeEach(func() { + cookies = response.Result().Cookies() + }) + + It("have the correct name set", func() { + if len(cookies) == 1 { + Expect(cookies[0].Name).To(Equal(cookieOpts.CookieName)) + } else { + for _, cookie := range cookies { + Expect(cookie.Name).To(ContainSubstring(cookieOpts.CookieName)) + } + } + }) + + It("have the correct path set", func() { + for _, cookie := range cookies { + Expect(cookie.Path).To(Equal(cookieOpts.CookiePath)) + } + }) + + It("have the correct domain set", func() { + for _, cookie := range cookies { + Expect(cookie.Domain).To(Equal(cookieOpts.CookieDomain)) + } + }) + + It("have the correct HTTPOnly set", func() { + for _, cookie := range cookies { + Expect(cookie.HttpOnly).To(Equal(cookieOpts.CookieHTTPOnly)) + } + }) + + It("have the correct secure set", func() { + for _, cookie := range cookies { + Expect(cookie.Secure).To(Equal(cookieOpts.CookieSecure)) + } + }) + + }) + } + + SessionStoreInterfaceTests := func() { + Context("when Save is called", func() { + BeforeEach(func() { + err := ss.Save(response, request, session) + Expect(err).ToNot(HaveOccurred()) + }) + + It("sets a `set-cookie` header in the response", func() { + Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty()) + }) + + CheckCookieOptions() + }) + + Context("when Clear is called", func() { + BeforeEach(func() { + cookie := cookies.MakeCookie(request, + cookieOpts.CookieName, + "foo", + cookieOpts.CookiePath, + cookieOpts.CookieDomain, + cookieOpts.CookieHTTPOnly, + cookieOpts.CookieSecure, + cookieOpts.CookieExpire, + time.Now(), + ) + request.AddCookie(cookie) + err := ss.Clear(response, request) + Expect(err).ToNot(HaveOccurred()) + }) + + It("sets a `set-cookie` header in the response", func() { + Expect(response.Header().Get("Set-Cookie")).ToNot(BeEmpty()) + }) + + CheckCookieOptions() + }) + + Context("when Load is called", func() { + var loadedSession *sessionsapi.SessionState + BeforeEach(func() { + req := httptest.NewRequest("GET", "http://example.com/", nil) + resp := httptest.NewRecorder() + err := ss.Save(resp, req, session) + Expect(err).ToNot(HaveOccurred()) + + for _, cookie := range resp.Result().Cookies() { + request.AddCookie(cookie) + } + loadedSession, err = ss.Load(request) + Expect(err).ToNot(HaveOccurred()) + }) + + It("loads a session equal to the original session", func() { + if cookieOpts.CookieSecret == "" { + // Only Email and User stored in session when encrypted + Expect(loadedSession.Email).To(Equal(session.Email)) + Expect(loadedSession.User).To(Equal(session.User)) + } else { + // All fields stored in session if encrypted + + // Can't compare time.Time using Equal() so remove ExpiresOn from sessions + l := *loadedSession + l.ExpiresOn = time.Time{} + s := *session + s.ExpiresOn = time.Time{} + Expect(l).To(Equal(s)) + + // Compare time.Time separately + Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue()) + } + }) + }) + } + + RunSessionTests := func() { + Context("with default options", func() { + BeforeEach(func() { + var err error + ss, err = sessions.NewSessionStore(opts, cookieOpts) + Expect(err).ToNot(HaveOccurred()) + }) + + SessionStoreInterfaceTests() + }) + + Context("with non-default options", func() { + BeforeEach(func() { + cookieOpts = &options.CookieOptions{ + CookieName: "_cookie_name", + CookiePath: "/path", + CookieExpire: time.Duration(72) * time.Hour, + CookieRefresh: time.Duration(3600), + CookieSecure: false, + CookieHTTPOnly: false, + CookieDomain: "example.com", + } + + var err error + ss, err = sessions.NewSessionStore(opts, cookieOpts) + Expect(err).ToNot(HaveOccurred()) + }) + + SessionStoreInterfaceTests() + }) + + Context("with a cookie-secret set", func() { + BeforeEach(func() { + secret := make([]byte, 32) + _, err := rand.Read(secret) + Expect(err).ToNot(HaveOccurred()) + cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret) + + ss, err = sessions.NewSessionStore(opts, cookieOpts) + Expect(err).ToNot(HaveOccurred()) + }) + + SessionStoreInterfaceTests() + }) + } + + BeforeEach(func() { + ss = nil + opts = &options.SessionOptions{} + + // Set default options in CookieOptions + cookieOpts = &options.CookieOptions{ + CookieName: "_oauth2_proxy", + CookiePath: "/", + CookieExpire: time.Duration(168) * time.Hour, + CookieRefresh: time.Duration(0), + CookieSecure: true, + CookieHTTPOnly: true, + } + + session = &sessionsapi.SessionState{ + AccessToken: "AccessToken", + IDToken: "IDToken", + ExpiresOn: time.Now().Add(1 * time.Hour), + RefreshToken: "RefreshToken", + Email: "john.doe@example.com", + User: "john.doe", + } + + request = httptest.NewRequest("GET", "http://example.com/", nil) + response = httptest.NewRecorder() + }) + + Context("with type 'cookie'", func() { + BeforeEach(func() { + opts.Type = options.CookieSessionStoreType + }) + + It("creates a cookie.SessionStore", func() { + ss, err := sessions.NewSessionStore(opts, cookieOpts) + Expect(err).NotTo(HaveOccurred()) + Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{})) + }) + + Context("the cookie.SessionStore", func() { + RunSessionTests() + }) + }) + + Context("with an invalid type", func() { + BeforeEach(func() { + opts.Type = "invalid-type" + }) + + It("returns an error", func() { + ss, err := sessions.NewSessionStore(opts, cookieOpts) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(Equal("unknown session store type 'invalid-type'")) + Expect(ss).To(BeNil()) + }) + }) +}) diff --git a/pkg/sessions/utils/utils.go b/pkg/sessions/utils/utils.go new file mode 100644 index 0000000..051e9cc --- /dev/null +++ b/pkg/sessions/utils/utils.go @@ -0,0 +1,41 @@ +package utils + +import ( + "encoding/base64" + + "github.com/pusher/oauth2_proxy/cookie" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" +) + +// CookieForSession serializes a session state for storage in a cookie +func CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) { + return s.EncodeSessionState(c) +} + +// SessionFromCookie deserializes a session from a cookie value +func SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) { + return sessions.DecodeSessionState(v, c) +} + +// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary +func SecretBytes(secret string) []byte { + b, err := base64.URLEncoding.DecodeString(addPadding(secret)) + if err == nil { + return []byte(addPadding(string(b))) + } + return []byte(secret) +} + +func addPadding(secret string) string { + padding := len(secret) % 4 + switch padding { + case 1: + return secret + "===" + case 2: + return secret + "==" + case 3: + return secret + "=" + default: + return secret + } +} diff --git a/providers/azure.go b/providers/azure.go index baae38f..a7961d2 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -9,6 +9,7 @@ import ( "github.com/bitly/go-simplejson" "github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // AzureProvider represents an Azure based Identity Provider @@ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) { } // GetEmailAddress returns the Account email address -func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { +func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { var email string var err error diff --git a/providers/azure_test.go b/providers/azure_test.go index 469f2d1..8d34bdc 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -6,6 +6,7 @@ import ( "net/url" "testing" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) @@ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) @@ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) @@ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "", email) @@ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "", email) @@ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "", email) diff --git a/providers/facebook.go b/providers/facebook.go index 6f81f15..9897a1b 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -7,6 +7,7 @@ import ( "net/url" "github.com/pusher/oauth2_proxy/api" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // FacebookProvider represents an Facebook based Identity Provider @@ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header { } // GetEmailAddress returns the Account email address -func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { +func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { if s.AccessToken == "" { return "", errors.New("missing access token") } @@ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { } // ValidateSessionState validates the AccessToken -func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { +func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool { return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) } diff --git a/providers/github.go b/providers/github.go index f00fc19..b60ffe1 100644 --- a/providers/github.go +++ b/providers/github.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // GitHubProvider represents an GitHub based Identity Provider @@ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { } // GetEmailAddress returns the Account email address -func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { +func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { var emails []struct { Email string `json:"email"` @@ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { } // GetUserName returns the Account user name -func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { +func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) { var user struct { Login string `json:"login"` Email string `json:"email"` diff --git a/providers/github_test.go b/providers/github_test.go index 4b093ca..2d45b84 100644 --- a/providers/github_test.go +++ b/providers/github_test.go @@ -6,6 +6,7 @@ import ( "net/url" "testing" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) @@ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Empty(t, "", email) @@ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { p := testGitHubProvider(bURL.Host) p.Org = "testorg1" - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) @@ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. - session := &SessionState{AccessToken: "unexpected_access_token"} + session := &sessions.SessionState{AccessToken: "unexpected_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) @@ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) @@ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetUserName(session) assert.Equal(t, nil, err) assert.Equal(t, "mbland", email) diff --git a/providers/gitlab.go b/providers/gitlab.go index 1962552..af956c4 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -6,6 +6,7 @@ import ( "github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // GitLabProvider represents an GitLab based Identity Provider @@ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider { } // GetEmailAddress returns the Account email address -func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { +func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { req, err := http.NewRequest("GET", p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) diff --git a/providers/gitlab_test.go b/providers/gitlab_test.go index 19038e1..112eb89 100644 --- a/providers/gitlab_test.go +++ b/providers/gitlab_test.go @@ -6,6 +6,7 @@ import ( "net/url" "testing" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitLabProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) @@ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) { // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. - session := &SessionState{AccessToken: "unexpected_access_token"} + session := &sessions.SessionState{AccessToken: "unexpected_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) @@ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitLabProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) diff --git a/providers/google.go b/providers/google.go index e3cb380..f79a131 100644 --- a/providers/google.go +++ b/providers/google.go @@ -14,6 +14,7 @@ import ( "time" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "golang.org/x/oauth2" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" @@ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { +func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err if err != nil { return } - s = &SessionState{ + s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), @@ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool { // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required -func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { +func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil } diff --git a/providers/internal_util_test.go b/providers/internal_util_test.go index 1a03fc5..ba6d470 100644 --- a/providers/internal_util_test.go +++ b/providers/internal_util_test.go @@ -7,6 +7,7 @@ import ( "net/url" "testing" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct { *ProviderData } -func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { +func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } // Note that we're testing the internal validateToken() used to implement // several Provider's ValidateSessionState() implementations -func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { +func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool { return false } diff --git a/providers/linkedin.go b/providers/linkedin.go index 8c392f8..a31b4a1 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -7,6 +7,7 @@ import ( "net/url" "github.com/pusher/oauth2_proxy/api" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // LinkedInProvider represents an LinkedIn based Identity Provider @@ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header { } // GetEmailAddress returns the Account email address -func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { +func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { if s.AccessToken == "" { return "", errors.New("missing access token") } @@ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { } // ValidateSessionState validates the AccessToken -func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { +func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool { return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) } diff --git a/providers/linkedin_test.go b/providers/linkedin_test.go index 7911522..9910a71 100644 --- a/providers/linkedin_test.go +++ b/providers/linkedin_test.go @@ -6,6 +6,7 @@ import ( "net/url" "testing" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testLinkedInProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@linkedin.com", email) @@ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. - session := &SessionState{AccessToken: "unexpected_access_token"} + session := &sessions.SessionState{AccessToken: "unexpected_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) @@ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testLinkedInProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) diff --git a/providers/logingov.go b/providers/logingov.go index 09bd3be..60f4260 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -13,6 +13,7 @@ import ( "time" "github.com/dgrijalva/jwt-go" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "gopkg.in/square/go-jose.v2" ) @@ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { +func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er } // Store the data that we found in the session state - s = &SessionState{ + s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), diff --git a/providers/oidc.go b/providers/oidc.go index d751be5..bacabdf 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -5,9 +5,9 @@ import ( "fmt" "time" - "golang.org/x/oauth2" - oidc "github.com/coreos/go-oidc" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "golang.org/x/oauth2" ) // OIDCProvider represents an OIDC based Identity Provider @@ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { +func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { ctx := context.Background() c := oauth2.Config{ ClientID: p.ClientID, @@ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required -func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { +func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil } @@ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { return true, nil } -func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { +func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) { c := oauth2.Config{ ClientID: p.ClientID, ClientSecret: p.ClientSecret, @@ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { return } -func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) { +func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok { return nil, fmt.Errorf("token response did not contain an id_token") @@ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) } - return &SessionState{ + return &sessions.SessionState{ AccessToken: token.AccessToken, IDToken: rawIDToken, RefreshToken: token.RefreshToken, @@ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok } // ValidateSessionState checks that the session's IDToken is still valid -func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool { +func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool { ctx := context.Background() _, err := p.Verifier.Verify(ctx, s.IDToken) if err != nil { diff --git a/providers/provider_default.go b/providers/provider_default.go index f8f59ab..cd78251 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -10,10 +10,11 @@ import ( "net/url" "github.com/pusher/oauth2_proxy/cookie" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // Redeem provides a default implementation of the OAuth2 token redemption process -func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { +func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er } err = json.Unmarshal(body, &jsonResponse) if err == nil { - s = &SessionState{ + s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, } return @@ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er return } if a := v.Get("access_token"); a != "" { - s = &SessionState{AccessToken: a} + s = &sessions.SessionState{AccessToken: a} } else { err = fmt.Errorf("no access token found %s", body) } @@ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string { } // CookieForSession serializes a session state for storage in a cookie -func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { +func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) { return s.EncodeSessionState(c) } // SessionFromCookie deserializes a session from a cookie value -func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { - return DecodeSessionState(v, c) +func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) { + return sessions.DecodeSessionState(v, c) } // GetEmailAddress returns the Account email address -func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { +func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } // GetUserName returns the Account username -func (p *ProviderData) GetUserName(s *SessionState) (string, error) { +func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } @@ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool { } // ValidateSessionState validates the AccessToken -func (p *ProviderData) ValidateSessionState(s *SessionState) bool { +func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool { return validateToken(p, s.AccessToken, nil) } // RefreshSessionIfNeeded should refresh the user's session if required and // do nothing if a refresh is not required -func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { +func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { return false, nil } diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index abff0a9..ffe4aa7 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -4,12 +4,13 @@ import ( "testing" "time" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) func TestRefresh(t *testing.T) { p := &ProviderData{} - refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ + refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{ ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), }) assert.Equal(t, false, refreshed) diff --git a/providers/providers.go b/providers/providers.go index 4616153..57ace41 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -2,20 +2,21 @@ package providers import ( "github.com/pusher/oauth2_proxy/cookie" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // Provider represents an upstream identity provider implementation type Provider interface { Data() *ProviderData - GetEmailAddress(*SessionState) (string, error) - GetUserName(*SessionState) (string, error) - Redeem(string, string) (*SessionState, error) + GetEmailAddress(*sessions.SessionState) (string, error) + GetUserName(*sessions.SessionState) (string, error) + Redeem(string, string) (*sessions.SessionState, error) ValidateGroup(string) bool - ValidateSessionState(*SessionState) bool + ValidateSessionState(*sessions.SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string - RefreshSessionIfNeeded(*SessionState) (bool, error) - SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) - CookieForSession(*SessionState, *cookie.Cipher) (string, error) + RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) + SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error) + CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error) } // New provides a new Provider based on the configured provider string