Merge pull request #147 from pusher/session-store
Add initial session-store interface and implementation
This commit is contained in:
commit
17e97ab884
@ -10,7 +10,11 @@
|
||||
|
||||
## Changes since v3.2.0
|
||||
|
||||
- [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings)
|
||||
- [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed)
|
||||
- Allows for multiple different session storage implementations including client and server side
|
||||
- Adds tests suite for interface to ensure consistency across implementations
|
||||
- Refactor some configuration options (around cookies) into packages
|
||||
- [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings)
|
||||
- [#146](https://github.com/pusher/oauth2_proxy/pull/146) Use full email address as `User` if the auth response did not contain a `User` field (@gargath)
|
||||
- [#144](https://github.com/pusher/oauth2_proxy/pull/144) Use GO 1.12 for ARM builds (@kskewes)
|
||||
- [#142](https://github.com/pusher/oauth2_proxy/pull/142) ARM Docker USER fix (@kskewes)
|
||||
|
129
Gopkg.lock
generated
129
Gopkg.lock
generated
@ -57,6 +57,20 @@
|
||||
pruneopts = ""
|
||||
revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:b3c5b95e56c06f5aa72cb2500e6ee5f44fcd122872d4fec2023a488e561218bc"
|
||||
name = "github.com/hpcloud/tail"
|
||||
packages = [
|
||||
".",
|
||||
"ratelimiter",
|
||||
"util",
|
||||
"watch",
|
||||
"winfile",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "a30252cb686a21eb2d0b98132633053ec2f7f1e5"
|
||||
version = "v1.0.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f"
|
||||
name = "github.com/mbland/hmacauth"
|
||||
@ -73,6 +87,54 @@
|
||||
pruneopts = ""
|
||||
revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:a3735b0978a8b53fc2ac97a6f46ca1189f0712a00df86d6ec4cf26c1a25e6d77"
|
||||
name = "github.com/onsi/ginkgo"
|
||||
packages = [
|
||||
".",
|
||||
"config",
|
||||
"internal/codelocation",
|
||||
"internal/containernode",
|
||||
"internal/failer",
|
||||
"internal/leafnodes",
|
||||
"internal/remote",
|
||||
"internal/spec",
|
||||
"internal/spec_iterator",
|
||||
"internal/specrunner",
|
||||
"internal/suite",
|
||||
"internal/testingtproxy",
|
||||
"internal/writer",
|
||||
"reporters",
|
||||
"reporters/stenographer",
|
||||
"reporters/stenographer/support/go-colorable",
|
||||
"reporters/stenographer/support/go-isatty",
|
||||
"types",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "eea6ad008b96acdaa524f5b409513bf062b500ad"
|
||||
version = "v1.8.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:dbafce2fddb1ca331646fe2ac9c9413980368b19a60a4406a6e5861680bd73be"
|
||||
name = "github.com/onsi/gomega"
|
||||
packages = [
|
||||
".",
|
||||
"format",
|
||||
"internal/assertion",
|
||||
"internal/asyncassertion",
|
||||
"internal/oraclematcher",
|
||||
"internal/testingtsupport",
|
||||
"matchers",
|
||||
"matchers/support/goraph/bipartitegraph",
|
||||
"matchers/support/goraph/edge",
|
||||
"matchers/support/goraph/node",
|
||||
"matchers/support/goraph/util",
|
||||
"types",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "90e289841c1ed79b7a598a7cd9959750cb5e89e2"
|
||||
version = "v1.5.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
|
||||
name = "github.com/pmezard/go-difflib"
|
||||
@ -131,6 +193,9 @@
|
||||
packages = [
|
||||
"context",
|
||||
"context/ctxhttp",
|
||||
"html",
|
||||
"html/atom",
|
||||
"html/charset",
|
||||
"websocket",
|
||||
]
|
||||
pruneopts = ""
|
||||
@ -150,6 +215,42 @@
|
||||
pruneopts = ""
|
||||
revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:67a6e61e60283fd7dce50eba228080bff8805d9d69b2f121d7ec2260d120c4a8"
|
||||
name = "golang.org/x/sys"
|
||||
packages = ["unix"]
|
||||
pruneopts = ""
|
||||
revision = "ca7f33d4116e3a1f9425755d6a44e7ed9b4c97df"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0"
|
||||
name = "golang.org/x/text"
|
||||
packages = [
|
||||
"encoding",
|
||||
"encoding/charmap",
|
||||
"encoding/htmlindex",
|
||||
"encoding/internal",
|
||||
"encoding/internal/identifier",
|
||||
"encoding/japanese",
|
||||
"encoding/korean",
|
||||
"encoding/simplifiedchinese",
|
||||
"encoding/traditionalchinese",
|
||||
"encoding/unicode",
|
||||
"internal/gen",
|
||||
"internal/language",
|
||||
"internal/language/compact",
|
||||
"internal/tag",
|
||||
"internal/utf8internal",
|
||||
"language",
|
||||
"runes",
|
||||
"transform",
|
||||
"unicode/cldr",
|
||||
]
|
||||
pruneopts = ""
|
||||
revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475"
|
||||
version = "v0.3.2"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed"
|
||||
@ -182,6 +283,15 @@
|
||||
revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a"
|
||||
version = "v1.0.0"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:eb53021a8aa3f599d29c7102e65026242bdedce998a54837dc67f14b6a97c5fd"
|
||||
name = "gopkg.in/fsnotify.v1"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9"
|
||||
source = "https://github.com/fsnotify/fsnotify.git"
|
||||
version = "v1.4.7"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2"
|
||||
name = "gopkg.in/fsnotify/fsnotify.v1"
|
||||
@ -210,6 +320,22 @@
|
||||
revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1"
|
||||
version = "v2.1.3"
|
||||
|
||||
[[projects]]
|
||||
branch = "v1"
|
||||
digest = "1:a96d16bd088460f2e0685d46c39bcf1208ba46e0a977be2df49864ec7da447dd"
|
||||
name = "gopkg.in/tomb.v1"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "dd632973f1e7218eb1089048e0798ec9ae7dceb8"
|
||||
|
||||
[[projects]]
|
||||
digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d"
|
||||
name = "gopkg.in/yaml.v2"
|
||||
packages = ["."]
|
||||
pruneopts = ""
|
||||
revision = "51d6538a90f86fe93ac480b35f37b2be17fef232"
|
||||
version = "v2.2.2"
|
||||
|
||||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
@ -220,6 +346,8 @@
|
||||
"github.com/dgrijalva/jwt-go",
|
||||
"github.com/mbland/hmacauth",
|
||||
"github.com/mreiferson/go-options",
|
||||
"github.com/onsi/ginkgo",
|
||||
"github.com/onsi/gomega",
|
||||
"github.com/stretchr/testify/assert",
|
||||
"github.com/stretchr/testify/require",
|
||||
"github.com/yhat/wsutil",
|
||||
@ -231,6 +359,7 @@
|
||||
"google.golang.org/api/googleapi",
|
||||
"gopkg.in/fsnotify/fsnotify.v1",
|
||||
"gopkg.in/natefinch/lumberjack.v2",
|
||||
"gopkg.in/square/go-jose.v2",
|
||||
]
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
||||
|
@ -35,6 +35,10 @@
|
||||
name = "gopkg.in/fsnotify/fsnotify.v1"
|
||||
version = "~1.2.0"
|
||||
|
||||
[[override]]
|
||||
name = "gopkg.in/fsnotify.v1"
|
||||
source = "https://github.com/fsnotify/fsnotify.git"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/crypto"
|
||||
|
1
Makefile
1
Makefile
@ -33,6 +33,7 @@ lint: $(GOMETALINTER)
|
||||
--enable=deadcode \
|
||||
--enable=gofmt \
|
||||
--enable=goimports \
|
||||
--deadline=120s \
|
||||
--tests ./...
|
||||
|
||||
.PHONY: dep
|
||||
|
@ -1,7 +1,8 @@
|
||||
---
|
||||
layout: default
|
||||
title: Configuration
|
||||
permalink: /configuration
|
||||
permalink: /docs/configuration
|
||||
has_children: true
|
||||
nav_order: 3
|
||||
---
|
||||
|
||||
@ -78,6 +79,7 @@ Usage of oauth2_proxy:
|
||||
-request-logging-format: Template for request log lines (see "Logging Configuration" paragraph below)
|
||||
-resource string: The resource that is protected (Azure AD only)
|
||||
-scope string: OAuth scope specification
|
||||
-session-store-type: Session data storage backend (default: cookie)
|
||||
-set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)
|
||||
-set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode)
|
||||
-signature-key string: GAP-Signature request signature key (algorithm:secretkey)
|
34
docs/configuration/sessions.md
Normal file
34
docs/configuration/sessions.md
Normal file
@ -0,0 +1,34 @@
|
||||
---
|
||||
layout: default
|
||||
title: Sessions
|
||||
permalink: /configuration
|
||||
parent: Configuration
|
||||
nav_order: 3
|
||||
---
|
||||
|
||||
## Sessions
|
||||
|
||||
Sessions allow a user's authentication to be tracked between multiple HTTP
|
||||
requests to a service.
|
||||
|
||||
The OAuth2 Proxy uses a Cookie to track user sessions and will store the session
|
||||
data in one of the available session storage backends.
|
||||
|
||||
At present the available backends are (as passed to `--session-store-type`):
|
||||
- [cookie](cookie-storage) (deafult)
|
||||
|
||||
### Cookie Storage
|
||||
|
||||
The Cookie storage backend is the default backend implementation and has
|
||||
been used in the OAuth2 Proxy historically.
|
||||
|
||||
With the Cookie storage backend, all session information is stored in client
|
||||
side cookies and transferred with each and every request.
|
||||
|
||||
The following should be known when using this implementation:
|
||||
- Since all state is stored client side, this storage backend means that the OAuth2 Proxy is completely stateless
|
||||
- Cookies are signed server side to prevent modification client-side
|
||||
- It is recommended to set a `cookie-secret` which will ensure data is encrypted within the cookie data.
|
||||
- Since multiple requests can be made concurrently to the OAuth2 Proxy, this session implementation
|
||||
cannot lock sessions and while updating and refreshing sessions, there can be conflicts which force
|
||||
users to re-authenticate
|
@ -15,14 +15,27 @@ type EnvOptions map[string]interface{}
|
||||
// Fields in the options struct must have an `env` and `cfg` tag to be read
|
||||
// from the environment
|
||||
func (cfg EnvOptions) LoadEnvForStruct(options interface{}) {
|
||||
val := reflect.ValueOf(options).Elem()
|
||||
typ := val.Type()
|
||||
val := reflect.ValueOf(options)
|
||||
var typ reflect.Type
|
||||
if val.Kind() == reflect.Ptr {
|
||||
typ = val.Elem().Type()
|
||||
} else {
|
||||
typ = val.Type()
|
||||
}
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
// pull out the struct tags:
|
||||
// flag - the name of the command line flag
|
||||
// deprecated - (optional) the name of the deprecated command line flag
|
||||
// cfg - (optional, defaults to underscored flag) the name of the config file option
|
||||
field := typ.Field(i)
|
||||
fieldV := reflect.Indirect(val).Field(i)
|
||||
|
||||
if field.Type.Kind() == reflect.Struct && field.Anonymous {
|
||||
cfg.LoadEnvForStruct(fieldV.Interface())
|
||||
continue
|
||||
}
|
||||
|
||||
flagName := field.Tag.Get("flag")
|
||||
envName := field.Tag.Get("env")
|
||||
cfgName := field.Tag.Get("cfg")
|
||||
|
@ -1,26 +1,46 @@
|
||||
package main
|
||||
package main_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
proxy "github.com/pusher/oauth2_proxy"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type envTest struct {
|
||||
testField string `cfg:"target_field" env:"TEST_ENV_FIELD"`
|
||||
type EnvTest struct {
|
||||
TestField string `cfg:"target_field" env:"TEST_ENV_FIELD"`
|
||||
EnvTestEmbed
|
||||
}
|
||||
|
||||
type EnvTestEmbed struct {
|
||||
TestFieldEmbed string `cfg:"target_field_embed" env:"TEST_ENV_FIELD_EMBED"`
|
||||
}
|
||||
|
||||
func TestLoadEnvForStruct(t *testing.T) {
|
||||
|
||||
cfg := make(EnvOptions)
|
||||
cfg.LoadEnvForStruct(&envTest{})
|
||||
cfg := make(proxy.EnvOptions)
|
||||
cfg.LoadEnvForStruct(&EnvTest{})
|
||||
|
||||
_, ok := cfg["target_field"]
|
||||
assert.Equal(t, ok, false)
|
||||
|
||||
os.Setenv("TEST_ENV_FIELD", "1234abcd")
|
||||
cfg.LoadEnvForStruct(&envTest{})
|
||||
cfg.LoadEnvForStruct(&EnvTest{})
|
||||
v := cfg["target_field"]
|
||||
assert.Equal(t, v, "1234abcd")
|
||||
}
|
||||
|
||||
func TestLoadEnvForStructWithEmbeddedFields(t *testing.T) {
|
||||
|
||||
cfg := make(proxy.EnvOptions)
|
||||
cfg.LoadEnvForStruct(&EnvTest{})
|
||||
|
||||
_, ok := cfg["target_field_embed"]
|
||||
assert.Equal(t, ok, false)
|
||||
|
||||
os.Setenv("TEST_ENV_FIELD_EMBED", "1234abcd")
|
||||
cfg.LoadEnvForStruct(&EnvTest{})
|
||||
v := cfg["target_field_embed"]
|
||||
assert.Equal(t, v, "1234abcd")
|
||||
}
|
||||
|
2
main.go
2
main.go
@ -75,6 +75,8 @@ func main() {
|
||||
flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag")
|
||||
flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag")
|
||||
|
||||
flagSet.String("session-store-type", "cookie", "the session storage provider to use")
|
||||
|
||||
flagSet.String("logging-filename", "", "File to log requests to, empty for stdout")
|
||||
flagSet.Int("logging-max-size", 100, "Maximum size in megabytes of the log file before rotation")
|
||||
flagSet.Int("logging-max-age", 7, "Maximum number of days to retain old log files")
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"github.com/yhat/wsutil"
|
||||
)
|
||||
@ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
|
||||
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) {
|
||||
func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
}
|
||||
@ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request,
|
||||
}
|
||||
|
||||
// LoadCookiedSession reads the user's authentication details from the request
|
||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
|
||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) {
|
||||
var age time.Duration
|
||||
c, err := loadCookie(req, p.CookieName)
|
||||
if err != nil {
|
||||
@ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt
|
||||
}
|
||||
|
||||
// SaveSession creates a new session cookie value and sets this on the response
|
||||
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
|
||||
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
||||
value, err := p.provider.CookieForSession(s, p.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
user, ok := p.ManualSignIn(rw, req)
|
||||
if ok {
|
||||
session := &providers.SessionState{User: user}
|
||||
session := &sessions.SessionState{User: user}
|
||||
p.SaveSession(rw, req, session)
|
||||
http.Redirect(rw, req, redirect, 302)
|
||||
} else {
|
||||
@ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
||||
|
||||
// CheckBasicAuth checks the requests Authorization header for basic auth
|
||||
// credentials and authenticates these against the proxies HtpasswdFile
|
||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
|
||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) {
|
||||
if p.HtpasswdFile == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState,
|
||||
}
|
||||
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
|
||||
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
||||
return &providers.SessionState{User: pair[0]}, nil
|
||||
return &sessions.SessionState{User: pair[0]}, nil
|
||||
}
|
||||
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
||||
return nil, nil
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) {
|
||||
func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) {
|
||||
return tp.EmailAddress, nil
|
||||
}
|
||||
|
||||
func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool {
|
||||
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
|
||||
return tp.ValidToken
|
||||
}
|
||||
|
||||
@ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook
|
||||
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
||||
}
|
||||
|
||||
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
|
||||
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error {
|
||||
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) {
|
||||
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) {
|
||||
return p.proxy.LoadCookiedSession(p.req)
|
||||
}
|
||||
|
||||
func TestLoadCookiedSession(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
|
||||
startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
||||
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, time.Now())
|
||||
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
@ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
|
||||
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||
|
||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
|
||||
session, age, err := pcTest.LoadCookiedSession()
|
||||
@ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
@ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
|
||||
pcTest.proxy.CookieRefresh = time.Hour
|
||||
@ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest {
|
||||
|
||||
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
||||
test := NewAuthOnlyEndpointTest()
|
||||
startSession := &providers.SessionState{
|
||||
startSession := &sessions.SessionState{
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
test.SaveSession(startSession, time.Now())
|
||||
|
||||
@ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
||||
test := NewAuthOnlyEndpointTest()
|
||||
test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &providers.SessionState{
|
||||
startSession := &sessions.SessionState{
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
test.SaveSession(startSession, reference)
|
||||
|
||||
@ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
||||
|
||||
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||
test := NewAuthOnlyEndpointTest()
|
||||
startSession := &providers.SessionState{
|
||||
startSession := &sessions.SessionState{
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
test.SaveSession(startSession, time.Now())
|
||||
test.validateUser = false
|
||||
@ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
||||
pcTest.req, _ = http.NewRequest("GET",
|
||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||
|
||||
startSession := &providers.SessionState{
|
||||
startSession := &sessions.SessionState{
|
||||
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
|
||||
pcTest.SaveSession(startSession, time.Now())
|
||||
|
||||
@ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
|
||||
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
|
||||
req.Header = st.header
|
||||
|
||||
state := &providers.SessionState{
|
||||
state := &sessions.SessionState{
|
||||
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
||||
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
|
||||
if err != nil {
|
||||
|
36
options.go
36
options.go
@ -18,6 +18,7 @@ import (
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
@ -49,14 +50,11 @@ type Options struct {
|
||||
CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"`
|
||||
Footer string `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"`
|
||||
|
||||
CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"`
|
||||
CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"`
|
||||
CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"`
|
||||
CookiePath string `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"`
|
||||
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"`
|
||||
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
|
||||
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"`
|
||||
CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"`
|
||||
// Embed CookieOptions
|
||||
options.CookieOptions
|
||||
|
||||
// Embed SessionOptions
|
||||
options.SessionOptions
|
||||
|
||||
Upstreams []string `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"`
|
||||
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"`
|
||||
@ -126,16 +124,18 @@ type SignatureData struct {
|
||||
// NewOptions constructs a new Options with defaulted values
|
||||
func NewOptions() *Options {
|
||||
return &Options{
|
||||
ProxyPrefix: "/oauth2",
|
||||
ProxyWebSockets: true,
|
||||
HTTPAddress: "127.0.0.1:4180",
|
||||
HTTPSAddress: ":443",
|
||||
DisplayHtpasswdForm: true,
|
||||
CookieName: "_oauth2_proxy",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieRefresh: time.Duration(0),
|
||||
ProxyPrefix: "/oauth2",
|
||||
ProxyWebSockets: true,
|
||||
HTTPAddress: "127.0.0.1:4180",
|
||||
HTTPSAddress: ":443",
|
||||
DisplayHtpasswdForm: true,
|
||||
CookieOptions: options.CookieOptions{
|
||||
CookieName: "_oauth2_proxy",
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieRefresh: time.Duration(0),
|
||||
},
|
||||
SetXAuthRequest: false,
|
||||
SkipAuthPreflight: false,
|
||||
PassBasicAuth: true,
|
||||
|
15
pkg/apis/options/cookie.go
Normal file
15
pkg/apis/options/cookie.go
Normal file
@ -0,0 +1,15 @@
|
||||
package options
|
||||
|
||||
import "time"
|
||||
|
||||
// CookieOptions contains configuration options relating to Cookie configuration
|
||||
type CookieOptions struct {
|
||||
CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"`
|
||||
CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"`
|
||||
CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"`
|
||||
CookiePath string `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"`
|
||||
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"`
|
||||
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
|
||||
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"`
|
||||
CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"`
|
||||
}
|
14
pkg/apis/options/sessions.go
Normal file
14
pkg/apis/options/sessions.go
Normal file
@ -0,0 +1,14 @@
|
||||
package options
|
||||
|
||||
// SessionOptions contains configuration options for the SessionStore providers.
|
||||
type SessionOptions struct {
|
||||
Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"`
|
||||
CookieStoreOptions
|
||||
}
|
||||
|
||||
// CookieSessionStoreType is used to indicate the CookieSessionStore should be
|
||||
// used for storing sessions.
|
||||
var CookieSessionStoreType = "cookie"
|
||||
|
||||
// CookieStoreOptions contains configuration options for the CookieSessionStore.
|
||||
type CookieStoreOptions struct{}
|
12
pkg/apis/sessions/interfaces.go
Normal file
12
pkg/apis/sessions/interfaces.go
Normal file
@ -0,0 +1,12 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// SessionStore is an interface to storing user sessions in the proxy
|
||||
type SessionStore interface {
|
||||
Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
|
||||
Load(req *http.Request) (*SessionState, error)
|
||||
Clear(rw http.ResponseWriter, req *http.Request) error
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package sessions_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
c2, err := cookie.NewCipher([]byte(altSecret))
|
||||
assert.Equal(t, nil, err)
|
||||
s := &SessionState{
|
||||
s := &sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
@ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
ss, err := sessions.DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@domain.com", ss.User)
|
||||
@ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||
ss, err = DecodeSessionState(encoded, c2)
|
||||
ss, err = sessions.DecodeSessionState(encoded, c2)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, "user@domain.com", ss.User)
|
||||
@ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
c2, err := cookie.NewCipher([]byte(altSecret))
|
||||
assert.Equal(t, nil, err)
|
||||
s := &SessionState{
|
||||
s := &sessions.SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
@ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
ss, err := sessions.DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
@ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||
ss, err = DecodeSessionState(encoded, c2)
|
||||
ss, err = sessions.DecodeSessionState(encoded, c2)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, s.User, ss.User)
|
||||
@ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
s := &SessionState{
|
||||
s := &sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
@ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
ss, err := sessions.DecodeSessionState(encoded, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@domain.com", ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
@ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
s := &SessionState{
|
||||
s := &sessions.SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
@ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
ss, err := sessions.DecodeSessionState(encoded, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
@ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
|
||||
s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
|
||||
assert.Equal(t, true, s.IsExpired())
|
||||
|
||||
s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
|
||||
s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
|
||||
assert.Equal(t, false, s.IsExpired())
|
||||
|
||||
s = &SessionState{}
|
||||
s = &sessions.SessionState{}
|
||||
assert.Equal(t, false, s.IsExpired())
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
SessionState
|
||||
sessions.SessionState
|
||||
Encoded string
|
||||
Cipher *cookie.Cipher
|
||||
Error bool
|
||||
@ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
@ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
|
||||
for i, tc := range testCases {
|
||||
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, encoded)
|
||||
@ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeSessionState tests DecodeSessionState with the test vector
|
||||
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
|
||||
func TestDecodeSessionState(t *testing.T) {
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
eJSON, _ := e.MarshalJSON()
|
||||
@ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "user@domain.com",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
@ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
@ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
},
|
||||
@ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
@ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
@ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||
ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ss)
|
34
pkg/cookies/cookies.go
Normal file
34
pkg/cookies/cookies.go
Normal file
@ -0,0 +1,34 @@
|
||||
package cookies
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
)
|
||||
|
||||
// MakeCookie constructs a cookie from the given parameters,
|
||||
// discovering the domain from the request if not specified.
|
||||
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
if domain != "" {
|
||||
host := req.Host
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
if !strings.HasSuffix(host, domain) {
|
||||
logger.Printf("Warning: request host is %q but using configured cookie domain of %q", host, domain)
|
||||
}
|
||||
}
|
||||
|
||||
return &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: path,
|
||||
Domain: domain,
|
||||
HttpOnly: httpOnly,
|
||||
Secure: secure,
|
||||
Expires: now.Add(expiration),
|
||||
}
|
||||
}
|
232
pkg/sessions/cookie/session_store.go
Normal file
232
pkg/sessions/cookie/session_store.go
Normal file
@ -0,0 +1,232 @@
|
||||
package cookie
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/pkg/cookies"
|
||||
"github.com/pusher/oauth2_proxy/pkg/sessions/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
// Cookies are limited to 4kb including the length of the cookie name,
|
||||
// the cookie name can be up to 256 bytes
|
||||
maxCookieLength = 3840
|
||||
)
|
||||
|
||||
// Ensure CookieSessionStore implements the interface
|
||||
var _ sessions.SessionStore = &SessionStore{}
|
||||
|
||||
// SessionStore is an implementation of the sessions.SessionStore
|
||||
// interface that stores sessions in client side cookies
|
||||
type SessionStore struct {
|
||||
CookieCipher *cookie.Cipher
|
||||
CookieDomain string
|
||||
CookieExpire time.Duration
|
||||
CookieHTTPOnly bool
|
||||
CookieName string
|
||||
CookiePath string
|
||||
CookieSecret string
|
||||
CookieSecure bool
|
||||
}
|
||||
|
||||
// Save takes a sessions.SessionState and stores the information from it
|
||||
// within Cookies set on the HTTP response writer
|
||||
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
|
||||
value, err := utils.CookieForSession(ss, s.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.setSessionCookie(rw, req, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load reads sessions.SessionState information from Cookies within the
|
||||
// HTTP request object
|
||||
func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
|
||||
c, err := loadCookie(req, s.CookieName)
|
||||
if err != nil {
|
||||
// always http.ErrNoCookie
|
||||
return nil, fmt.Errorf("Cookie %q not present", s.CookieName)
|
||||
}
|
||||
val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire)
|
||||
if !ok {
|
||||
return nil, errors.New("Cookie Signature not valid")
|
||||
}
|
||||
|
||||
session, err := utils.SessionFromCookie(val, s.CookieCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// Clear clears any saved session information by writing a cookie to
|
||||
// clear the session
|
||||
func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
||||
var cookies []*http.Cookie
|
||||
|
||||
// matches CookieName, CookieName_<number>
|
||||
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName))
|
||||
|
||||
for _, c := range req.Cookies() {
|
||||
if cookieNameRegex.MatchString(c.Name) {
|
||||
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1)
|
||||
|
||||
http.SetCookie(rw, clearCookie)
|
||||
cookies = append(cookies, clearCookie)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setSessionCookie adds the user's session cookie to the response
|
||||
func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) {
|
||||
http.SetCookie(rw, c)
|
||||
}
|
||||
}
|
||||
|
||||
// makeSessionCookie creates an http.Cookie containing the authenticated user's
|
||||
// authentication details
|
||||
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie {
|
||||
if value != "" {
|
||||
value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now)
|
||||
}
|
||||
c := s.makeCookie(req, s.CookieName, value, expiration)
|
||||
if len(c.Value) > 4096-len(s.CookieName) {
|
||||
return splitCookie(c)
|
||||
}
|
||||
return []*http.Cookie{c}
|
||||
}
|
||||
|
||||
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie {
|
||||
return cookies.MakeCookie(
|
||||
req,
|
||||
name,
|
||||
value,
|
||||
s.CookiePath,
|
||||
s.CookieDomain,
|
||||
s.CookieHTTPOnly,
|
||||
s.CookieSecure,
|
||||
expiration,
|
||||
time.Now(),
|
||||
)
|
||||
}
|
||||
|
||||
// NewCookieSessionStore initialises a new instance of the SessionStore from
|
||||
// the configuration given
|
||||
func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
||||
var cipher *cookie.Cipher
|
||||
if len(cookieOpts.CookieSecret) > 0 {
|
||||
var err error
|
||||
cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create cipher: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &SessionStore{
|
||||
CookieCipher: cipher,
|
||||
CookieDomain: cookieOpts.CookieDomain,
|
||||
CookieExpire: cookieOpts.CookieExpire,
|
||||
CookieHTTPOnly: cookieOpts.CookieHTTPOnly,
|
||||
CookieName: cookieOpts.CookieName,
|
||||
CookiePath: cookieOpts.CookiePath,
|
||||
CookieSecret: cookieOpts.CookieSecret,
|
||||
CookieSecure: cookieOpts.CookieSecure,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// splitCookie reads the full cookie generated to store the session and splits
|
||||
// it into a slice of cookies which fit within the 4kb cookie limit indexing
|
||||
// the cookies from 0
|
||||
func splitCookie(c *http.Cookie) []*http.Cookie {
|
||||
if len(c.Value) < maxCookieLength {
|
||||
return []*http.Cookie{c}
|
||||
}
|
||||
cookies := []*http.Cookie{}
|
||||
valueBytes := []byte(c.Value)
|
||||
count := 0
|
||||
for len(valueBytes) > 0 {
|
||||
new := copyCookie(c)
|
||||
new.Name = fmt.Sprintf("%s_%d", c.Name, count)
|
||||
count++
|
||||
if len(valueBytes) < maxCookieLength {
|
||||
new.Value = string(valueBytes)
|
||||
valueBytes = []byte{}
|
||||
} else {
|
||||
newValue := valueBytes[:maxCookieLength]
|
||||
valueBytes = valueBytes[maxCookieLength:]
|
||||
new.Value = string(newValue)
|
||||
}
|
||||
cookies = append(cookies, new)
|
||||
}
|
||||
return cookies
|
||||
}
|
||||
|
||||
// loadCookie retreieves the sessions state cookie from the http request.
|
||||
// If a single cookie is present this will be returned, otherwise it attempts
|
||||
// to reconstruct a cookie split up by splitCookie
|
||||
func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
|
||||
c, err := req.Cookie(cookieName)
|
||||
if err == nil {
|
||||
return c, nil
|
||||
}
|
||||
cookies := []*http.Cookie{}
|
||||
err = nil
|
||||
count := 0
|
||||
for err == nil {
|
||||
var c *http.Cookie
|
||||
c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count))
|
||||
if err == nil {
|
||||
cookies = append(cookies, c)
|
||||
count++
|
||||
}
|
||||
}
|
||||
if len(cookies) == 0 {
|
||||
return nil, fmt.Errorf("Could not find cookie %s", cookieName)
|
||||
}
|
||||
return joinCookies(cookies)
|
||||
}
|
||||
|
||||
// joinCookies takes a slice of cookies from the request and reconstructs the
|
||||
// full session cookie
|
||||
func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
|
||||
if len(cookies) == 0 {
|
||||
return nil, fmt.Errorf("list of cookies must be > 0")
|
||||
}
|
||||
if len(cookies) == 1 {
|
||||
return cookies[0], nil
|
||||
}
|
||||
c := copyCookie(cookies[0])
|
||||
for i := 1; i < len(cookies); i++ {
|
||||
c.Value += cookies[i].Value
|
||||
}
|
||||
c.Name = strings.TrimRight(c.Name, "_0")
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func copyCookie(c *http.Cookie) *http.Cookie {
|
||||
return &http.Cookie{
|
||||
Name: c.Name,
|
||||
Value: c.Value,
|
||||
Path: c.Path,
|
||||
Domain: c.Domain,
|
||||
Expires: c.Expires,
|
||||
RawExpires: c.RawExpires,
|
||||
MaxAge: c.MaxAge,
|
||||
Secure: c.Secure,
|
||||
HttpOnly: c.HttpOnly,
|
||||
Raw: c.Raw,
|
||||
Unparsed: c.Unparsed,
|
||||
}
|
||||
}
|
19
pkg/sessions/session_store.go
Normal file
19
pkg/sessions/session_store.go
Normal file
@ -0,0 +1,19 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
|
||||
)
|
||||
|
||||
// NewSessionStore creates a SessionStore from the provided configuration
|
||||
func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
||||
switch opts.Type {
|
||||
case options.CookieSessionStoreType:
|
||||
return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown session store type '%s'", opts.Type)
|
||||
}
|
||||
}
|
254
pkg/sessions/session_store_test.go
Normal file
254
pkg/sessions/session_store_test.go
Normal file
@ -0,0 +1,254 @@
|
||||
package sessions_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||
sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/pkg/cookies"
|
||||
"github.com/pusher/oauth2_proxy/pkg/sessions"
|
||||
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
|
||||
)
|
||||
|
||||
func TestSessionStore(t *testing.T) {
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "SessionStore")
|
||||
}
|
||||
|
||||
var _ = Describe("NewSessionStore", func() {
|
||||
var opts *options.SessionOptions
|
||||
var cookieOpts *options.CookieOptions
|
||||
|
||||
var request *http.Request
|
||||
var response *httptest.ResponseRecorder
|
||||
var session *sessionsapi.SessionState
|
||||
var ss sessionsapi.SessionStore
|
||||
|
||||
CheckCookieOptions := func() {
|
||||
Context("the cookies returned", func() {
|
||||
var cookies []*http.Cookie
|
||||
BeforeEach(func() {
|
||||
cookies = response.Result().Cookies()
|
||||
})
|
||||
|
||||
It("have the correct name set", func() {
|
||||
if len(cookies) == 1 {
|
||||
Expect(cookies[0].Name).To(Equal(cookieOpts.CookieName))
|
||||
} else {
|
||||
for _, cookie := range cookies {
|
||||
Expect(cookie.Name).To(ContainSubstring(cookieOpts.CookieName))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
It("have the correct path set", func() {
|
||||
for _, cookie := range cookies {
|
||||
Expect(cookie.Path).To(Equal(cookieOpts.CookiePath))
|
||||
}
|
||||
})
|
||||
|
||||
It("have the correct domain set", func() {
|
||||
for _, cookie := range cookies {
|
||||
Expect(cookie.Domain).To(Equal(cookieOpts.CookieDomain))
|
||||
}
|
||||
})
|
||||
|
||||
It("have the correct HTTPOnly set", func() {
|
||||
for _, cookie := range cookies {
|
||||
Expect(cookie.HttpOnly).To(Equal(cookieOpts.CookieHTTPOnly))
|
||||
}
|
||||
})
|
||||
|
||||
It("have the correct secure set", func() {
|
||||
for _, cookie := range cookies {
|
||||
Expect(cookie.Secure).To(Equal(cookieOpts.CookieSecure))
|
||||
}
|
||||
})
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
SessionStoreInterfaceTests := func() {
|
||||
Context("when Save is called", func() {
|
||||
BeforeEach(func() {
|
||||
err := ss.Save(response, request, session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("sets a `set-cookie` header in the response", func() {
|
||||
Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
CheckCookieOptions()
|
||||
})
|
||||
|
||||
Context("when Clear is called", func() {
|
||||
BeforeEach(func() {
|
||||
cookie := cookies.MakeCookie(request,
|
||||
cookieOpts.CookieName,
|
||||
"foo",
|
||||
cookieOpts.CookiePath,
|
||||
cookieOpts.CookieDomain,
|
||||
cookieOpts.CookieHTTPOnly,
|
||||
cookieOpts.CookieSecure,
|
||||
cookieOpts.CookieExpire,
|
||||
time.Now(),
|
||||
)
|
||||
request.AddCookie(cookie)
|
||||
err := ss.Clear(response, request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("sets a `set-cookie` header in the response", func() {
|
||||
Expect(response.Header().Get("Set-Cookie")).ToNot(BeEmpty())
|
||||
})
|
||||
|
||||
CheckCookieOptions()
|
||||
})
|
||||
|
||||
Context("when Load is called", func() {
|
||||
var loadedSession *sessionsapi.SessionState
|
||||
BeforeEach(func() {
|
||||
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
err := ss.Save(resp, req, session)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
for _, cookie := range resp.Result().Cookies() {
|
||||
request.AddCookie(cookie)
|
||||
}
|
||||
loadedSession, err = ss.Load(request)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
It("loads a session equal to the original session", func() {
|
||||
if cookieOpts.CookieSecret == "" {
|
||||
// Only Email and User stored in session when encrypted
|
||||
Expect(loadedSession.Email).To(Equal(session.Email))
|
||||
Expect(loadedSession.User).To(Equal(session.User))
|
||||
} else {
|
||||
// All fields stored in session if encrypted
|
||||
|
||||
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
|
||||
l := *loadedSession
|
||||
l.ExpiresOn = time.Time{}
|
||||
s := *session
|
||||
s.ExpiresOn = time.Time{}
|
||||
Expect(l).To(Equal(s))
|
||||
|
||||
// Compare time.Time separately
|
||||
Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue())
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
RunSessionTests := func() {
|
||||
Context("with default options", func() {
|
||||
BeforeEach(func() {
|
||||
var err error
|
||||
ss, err = sessions.NewSessionStore(opts, cookieOpts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
SessionStoreInterfaceTests()
|
||||
})
|
||||
|
||||
Context("with non-default options", func() {
|
||||
BeforeEach(func() {
|
||||
cookieOpts = &options.CookieOptions{
|
||||
CookieName: "_cookie_name",
|
||||
CookiePath: "/path",
|
||||
CookieExpire: time.Duration(72) * time.Hour,
|
||||
CookieRefresh: time.Duration(3600),
|
||||
CookieSecure: false,
|
||||
CookieHTTPOnly: false,
|
||||
CookieDomain: "example.com",
|
||||
}
|
||||
|
||||
var err error
|
||||
ss, err = sessions.NewSessionStore(opts, cookieOpts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
SessionStoreInterfaceTests()
|
||||
})
|
||||
|
||||
Context("with a cookie-secret set", func() {
|
||||
BeforeEach(func() {
|
||||
secret := make([]byte, 32)
|
||||
_, err := rand.Read(secret)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret)
|
||||
|
||||
ss, err = sessions.NewSessionStore(opts, cookieOpts)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
})
|
||||
|
||||
SessionStoreInterfaceTests()
|
||||
})
|
||||
}
|
||||
|
||||
BeforeEach(func() {
|
||||
ss = nil
|
||||
opts = &options.SessionOptions{}
|
||||
|
||||
// Set default options in CookieOptions
|
||||
cookieOpts = &options.CookieOptions{
|
||||
CookieName: "_oauth2_proxy",
|
||||
CookiePath: "/",
|
||||
CookieExpire: time.Duration(168) * time.Hour,
|
||||
CookieRefresh: time.Duration(0),
|
||||
CookieSecure: true,
|
||||
CookieHTTPOnly: true,
|
||||
}
|
||||
|
||||
session = &sessionsapi.SessionState{
|
||||
AccessToken: "AccessToken",
|
||||
IDToken: "IDToken",
|
||||
ExpiresOn: time.Now().Add(1 * time.Hour),
|
||||
RefreshToken: "RefreshToken",
|
||||
Email: "john.doe@example.com",
|
||||
User: "john.doe",
|
||||
}
|
||||
|
||||
request = httptest.NewRequest("GET", "http://example.com/", nil)
|
||||
response = httptest.NewRecorder()
|
||||
})
|
||||
|
||||
Context("with type 'cookie'", func() {
|
||||
BeforeEach(func() {
|
||||
opts.Type = options.CookieSessionStoreType
|
||||
})
|
||||
|
||||
It("creates a cookie.SessionStore", func() {
|
||||
ss, err := sessions.NewSessionStore(opts, cookieOpts)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{}))
|
||||
})
|
||||
|
||||
Context("the cookie.SessionStore", func() {
|
||||
RunSessionTests()
|
||||
})
|
||||
})
|
||||
|
||||
Context("with an invalid type", func() {
|
||||
BeforeEach(func() {
|
||||
opts.Type = "invalid-type"
|
||||
})
|
||||
|
||||
It("returns an error", func() {
|
||||
ss, err := sessions.NewSessionStore(opts, cookieOpts)
|
||||
Expect(err).To(HaveOccurred())
|
||||
Expect(err.Error()).To(Equal("unknown session store type 'invalid-type'"))
|
||||
Expect(ss).To(BeNil())
|
||||
})
|
||||
})
|
||||
})
|
41
pkg/sessions/utils/utils.go
Normal file
41
pkg/sessions/utils/utils.go
Normal file
@ -0,0 +1,41 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// CookieForSession serializes a session state for storage in a cookie
|
||||
func CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) {
|
||||
return s.EncodeSessionState(c)
|
||||
}
|
||||
|
||||
// SessionFromCookie deserializes a session from a cookie value
|
||||
func SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) {
|
||||
return sessions.DecodeSessionState(v, c)
|
||||
}
|
||||
|
||||
// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
|
||||
func SecretBytes(secret string) []byte {
|
||||
b, err := base64.URLEncoding.DecodeString(addPadding(secret))
|
||||
if err == nil {
|
||||
return []byte(addPadding(string(b)))
|
||||
}
|
||||
return []byte(secret)
|
||||
}
|
||||
|
||||
func addPadding(secret string) string {
|
||||
padding := len(secret) % 4
|
||||
switch padding {
|
||||
case 1:
|
||||
return secret + "==="
|
||||
case 2:
|
||||
return secret + "=="
|
||||
case 3:
|
||||
return secret + "="
|
||||
default:
|
||||
return secret
|
||||
}
|
||||
}
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/bitly/go-simplejson"
|
||||
"github.com/pusher/oauth2_proxy/api"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// AzureProvider represents an Azure based Identity Provider
|
||||
@ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
var email string
|
||||
var err error
|
||||
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@windows.net", email)
|
||||
@ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@windows.net", email)
|
||||
@ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@windows.net", email)
|
||||
@ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||
assert.Equal(t, "", email)
|
||||
@ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
@ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||
assert.Equal(t, "", email)
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/api"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// FacebookProvider represents an Facebook based Identity Provider
|
||||
@ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
@ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool {
|
||||
func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// GitHubProvider represents an GitHub based Identity Provider
|
||||
@ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
|
||||
var emails []struct {
|
||||
Email string `json:"email"`
|
||||
@ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
}
|
||||
|
||||
// GetUserName returns the Account user name
|
||||
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
|
||||
func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
|
||||
var user struct {
|
||||
Login string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
@ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Empty(t, "", email)
|
||||
@ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
p.Org = "testorg1"
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
@ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
@ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
@ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetUserName(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "mbland", email)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/pusher/oauth2_proxy/api"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// GitLabProvider represents an GitLab based Identity Provider
|
||||
@ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
|
||||
req, err := http.NewRequest("GET",
|
||||
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitLabProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
@ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
@ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitLabProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
@ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
||||
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
@ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s = &SessionState{
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||
@ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
||||
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct {
|
||||
*ProviderData
|
||||
}
|
||||
|
||||
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Note that we're testing the internal validateToken() used to implement
|
||||
// several Provider's ValidateSessionState() implementations
|
||||
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool {
|
||||
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/api"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// LinkedInProvider represents an LinkedIn based Identity Provider
|
||||
@ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
@ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool {
|
||||
func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testLinkedInProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@linkedin.com", email)
|
||||
@ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
@ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testLinkedInProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
@ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
||||
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
@ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er
|
||||
}
|
||||
|
||||
// Store the data that we found in the session state
|
||||
s = &SessionState{
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||
|
@ -5,9 +5,9 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OIDCProvider represents an OIDC based Identity Provider
|
||||
@ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
||||
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
ctx := context.Background()
|
||||
c := oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
@ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
||||
func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
@ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
|
||||
func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) {
|
||||
c := oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
@ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) {
|
||||
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||
@ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||
}
|
||||
|
||||
return &SessionState{
|
||||
return &sessions.SessionState{
|
||||
AccessToken: token.AccessToken,
|
||||
IDToken: rawIDToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
@ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
||||
}
|
||||
|
||||
// ValidateSessionState checks that the session's IDToken is still valid
|
||||
func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool {
|
||||
func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
ctx := context.Background()
|
||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
if err != nil {
|
||||
|
@ -10,10 +10,11 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// Redeem provides a default implementation of the OAuth2 token redemption process
|
||||
func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
||||
func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
@ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
if err == nil {
|
||||
s = &SessionState{
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
}
|
||||
return
|
||||
@ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
|
||||
return
|
||||
}
|
||||
if a := v.Get("access_token"); a != "" {
|
||||
s = &SessionState{AccessToken: a}
|
||||
s = &sessions.SessionState{AccessToken: a}
|
||||
} else {
|
||||
err = fmt.Errorf("no access token found %s", body)
|
||||
}
|
||||
@ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
|
||||
}
|
||||
|
||||
// CookieForSession serializes a session state for storage in a cookie
|
||||
func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) {
|
||||
func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) {
|
||||
return s.EncodeSessionState(c)
|
||||
}
|
||||
|
||||
// SessionFromCookie deserializes a session from a cookie value
|
||||
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) {
|
||||
return DecodeSessionState(v, c)
|
||||
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) {
|
||||
return sessions.DecodeSessionState(v, c)
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
// GetUserName returns the Account username
|
||||
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
|
||||
func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
@ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool {
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
|
||||
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
return validateToken(p, s.AccessToken, nil)
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
||||
// do nothing if a refresh is not required
|
||||
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
||||
func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
@ -4,12 +4,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRefresh(t *testing.T) {
|
||||
p := &ProviderData{}
|
||||
refreshed, err := p.RefreshSessionIfNeeded(&SessionState{
|
||||
refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{
|
||||
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
|
||||
})
|
||||
assert.Equal(t, false, refreshed)
|
||||
|
@ -2,20 +2,21 @@ package providers
|
||||
|
||||
import (
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// Provider represents an upstream identity provider implementation
|
||||
type Provider interface {
|
||||
Data() *ProviderData
|
||||
GetEmailAddress(*SessionState) (string, error)
|
||||
GetUserName(*SessionState) (string, error)
|
||||
Redeem(string, string) (*SessionState, error)
|
||||
GetEmailAddress(*sessions.SessionState) (string, error)
|
||||
GetUserName(*sessions.SessionState) (string, error)
|
||||
Redeem(string, string) (*sessions.SessionState, error)
|
||||
ValidateGroup(string) bool
|
||||
ValidateSessionState(*SessionState) bool
|
||||
ValidateSessionState(*sessions.SessionState) bool
|
||||
GetLoginURL(redirectURI, finalRedirect string) string
|
||||
RefreshSessionIfNeeded(*SessionState) (bool, error)
|
||||
SessionFromCookie(string, *cookie.Cipher) (*SessionState, error)
|
||||
CookieForSession(*SessionState, *cookie.Cipher) (string, error)
|
||||
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
|
||||
SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error)
|
||||
CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error)
|
||||
}
|
||||
|
||||
// New provides a new Provider based on the configured provider string
|
||||
|
Loading…
Reference in New Issue
Block a user