Merge pull request #147 from pusher/session-store
Add initial session-store interface and implementation
This commit is contained in:
commit
17e97ab884
@ -10,6 +10,10 @@
|
|||||||
|
|
||||||
## Changes since v3.2.0
|
## Changes since v3.2.0
|
||||||
|
|
||||||
|
- [#147](https://github.com/pusher/outh2_proxy/pull/147) Add SessionStore interfaces and initial implementation (@JoelSpeed)
|
||||||
|
- Allows for multiple different session storage implementations including client and server side
|
||||||
|
- Adds tests suite for interface to ensure consistency across implementations
|
||||||
|
- Refactor some configuration options (around cookies) into packages
|
||||||
- [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings)
|
- [#114](https://github.com/pusher/oauth2_proxy/pull/114), [#154](https://github.com/pusher/oauth2_proxy/pull/154) Documentation is now available live at our [docs website](https://pusher.github.io/oauth2_proxy/) (@JoelSpeed, @icelynjennings)
|
||||||
- [#146](https://github.com/pusher/oauth2_proxy/pull/146) Use full email address as `User` if the auth response did not contain a `User` field (@gargath)
|
- [#146](https://github.com/pusher/oauth2_proxy/pull/146) Use full email address as `User` if the auth response did not contain a `User` field (@gargath)
|
||||||
- [#144](https://github.com/pusher/oauth2_proxy/pull/144) Use GO 1.12 for ARM builds (@kskewes)
|
- [#144](https://github.com/pusher/oauth2_proxy/pull/144) Use GO 1.12 for ARM builds (@kskewes)
|
||||||
|
129
Gopkg.lock
generated
129
Gopkg.lock
generated
@ -57,6 +57,20 @@
|
|||||||
pruneopts = ""
|
pruneopts = ""
|
||||||
revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845"
|
revision = "1e59b77b52bf8e4b449a57e6f79f21226d571845"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
digest = "1:b3c5b95e56c06f5aa72cb2500e6ee5f44fcd122872d4fec2023a488e561218bc"
|
||||||
|
name = "github.com/hpcloud/tail"
|
||||||
|
packages = [
|
||||||
|
".",
|
||||||
|
"ratelimiter",
|
||||||
|
"util",
|
||||||
|
"watch",
|
||||||
|
"winfile",
|
||||||
|
]
|
||||||
|
pruneopts = ""
|
||||||
|
revision = "a30252cb686a21eb2d0b98132633053ec2f7f1e5"
|
||||||
|
version = "v1.0.0"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f"
|
digest = "1:af67386ca553c04c6222f7b5b2f17bc97a5dfb3b81b706882c7fd8c72c30cf8f"
|
||||||
name = "github.com/mbland/hmacauth"
|
name = "github.com/mbland/hmacauth"
|
||||||
@ -73,6 +87,54 @@
|
|||||||
pruneopts = ""
|
pruneopts = ""
|
||||||
revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95"
|
revision = "20ba7d382d05facb01e02eb777af0c5f229c5c95"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
digest = "1:a3735b0978a8b53fc2ac97a6f46ca1189f0712a00df86d6ec4cf26c1a25e6d77"
|
||||||
|
name = "github.com/onsi/ginkgo"
|
||||||
|
packages = [
|
||||||
|
".",
|
||||||
|
"config",
|
||||||
|
"internal/codelocation",
|
||||||
|
"internal/containernode",
|
||||||
|
"internal/failer",
|
||||||
|
"internal/leafnodes",
|
||||||
|
"internal/remote",
|
||||||
|
"internal/spec",
|
||||||
|
"internal/spec_iterator",
|
||||||
|
"internal/specrunner",
|
||||||
|
"internal/suite",
|
||||||
|
"internal/testingtproxy",
|
||||||
|
"internal/writer",
|
||||||
|
"reporters",
|
||||||
|
"reporters/stenographer",
|
||||||
|
"reporters/stenographer/support/go-colorable",
|
||||||
|
"reporters/stenographer/support/go-isatty",
|
||||||
|
"types",
|
||||||
|
]
|
||||||
|
pruneopts = ""
|
||||||
|
revision = "eea6ad008b96acdaa524f5b409513bf062b500ad"
|
||||||
|
version = "v1.8.0"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
digest = "1:dbafce2fddb1ca331646fe2ac9c9413980368b19a60a4406a6e5861680bd73be"
|
||||||
|
name = "github.com/onsi/gomega"
|
||||||
|
packages = [
|
||||||
|
".",
|
||||||
|
"format",
|
||||||
|
"internal/assertion",
|
||||||
|
"internal/asyncassertion",
|
||||||
|
"internal/oraclematcher",
|
||||||
|
"internal/testingtsupport",
|
||||||
|
"matchers",
|
||||||
|
"matchers/support/goraph/bipartitegraph",
|
||||||
|
"matchers/support/goraph/edge",
|
||||||
|
"matchers/support/goraph/node",
|
||||||
|
"matchers/support/goraph/util",
|
||||||
|
"types",
|
||||||
|
]
|
||||||
|
pruneopts = ""
|
||||||
|
revision = "90e289841c1ed79b7a598a7cd9959750cb5e89e2"
|
||||||
|
version = "v1.5.0"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
|
digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411"
|
||||||
name = "github.com/pmezard/go-difflib"
|
name = "github.com/pmezard/go-difflib"
|
||||||
@ -131,6 +193,9 @@
|
|||||||
packages = [
|
packages = [
|
||||||
"context",
|
"context",
|
||||||
"context/ctxhttp",
|
"context/ctxhttp",
|
||||||
|
"html",
|
||||||
|
"html/atom",
|
||||||
|
"html/charset",
|
||||||
"websocket",
|
"websocket",
|
||||||
]
|
]
|
||||||
pruneopts = ""
|
pruneopts = ""
|
||||||
@ -150,6 +215,42 @@
|
|||||||
pruneopts = ""
|
pruneopts = ""
|
||||||
revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402"
|
revision = "9ff8ebcc8e241d46f52ecc5bff0e5a2f2dbef402"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
digest = "1:67a6e61e60283fd7dce50eba228080bff8805d9d69b2f121d7ec2260d120c4a8"
|
||||||
|
name = "golang.org/x/sys"
|
||||||
|
packages = ["unix"]
|
||||||
|
pruneopts = ""
|
||||||
|
revision = "ca7f33d4116e3a1f9425755d6a44e7ed9b4c97df"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
digest = "1:740b51a55815493a8d0f2b1e0d0ae48fe48953bf7eaf3fcc4198823bf67768c0"
|
||||||
|
name = "golang.org/x/text"
|
||||||
|
packages = [
|
||||||
|
"encoding",
|
||||||
|
"encoding/charmap",
|
||||||
|
"encoding/htmlindex",
|
||||||
|
"encoding/internal",
|
||||||
|
"encoding/internal/identifier",
|
||||||
|
"encoding/japanese",
|
||||||
|
"encoding/korean",
|
||||||
|
"encoding/simplifiedchinese",
|
||||||
|
"encoding/traditionalchinese",
|
||||||
|
"encoding/unicode",
|
||||||
|
"internal/gen",
|
||||||
|
"internal/language",
|
||||||
|
"internal/language/compact",
|
||||||
|
"internal/tag",
|
||||||
|
"internal/utf8internal",
|
||||||
|
"language",
|
||||||
|
"runes",
|
||||||
|
"transform",
|
||||||
|
"unicode/cldr",
|
||||||
|
]
|
||||||
|
pruneopts = ""
|
||||||
|
revision = "342b2e1fbaa52c93f31447ad2c6abc048c63e475"
|
||||||
|
version = "v0.3.2"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed"
|
digest = "1:dc1fb726dbbe79c86369941eae1e3b431b8fc6f11dbd37f7899dc758a43cc3ed"
|
||||||
@ -182,6 +283,15 @@
|
|||||||
revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a"
|
revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a"
|
||||||
version = "v1.0.0"
|
version = "v1.0.0"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
digest = "1:eb53021a8aa3f599d29c7102e65026242bdedce998a54837dc67f14b6a97c5fd"
|
||||||
|
name = "gopkg.in/fsnotify.v1"
|
||||||
|
packages = ["."]
|
||||||
|
pruneopts = ""
|
||||||
|
revision = "c2828203cd70a50dcccfb2761f8b1f8ceef9a8e9"
|
||||||
|
source = "https://github.com/fsnotify/fsnotify.git"
|
||||||
|
version = "v1.4.7"
|
||||||
|
|
||||||
[[projects]]
|
[[projects]]
|
||||||
digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2"
|
digest = "1:cb5b2a45a3dd41c01ff779c54ae4c8aab0271d6d3b3f734c8a8bd2c890299ef2"
|
||||||
name = "gopkg.in/fsnotify/fsnotify.v1"
|
name = "gopkg.in/fsnotify/fsnotify.v1"
|
||||||
@ -210,6 +320,22 @@
|
|||||||
revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1"
|
revision = "f8f38de21b4dcd69d0413faf231983f5fd6634b1"
|
||||||
version = "v2.1.3"
|
version = "v2.1.3"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "v1"
|
||||||
|
digest = "1:a96d16bd088460f2e0685d46c39bcf1208ba46e0a977be2df49864ec7da447dd"
|
||||||
|
name = "gopkg.in/tomb.v1"
|
||||||
|
packages = ["."]
|
||||||
|
pruneopts = ""
|
||||||
|
revision = "dd632973f1e7218eb1089048e0798ec9ae7dceb8"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
digest = "1:cedccf16b71e86db87a24f8d4c70b0a855872eb967cb906a66b95de56aefbd0d"
|
||||||
|
name = "gopkg.in/yaml.v2"
|
||||||
|
packages = ["."]
|
||||||
|
pruneopts = ""
|
||||||
|
revision = "51d6538a90f86fe93ac480b35f37b2be17fef232"
|
||||||
|
version = "v2.2.2"
|
||||||
|
|
||||||
[solve-meta]
|
[solve-meta]
|
||||||
analyzer-name = "dep"
|
analyzer-name = "dep"
|
||||||
analyzer-version = 1
|
analyzer-version = 1
|
||||||
@ -220,6 +346,8 @@
|
|||||||
"github.com/dgrijalva/jwt-go",
|
"github.com/dgrijalva/jwt-go",
|
||||||
"github.com/mbland/hmacauth",
|
"github.com/mbland/hmacauth",
|
||||||
"github.com/mreiferson/go-options",
|
"github.com/mreiferson/go-options",
|
||||||
|
"github.com/onsi/ginkgo",
|
||||||
|
"github.com/onsi/gomega",
|
||||||
"github.com/stretchr/testify/assert",
|
"github.com/stretchr/testify/assert",
|
||||||
"github.com/stretchr/testify/require",
|
"github.com/stretchr/testify/require",
|
||||||
"github.com/yhat/wsutil",
|
"github.com/yhat/wsutil",
|
||||||
@ -231,6 +359,7 @@
|
|||||||
"google.golang.org/api/googleapi",
|
"google.golang.org/api/googleapi",
|
||||||
"gopkg.in/fsnotify/fsnotify.v1",
|
"gopkg.in/fsnotify/fsnotify.v1",
|
||||||
"gopkg.in/natefinch/lumberjack.v2",
|
"gopkg.in/natefinch/lumberjack.v2",
|
||||||
|
"gopkg.in/square/go-jose.v2",
|
||||||
]
|
]
|
||||||
solver-name = "gps-cdcl"
|
solver-name = "gps-cdcl"
|
||||||
solver-version = 1
|
solver-version = 1
|
||||||
|
@ -35,6 +35,10 @@
|
|||||||
name = "gopkg.in/fsnotify/fsnotify.v1"
|
name = "gopkg.in/fsnotify/fsnotify.v1"
|
||||||
version = "~1.2.0"
|
version = "~1.2.0"
|
||||||
|
|
||||||
|
[[override]]
|
||||||
|
name = "gopkg.in/fsnotify.v1"
|
||||||
|
source = "https://github.com/fsnotify/fsnotify.git"
|
||||||
|
|
||||||
[[constraint]]
|
[[constraint]]
|
||||||
branch = "master"
|
branch = "master"
|
||||||
name = "golang.org/x/crypto"
|
name = "golang.org/x/crypto"
|
||||||
|
1
Makefile
1
Makefile
@ -33,6 +33,7 @@ lint: $(GOMETALINTER)
|
|||||||
--enable=deadcode \
|
--enable=deadcode \
|
||||||
--enable=gofmt \
|
--enable=gofmt \
|
||||||
--enable=goimports \
|
--enable=goimports \
|
||||||
|
--deadline=120s \
|
||||||
--tests ./...
|
--tests ./...
|
||||||
|
|
||||||
.PHONY: dep
|
.PHONY: dep
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
---
|
---
|
||||||
layout: default
|
layout: default
|
||||||
title: Configuration
|
title: Configuration
|
||||||
permalink: /configuration
|
permalink: /docs/configuration
|
||||||
|
has_children: true
|
||||||
nav_order: 3
|
nav_order: 3
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -78,6 +79,7 @@ Usage of oauth2_proxy:
|
|||||||
-request-logging-format: Template for request log lines (see "Logging Configuration" paragraph below)
|
-request-logging-format: Template for request log lines (see "Logging Configuration" paragraph below)
|
||||||
-resource string: The resource that is protected (Azure AD only)
|
-resource string: The resource that is protected (Azure AD only)
|
||||||
-scope string: OAuth scope specification
|
-scope string: OAuth scope specification
|
||||||
|
-session-store-type: Session data storage backend (default: cookie)
|
||||||
-set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)
|
-set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)
|
||||||
-set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode)
|
-set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode)
|
||||||
-signature-key string: GAP-Signature request signature key (algorithm:secretkey)
|
-signature-key string: GAP-Signature request signature key (algorithm:secretkey)
|
34
docs/configuration/sessions.md
Normal file
34
docs/configuration/sessions.md
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
---
|
||||||
|
layout: default
|
||||||
|
title: Sessions
|
||||||
|
permalink: /configuration
|
||||||
|
parent: Configuration
|
||||||
|
nav_order: 3
|
||||||
|
---
|
||||||
|
|
||||||
|
## Sessions
|
||||||
|
|
||||||
|
Sessions allow a user's authentication to be tracked between multiple HTTP
|
||||||
|
requests to a service.
|
||||||
|
|
||||||
|
The OAuth2 Proxy uses a Cookie to track user sessions and will store the session
|
||||||
|
data in one of the available session storage backends.
|
||||||
|
|
||||||
|
At present the available backends are (as passed to `--session-store-type`):
|
||||||
|
- [cookie](cookie-storage) (deafult)
|
||||||
|
|
||||||
|
### Cookie Storage
|
||||||
|
|
||||||
|
The Cookie storage backend is the default backend implementation and has
|
||||||
|
been used in the OAuth2 Proxy historically.
|
||||||
|
|
||||||
|
With the Cookie storage backend, all session information is stored in client
|
||||||
|
side cookies and transferred with each and every request.
|
||||||
|
|
||||||
|
The following should be known when using this implementation:
|
||||||
|
- Since all state is stored client side, this storage backend means that the OAuth2 Proxy is completely stateless
|
||||||
|
- Cookies are signed server side to prevent modification client-side
|
||||||
|
- It is recommended to set a `cookie-secret` which will ensure data is encrypted within the cookie data.
|
||||||
|
- Since multiple requests can be made concurrently to the OAuth2 Proxy, this session implementation
|
||||||
|
cannot lock sessions and while updating and refreshing sessions, there can be conflicts which force
|
||||||
|
users to re-authenticate
|
@ -15,14 +15,27 @@ type EnvOptions map[string]interface{}
|
|||||||
// Fields in the options struct must have an `env` and `cfg` tag to be read
|
// Fields in the options struct must have an `env` and `cfg` tag to be read
|
||||||
// from the environment
|
// from the environment
|
||||||
func (cfg EnvOptions) LoadEnvForStruct(options interface{}) {
|
func (cfg EnvOptions) LoadEnvForStruct(options interface{}) {
|
||||||
val := reflect.ValueOf(options).Elem()
|
val := reflect.ValueOf(options)
|
||||||
typ := val.Type()
|
var typ reflect.Type
|
||||||
|
if val.Kind() == reflect.Ptr {
|
||||||
|
typ = val.Elem().Type()
|
||||||
|
} else {
|
||||||
|
typ = val.Type()
|
||||||
|
}
|
||||||
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
// pull out the struct tags:
|
// pull out the struct tags:
|
||||||
// flag - the name of the command line flag
|
// flag - the name of the command line flag
|
||||||
// deprecated - (optional) the name of the deprecated command line flag
|
// deprecated - (optional) the name of the deprecated command line flag
|
||||||
// cfg - (optional, defaults to underscored flag) the name of the config file option
|
// cfg - (optional, defaults to underscored flag) the name of the config file option
|
||||||
field := typ.Field(i)
|
field := typ.Field(i)
|
||||||
|
fieldV := reflect.Indirect(val).Field(i)
|
||||||
|
|
||||||
|
if field.Type.Kind() == reflect.Struct && field.Anonymous {
|
||||||
|
cfg.LoadEnvForStruct(fieldV.Interface())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
flagName := field.Tag.Get("flag")
|
flagName := field.Tag.Get("flag")
|
||||||
envName := field.Tag.Get("env")
|
envName := field.Tag.Get("env")
|
||||||
cfgName := field.Tag.Get("cfg")
|
cfgName := field.Tag.Get("cfg")
|
||||||
|
@ -1,26 +1,46 @@
|
|||||||
package main
|
package main_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
proxy "github.com/pusher/oauth2_proxy"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
type envTest struct {
|
type EnvTest struct {
|
||||||
testField string `cfg:"target_field" env:"TEST_ENV_FIELD"`
|
TestField string `cfg:"target_field" env:"TEST_ENV_FIELD"`
|
||||||
|
EnvTestEmbed
|
||||||
|
}
|
||||||
|
|
||||||
|
type EnvTestEmbed struct {
|
||||||
|
TestFieldEmbed string `cfg:"target_field_embed" env:"TEST_ENV_FIELD_EMBED"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadEnvForStruct(t *testing.T) {
|
func TestLoadEnvForStruct(t *testing.T) {
|
||||||
|
|
||||||
cfg := make(EnvOptions)
|
cfg := make(proxy.EnvOptions)
|
||||||
cfg.LoadEnvForStruct(&envTest{})
|
cfg.LoadEnvForStruct(&EnvTest{})
|
||||||
|
|
||||||
_, ok := cfg["target_field"]
|
_, ok := cfg["target_field"]
|
||||||
assert.Equal(t, ok, false)
|
assert.Equal(t, ok, false)
|
||||||
|
|
||||||
os.Setenv("TEST_ENV_FIELD", "1234abcd")
|
os.Setenv("TEST_ENV_FIELD", "1234abcd")
|
||||||
cfg.LoadEnvForStruct(&envTest{})
|
cfg.LoadEnvForStruct(&EnvTest{})
|
||||||
v := cfg["target_field"]
|
v := cfg["target_field"]
|
||||||
assert.Equal(t, v, "1234abcd")
|
assert.Equal(t, v, "1234abcd")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLoadEnvForStructWithEmbeddedFields(t *testing.T) {
|
||||||
|
|
||||||
|
cfg := make(proxy.EnvOptions)
|
||||||
|
cfg.LoadEnvForStruct(&EnvTest{})
|
||||||
|
|
||||||
|
_, ok := cfg["target_field_embed"]
|
||||||
|
assert.Equal(t, ok, false)
|
||||||
|
|
||||||
|
os.Setenv("TEST_ENV_FIELD_EMBED", "1234abcd")
|
||||||
|
cfg.LoadEnvForStruct(&EnvTest{})
|
||||||
|
v := cfg["target_field_embed"]
|
||||||
|
assert.Equal(t, v, "1234abcd")
|
||||||
|
}
|
||||||
|
2
main.go
2
main.go
@ -75,6 +75,8 @@ func main() {
|
|||||||
flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag")
|
flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag")
|
||||||
flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag")
|
flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag")
|
||||||
|
|
||||||
|
flagSet.String("session-store-type", "cookie", "the session storage provider to use")
|
||||||
|
|
||||||
flagSet.String("logging-filename", "", "File to log requests to, empty for stdout")
|
flagSet.String("logging-filename", "", "File to log requests to, empty for stdout")
|
||||||
flagSet.Int("logging-max-size", 100, "Maximum size in megabytes of the log file before rotation")
|
flagSet.Int("logging-max-size", 100, "Maximum size in megabytes of the log file before rotation")
|
||||||
flagSet.Int("logging-max-age", 7, "Maximum number of days to retain old log files")
|
flagSet.Int("logging-max-age", 7, "Maximum number of days to retain old log files")
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/mbland/hmacauth"
|
"github.com/mbland/hmacauth"
|
||||||
"github.com/pusher/oauth2_proxy/cookie"
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/pusher/oauth2_proxy/providers"
|
"github.com/pusher/oauth2_proxy/providers"
|
||||||
"github.com/yhat/wsutil"
|
"github.com/yhat/wsutil"
|
||||||
)
|
)
|
||||||
@ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
|
|||||||
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
|
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) {
|
func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, errors.New("missing code")
|
return nil, errors.New("missing code")
|
||||||
}
|
}
|
||||||
@ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LoadCookiedSession reads the user's authentication details from the request
|
// LoadCookiedSession reads the user's authentication details from the request
|
||||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
|
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) {
|
||||||
var age time.Duration
|
var age time.Duration
|
||||||
c, err := loadCookie(req, p.CookieName)
|
c, err := loadCookie(req, p.CookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SaveSession creates a new session cookie value and sets this on the response
|
// SaveSession creates a new session cookie value and sets this on the response
|
||||||
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
|
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
||||||
value, err := p.provider.CookieForSession(s, p.CookieCipher)
|
value, err := p.provider.CookieForSession(s, p.CookieCipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
user, ok := p.ManualSignIn(rw, req)
|
user, ok := p.ManualSignIn(rw, req)
|
||||||
if ok {
|
if ok {
|
||||||
session := &providers.SessionState{User: user}
|
session := &sessions.SessionState{User: user}
|
||||||
p.SaveSession(rw, req, session)
|
p.SaveSession(rw, req, session)
|
||||||
http.Redirect(rw, req, redirect, 302)
|
http.Redirect(rw, req, redirect, 302)
|
||||||
} else {
|
} else {
|
||||||
@ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
|||||||
|
|
||||||
// CheckBasicAuth checks the requests Authorization header for basic auth
|
// CheckBasicAuth checks the requests Authorization header for basic auth
|
||||||
// credentials and authenticates these against the proxies HtpasswdFile
|
// credentials and authenticates these against the proxies HtpasswdFile
|
||||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
|
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) {
|
||||||
if p.HtpasswdFile == nil {
|
if p.HtpasswdFile == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState,
|
|||||||
}
|
}
|
||||||
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
|
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
|
||||||
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
||||||
return &providers.SessionState{User: pair[0]}, nil
|
return &sessions.SessionState{User: pair[0]}, nil
|
||||||
}
|
}
|
||||||
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mbland/hmacauth"
|
"github.com/mbland/hmacauth"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/pusher/oauth2_proxy/providers"
|
"github.com/pusher/oauth2_proxy/providers"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) {
|
func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) {
|
||||||
return tp.EmailAddress, nil
|
return tp.EmailAddress, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool {
|
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
|
||||||
return tp.ValidToken
|
return tp.ValidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook
|
|||||||
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
|
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error {
|
||||||
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
|
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) {
|
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) {
|
||||||
return p.proxy.LoadCookiedSession(p.req)
|
return p.proxy.LoadCookiedSession(p.req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadCookiedSession(t *testing.T) {
|
func TestLoadCookiedSession(t *testing.T) {
|
||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithDefaults()
|
||||||
|
|
||||||
startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
||||||
pcTest.SaveSession(startSession, time.Now())
|
pcTest.SaveSession(startSession, time.Now())
|
||||||
|
|
||||||
session, _, err := pcTest.LoadCookiedSession()
|
session, _, err := pcTest.LoadCookiedSession()
|
||||||
@ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
|
|||||||
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
|
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||||
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||||
|
|
||||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
pcTest.SaveSession(startSession, reference)
|
pcTest.SaveSession(startSession, reference)
|
||||||
|
|
||||||
session, age, err := pcTest.LoadCookiedSession()
|
session, age, err := pcTest.LoadCookiedSession()
|
||||||
@ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
|||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithDefaults()
|
||||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
pcTest.SaveSession(startSession, reference)
|
pcTest.SaveSession(startSession, reference)
|
||||||
|
|
||||||
session, _, err := pcTest.LoadCookiedSession()
|
session, _, err := pcTest.LoadCookiedSession()
|
||||||
@ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
|||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithDefaults()
|
||||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
pcTest.SaveSession(startSession, reference)
|
pcTest.SaveSession(startSession, reference)
|
||||||
|
|
||||||
pcTest.proxy.CookieRefresh = time.Hour
|
pcTest.proxy.CookieRefresh = time.Hour
|
||||||
@ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest {
|
|||||||
|
|
||||||
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
||||||
test := NewAuthOnlyEndpointTest()
|
test := NewAuthOnlyEndpointTest()
|
||||||
startSession := &providers.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
test.SaveSession(startSession, time.Now())
|
test.SaveSession(startSession, time.Now())
|
||||||
|
|
||||||
@ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
|||||||
test := NewAuthOnlyEndpointTest()
|
test := NewAuthOnlyEndpointTest()
|
||||||
test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||||
startSession := &providers.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
test.SaveSession(startSession, reference)
|
test.SaveSession(startSession, reference)
|
||||||
|
|
||||||
@ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
|||||||
|
|
||||||
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||||
test := NewAuthOnlyEndpointTest()
|
test := NewAuthOnlyEndpointTest()
|
||||||
startSession := &providers.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
test.SaveSession(startSession, time.Now())
|
test.SaveSession(startSession, time.Now())
|
||||||
test.validateUser = false
|
test.validateUser = false
|
||||||
@ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
|||||||
pcTest.req, _ = http.NewRequest("GET",
|
pcTest.req, _ = http.NewRequest("GET",
|
||||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||||
|
|
||||||
startSession := &providers.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
|
||||||
pcTest.SaveSession(startSession, time.Now())
|
pcTest.SaveSession(startSession, time.Now())
|
||||||
|
|
||||||
@ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
|
|||||||
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
|
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
|
||||||
req.Header = st.header
|
req.Header = st.header
|
||||||
|
|
||||||
state := &providers.SessionState{
|
state := &sessions.SessionState{
|
||||||
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
||||||
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
|
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
36
options.go
36
options.go
@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/dgrijalva/jwt-go"
|
"github.com/dgrijalva/jwt-go"
|
||||||
"github.com/mbland/hmacauth"
|
"github.com/mbland/hmacauth"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||||
"github.com/pusher/oauth2_proxy/providers"
|
"github.com/pusher/oauth2_proxy/providers"
|
||||||
"gopkg.in/natefinch/lumberjack.v2"
|
"gopkg.in/natefinch/lumberjack.v2"
|
||||||
)
|
)
|
||||||
@ -49,14 +50,11 @@ type Options struct {
|
|||||||
CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"`
|
CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"`
|
||||||
Footer string `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"`
|
Footer string `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"`
|
||||||
|
|
||||||
CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"`
|
// Embed CookieOptions
|
||||||
CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"`
|
options.CookieOptions
|
||||||
CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"`
|
|
||||||
CookiePath string `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"`
|
// Embed SessionOptions
|
||||||
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"`
|
options.SessionOptions
|
||||||
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
|
|
||||||
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"`
|
|
||||||
CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"`
|
|
||||||
|
|
||||||
Upstreams []string `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"`
|
Upstreams []string `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"`
|
||||||
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"`
|
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"`
|
||||||
@ -126,16 +124,18 @@ type SignatureData struct {
|
|||||||
// NewOptions constructs a new Options with defaulted values
|
// NewOptions constructs a new Options with defaulted values
|
||||||
func NewOptions() *Options {
|
func NewOptions() *Options {
|
||||||
return &Options{
|
return &Options{
|
||||||
ProxyPrefix: "/oauth2",
|
ProxyPrefix: "/oauth2",
|
||||||
ProxyWebSockets: true,
|
ProxyWebSockets: true,
|
||||||
HTTPAddress: "127.0.0.1:4180",
|
HTTPAddress: "127.0.0.1:4180",
|
||||||
HTTPSAddress: ":443",
|
HTTPSAddress: ":443",
|
||||||
DisplayHtpasswdForm: true,
|
DisplayHtpasswdForm: true,
|
||||||
CookieName: "_oauth2_proxy",
|
CookieOptions: options.CookieOptions{
|
||||||
CookieSecure: true,
|
CookieName: "_oauth2_proxy",
|
||||||
CookieHTTPOnly: true,
|
CookieSecure: true,
|
||||||
CookieExpire: time.Duration(168) * time.Hour,
|
CookieHTTPOnly: true,
|
||||||
CookieRefresh: time.Duration(0),
|
CookieExpire: time.Duration(168) * time.Hour,
|
||||||
|
CookieRefresh: time.Duration(0),
|
||||||
|
},
|
||||||
SetXAuthRequest: false,
|
SetXAuthRequest: false,
|
||||||
SkipAuthPreflight: false,
|
SkipAuthPreflight: false,
|
||||||
PassBasicAuth: true,
|
PassBasicAuth: true,
|
||||||
|
15
pkg/apis/options/cookie.go
Normal file
15
pkg/apis/options/cookie.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package options
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// CookieOptions contains configuration options relating to Cookie configuration
|
||||||
|
type CookieOptions struct {
|
||||||
|
CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"`
|
||||||
|
CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"`
|
||||||
|
CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"`
|
||||||
|
CookiePath string `flag:"cookie-path" cfg:"cookie_path" env:"OAUTH2_PROXY_COOKIE_PATH"`
|
||||||
|
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"`
|
||||||
|
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
|
||||||
|
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"`
|
||||||
|
CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"`
|
||||||
|
}
|
14
pkg/apis/options/sessions.go
Normal file
14
pkg/apis/options/sessions.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package options
|
||||||
|
|
||||||
|
// SessionOptions contains configuration options for the SessionStore providers.
|
||||||
|
type SessionOptions struct {
|
||||||
|
Type string `flag:"session-store-type" cfg:"session_store_type" env:"OAUTH2_PROXY_SESSION_STORE_TYPE"`
|
||||||
|
CookieStoreOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
// CookieSessionStoreType is used to indicate the CookieSessionStore should be
|
||||||
|
// used for storing sessions.
|
||||||
|
var CookieSessionStoreType = "cookie"
|
||||||
|
|
||||||
|
// CookieStoreOptions contains configuration options for the CookieSessionStore.
|
||||||
|
type CookieStoreOptions struct{}
|
12
pkg/apis/sessions/interfaces.go
Normal file
12
pkg/apis/sessions/interfaces.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package sessions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionStore is an interface to storing user sessions in the proxy
|
||||||
|
type SessionStore interface {
|
||||||
|
Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
|
||||||
|
Load(req *http.Request) (*SessionState, error)
|
||||||
|
Clear(rw http.ResponseWriter, req *http.Request) error
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package sessions
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package sessions_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/cookie"
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
c2, err := cookie.NewCipher([]byte(altSecret))
|
c2, err := cookie.NewCipher([]byte(altSecret))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
s := &SessionState{
|
s := &sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
IDToken: "rawtoken1234",
|
IDToken: "rawtoken1234",
|
||||||
@ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
encoded, err := s.EncodeSessionState(c)
|
encoded, err := s.EncodeSessionState(c)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
ss, err := DecodeSessionState(encoded, c)
|
ss, err := sessions.DecodeSessionState(encoded, c)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@domain.com", ss.User)
|
assert.Equal(t, "user@domain.com", ss.User)
|
||||||
@ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
|
||||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||||
ss, err = DecodeSessionState(encoded, c2)
|
ss, err = sessions.DecodeSessionState(encoded, c2)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, "user@domain.com", ss.User)
|
assert.NotEqual(t, "user@domain.com", ss.User)
|
||||||
@ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
c2, err := cookie.NewCipher([]byte(altSecret))
|
c2, err := cookie.NewCipher([]byte(altSecret))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
s := &SessionState{
|
s := &sessions.SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
encoded, err := s.EncodeSessionState(c)
|
encoded, err := s.EncodeSessionState(c)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
ss, err := DecodeSessionState(encoded, c)
|
ss, err := sessions.DecodeSessionState(encoded, c)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, s.User, ss.User)
|
assert.Equal(t, s.User, ss.User)
|
||||||
@ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
|
||||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||||
ss, err = DecodeSessionState(encoded, c2)
|
ss, err = sessions.DecodeSessionState(encoded, c2)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, s.User, ss.User)
|
assert.NotEqual(t, s.User, ss.User)
|
||||||
@ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||||
s := &SessionState{
|
s := &sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||||
@ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
// only email should have been serialized
|
// only email should have been serialized
|
||||||
ss, err := DecodeSessionState(encoded, nil)
|
ss, err := sessions.DecodeSessionState(encoded, nil)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@domain.com", ss.User)
|
assert.Equal(t, "user@domain.com", ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
@ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||||
s := &SessionState{
|
s := &sessions.SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
// only email should have been serialized
|
// only email should have been serialized
|
||||||
ss, err := DecodeSessionState(encoded, nil)
|
ss, err := sessions.DecodeSessionState(encoded, nil)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, s.User, ss.User)
|
assert.Equal(t, s.User, ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
@ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestExpired(t *testing.T) {
|
func TestExpired(t *testing.T) {
|
||||||
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
|
s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
|
||||||
assert.Equal(t, true, s.IsExpired())
|
assert.Equal(t, true, s.IsExpired())
|
||||||
|
|
||||||
s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
|
s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
|
||||||
assert.Equal(t, false, s.IsExpired())
|
assert.Equal(t, false, s.IsExpired())
|
||||||
|
|
||||||
s = &SessionState{}
|
s = &sessions.SessionState{}
|
||||||
assert.Equal(t, false, s.IsExpired())
|
assert.Equal(t, false, s.IsExpired())
|
||||||
}
|
}
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
SessionState
|
sessions.SessionState
|
||||||
Encoded string
|
Encoded string
|
||||||
Cipher *cookie.Cipher
|
Cipher *cookie.Cipher
|
||||||
Error bool
|
Error bool
|
||||||
@ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
|
|||||||
|
|
||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
||||||
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||||
if tc.Error {
|
if tc.Error {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Empty(t, encoded)
|
assert.Empty(t, encoded)
|
||||||
@ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDecodeSessionState tests DecodeSessionState with the test vector
|
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
|
||||||
func TestDecodeSessionState(t *testing.T) {
|
func TestDecodeSessionState(t *testing.T) {
|
||||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||||
eJSON, _ := e.MarshalJSON()
|
eJSON, _ := e.MarshalJSON()
|
||||||
@ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "user@domain.com",
|
User: "user@domain.com",
|
||||||
},
|
},
|
||||||
Encoded: `{"Email":"user@domain.com"}`,
|
Encoded: `{"Email":"user@domain.com"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: `{"User":"just-user"}`,
|
Encoded: `{"User":"just-user"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Cipher: c,
|
Cipher: c,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
@ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Error: true,
|
Error: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
},
|
},
|
||||||
@ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Error: true,
|
Error: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Cipher: c,
|
Cipher: c,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
|
ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||||
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||||
if tc.Error {
|
if tc.Error {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, ss)
|
assert.Nil(t, ss)
|
34
pkg/cookies/cookies.go
Normal file
34
pkg/cookies/cookies.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package cookies
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MakeCookie constructs a cookie from the given parameters,
|
||||||
|
// discovering the domain from the request if not specified.
|
||||||
|
func MakeCookie(req *http.Request, name string, value string, path string, domain string, httpOnly bool, secure bool, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
|
if domain != "" {
|
||||||
|
host := req.Host
|
||||||
|
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
host = h
|
||||||
|
}
|
||||||
|
if !strings.HasSuffix(host, domain) {
|
||||||
|
logger.Printf("Warning: request host is %q but using configured cookie domain of %q", host, domain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: path,
|
||||||
|
Domain: domain,
|
||||||
|
HttpOnly: httpOnly,
|
||||||
|
Secure: secure,
|
||||||
|
Expires: now.Add(expiration),
|
||||||
|
}
|
||||||
|
}
|
232
pkg/sessions/cookie/session_store.go
Normal file
232
pkg/sessions/cookie/session_store.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
package cookie
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/cookies"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/sessions/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Cookies are limited to 4kb including the length of the cookie name,
|
||||||
|
// the cookie name can be up to 256 bytes
|
||||||
|
maxCookieLength = 3840
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensure CookieSessionStore implements the interface
|
||||||
|
var _ sessions.SessionStore = &SessionStore{}
|
||||||
|
|
||||||
|
// SessionStore is an implementation of the sessions.SessionStore
|
||||||
|
// interface that stores sessions in client side cookies
|
||||||
|
type SessionStore struct {
|
||||||
|
CookieCipher *cookie.Cipher
|
||||||
|
CookieDomain string
|
||||||
|
CookieExpire time.Duration
|
||||||
|
CookieHTTPOnly bool
|
||||||
|
CookieName string
|
||||||
|
CookiePath string
|
||||||
|
CookieSecret string
|
||||||
|
CookieSecure bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save takes a sessions.SessionState and stores the information from it
|
||||||
|
// within Cookies set on the HTTP response writer
|
||||||
|
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
|
||||||
|
value, err := utils.CookieForSession(ss, s.CookieCipher)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.setSessionCookie(rw, req, value)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load reads sessions.SessionState information from Cookies within the
|
||||||
|
// HTTP request object
|
||||||
|
func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
|
||||||
|
c, err := loadCookie(req, s.CookieName)
|
||||||
|
if err != nil {
|
||||||
|
// always http.ErrNoCookie
|
||||||
|
return nil, fmt.Errorf("Cookie %q not present", s.CookieName)
|
||||||
|
}
|
||||||
|
val, _, ok := cookie.Validate(c, s.CookieSecret, s.CookieExpire)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("Cookie Signature not valid")
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := utils.SessionFromCookie(val, s.CookieCipher)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear clears any saved session information by writing a cookie to
|
||||||
|
// clear the session
|
||||||
|
func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
|
||||||
|
var cookies []*http.Cookie
|
||||||
|
|
||||||
|
// matches CookieName, CookieName_<number>
|
||||||
|
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.CookieName))
|
||||||
|
|
||||||
|
for _, c := range req.Cookies() {
|
||||||
|
if cookieNameRegex.MatchString(c.Name) {
|
||||||
|
clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1)
|
||||||
|
|
||||||
|
http.SetCookie(rw, clearCookie)
|
||||||
|
cookies = append(cookies, clearCookie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSessionCookie adds the user's session cookie to the response
|
||||||
|
func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||||
|
for _, c := range s.makeSessionCookie(req, val, s.CookieExpire, time.Now()) {
|
||||||
|
http.SetCookie(rw, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeSessionCookie creates an http.Cookie containing the authenticated user's
|
||||||
|
// authentication details
|
||||||
|
func (s *SessionStore) makeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie {
|
||||||
|
if value != "" {
|
||||||
|
value = cookie.SignedValue(s.CookieSecret, s.CookieName, value, now)
|
||||||
|
}
|
||||||
|
c := s.makeCookie(req, s.CookieName, value, expiration)
|
||||||
|
if len(c.Value) > 4096-len(s.CookieName) {
|
||||||
|
return splitCookie(c)
|
||||||
|
}
|
||||||
|
return []*http.Cookie{c}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration) *http.Cookie {
|
||||||
|
return cookies.MakeCookie(
|
||||||
|
req,
|
||||||
|
name,
|
||||||
|
value,
|
||||||
|
s.CookiePath,
|
||||||
|
s.CookieDomain,
|
||||||
|
s.CookieHTTPOnly,
|
||||||
|
s.CookieSecure,
|
||||||
|
expiration,
|
||||||
|
time.Now(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCookieSessionStore initialises a new instance of the SessionStore from
|
||||||
|
// the configuration given
|
||||||
|
func NewCookieSessionStore(opts options.CookieStoreOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
||||||
|
var cipher *cookie.Cipher
|
||||||
|
if len(cookieOpts.CookieSecret) > 0 {
|
||||||
|
var err error
|
||||||
|
cipher, err = cookie.NewCipher(utils.SecretBytes(cookieOpts.CookieSecret))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to create cipher: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SessionStore{
|
||||||
|
CookieCipher: cipher,
|
||||||
|
CookieDomain: cookieOpts.CookieDomain,
|
||||||
|
CookieExpire: cookieOpts.CookieExpire,
|
||||||
|
CookieHTTPOnly: cookieOpts.CookieHTTPOnly,
|
||||||
|
CookieName: cookieOpts.CookieName,
|
||||||
|
CookiePath: cookieOpts.CookiePath,
|
||||||
|
CookieSecret: cookieOpts.CookieSecret,
|
||||||
|
CookieSecure: cookieOpts.CookieSecure,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitCookie reads the full cookie generated to store the session and splits
|
||||||
|
// it into a slice of cookies which fit within the 4kb cookie limit indexing
|
||||||
|
// the cookies from 0
|
||||||
|
func splitCookie(c *http.Cookie) []*http.Cookie {
|
||||||
|
if len(c.Value) < maxCookieLength {
|
||||||
|
return []*http.Cookie{c}
|
||||||
|
}
|
||||||
|
cookies := []*http.Cookie{}
|
||||||
|
valueBytes := []byte(c.Value)
|
||||||
|
count := 0
|
||||||
|
for len(valueBytes) > 0 {
|
||||||
|
new := copyCookie(c)
|
||||||
|
new.Name = fmt.Sprintf("%s_%d", c.Name, count)
|
||||||
|
count++
|
||||||
|
if len(valueBytes) < maxCookieLength {
|
||||||
|
new.Value = string(valueBytes)
|
||||||
|
valueBytes = []byte{}
|
||||||
|
} else {
|
||||||
|
newValue := valueBytes[:maxCookieLength]
|
||||||
|
valueBytes = valueBytes[maxCookieLength:]
|
||||||
|
new.Value = string(newValue)
|
||||||
|
}
|
||||||
|
cookies = append(cookies, new)
|
||||||
|
}
|
||||||
|
return cookies
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadCookie retreieves the sessions state cookie from the http request.
|
||||||
|
// If a single cookie is present this will be returned, otherwise it attempts
|
||||||
|
// to reconstruct a cookie split up by splitCookie
|
||||||
|
func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
|
||||||
|
c, err := req.Cookie(cookieName)
|
||||||
|
if err == nil {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
cookies := []*http.Cookie{}
|
||||||
|
err = nil
|
||||||
|
count := 0
|
||||||
|
for err == nil {
|
||||||
|
var c *http.Cookie
|
||||||
|
c, err = req.Cookie(fmt.Sprintf("%s_%d", cookieName, count))
|
||||||
|
if err == nil {
|
||||||
|
cookies = append(cookies, c)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(cookies) == 0 {
|
||||||
|
return nil, fmt.Errorf("Could not find cookie %s", cookieName)
|
||||||
|
}
|
||||||
|
return joinCookies(cookies)
|
||||||
|
}
|
||||||
|
|
||||||
|
// joinCookies takes a slice of cookies from the request and reconstructs the
|
||||||
|
// full session cookie
|
||||||
|
func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
|
||||||
|
if len(cookies) == 0 {
|
||||||
|
return nil, fmt.Errorf("list of cookies must be > 0")
|
||||||
|
}
|
||||||
|
if len(cookies) == 1 {
|
||||||
|
return cookies[0], nil
|
||||||
|
}
|
||||||
|
c := copyCookie(cookies[0])
|
||||||
|
for i := 1; i < len(cookies); i++ {
|
||||||
|
c.Value += cookies[i].Value
|
||||||
|
}
|
||||||
|
c.Name = strings.TrimRight(c.Name, "_0")
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyCookie(c *http.Cookie) *http.Cookie {
|
||||||
|
return &http.Cookie{
|
||||||
|
Name: c.Name,
|
||||||
|
Value: c.Value,
|
||||||
|
Path: c.Path,
|
||||||
|
Domain: c.Domain,
|
||||||
|
Expires: c.Expires,
|
||||||
|
RawExpires: c.RawExpires,
|
||||||
|
MaxAge: c.MaxAge,
|
||||||
|
Secure: c.Secure,
|
||||||
|
HttpOnly: c.HttpOnly,
|
||||||
|
Raw: c.Raw,
|
||||||
|
Unparsed: c.Unparsed,
|
||||||
|
}
|
||||||
|
}
|
19
pkg/sessions/session_store.go
Normal file
19
pkg/sessions/session_store.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package sessions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewSessionStore creates a SessionStore from the provided configuration
|
||||||
|
func NewSessionStore(opts *options.SessionOptions, cookieOpts *options.CookieOptions) (sessions.SessionStore, error) {
|
||||||
|
switch opts.Type {
|
||||||
|
case options.CookieSessionStoreType:
|
||||||
|
return cookie.NewCookieSessionStore(opts.CookieStoreOptions, cookieOpts)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown session store type '%s'", opts.Type)
|
||||||
|
}
|
||||||
|
}
|
254
pkg/sessions/session_store_test.go
Normal file
254
pkg/sessions/session_store_test.go
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
package sessions_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||||
|
sessionsapi "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/cookies"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/sessions"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/sessions/cookie"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSessionStore(t *testing.T) {
|
||||||
|
RegisterFailHandler(Fail)
|
||||||
|
RunSpecs(t, "SessionStore")
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ = Describe("NewSessionStore", func() {
|
||||||
|
var opts *options.SessionOptions
|
||||||
|
var cookieOpts *options.CookieOptions
|
||||||
|
|
||||||
|
var request *http.Request
|
||||||
|
var response *httptest.ResponseRecorder
|
||||||
|
var session *sessionsapi.SessionState
|
||||||
|
var ss sessionsapi.SessionStore
|
||||||
|
|
||||||
|
CheckCookieOptions := func() {
|
||||||
|
Context("the cookies returned", func() {
|
||||||
|
var cookies []*http.Cookie
|
||||||
|
BeforeEach(func() {
|
||||||
|
cookies = response.Result().Cookies()
|
||||||
|
})
|
||||||
|
|
||||||
|
It("have the correct name set", func() {
|
||||||
|
if len(cookies) == 1 {
|
||||||
|
Expect(cookies[0].Name).To(Equal(cookieOpts.CookieName))
|
||||||
|
} else {
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
Expect(cookie.Name).To(ContainSubstring(cookieOpts.CookieName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("have the correct path set", func() {
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
Expect(cookie.Path).To(Equal(cookieOpts.CookiePath))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("have the correct domain set", func() {
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
Expect(cookie.Domain).To(Equal(cookieOpts.CookieDomain))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("have the correct HTTPOnly set", func() {
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
Expect(cookie.HttpOnly).To(Equal(cookieOpts.CookieHTTPOnly))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
It("have the correct secure set", func() {
|
||||||
|
for _, cookie := range cookies {
|
||||||
|
Expect(cookie.Secure).To(Equal(cookieOpts.CookieSecure))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
SessionStoreInterfaceTests := func() {
|
||||||
|
Context("when Save is called", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
err := ss.Save(response, request, session)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets a `set-cookie` header in the response", func() {
|
||||||
|
Expect(response.Header().Get("set-cookie")).ToNot(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
CheckCookieOptions()
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when Clear is called", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
cookie := cookies.MakeCookie(request,
|
||||||
|
cookieOpts.CookieName,
|
||||||
|
"foo",
|
||||||
|
cookieOpts.CookiePath,
|
||||||
|
cookieOpts.CookieDomain,
|
||||||
|
cookieOpts.CookieHTTPOnly,
|
||||||
|
cookieOpts.CookieSecure,
|
||||||
|
cookieOpts.CookieExpire,
|
||||||
|
time.Now(),
|
||||||
|
)
|
||||||
|
request.AddCookie(cookie)
|
||||||
|
err := ss.Clear(response, request)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("sets a `set-cookie` header in the response", func() {
|
||||||
|
Expect(response.Header().Get("Set-Cookie")).ToNot(BeEmpty())
|
||||||
|
})
|
||||||
|
|
||||||
|
CheckCookieOptions()
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("when Load is called", func() {
|
||||||
|
var loadedSession *sessionsapi.SessionState
|
||||||
|
BeforeEach(func() {
|
||||||
|
req := httptest.NewRequest("GET", "http://example.com/", nil)
|
||||||
|
resp := httptest.NewRecorder()
|
||||||
|
err := ss.Save(resp, req, session)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
|
||||||
|
for _, cookie := range resp.Result().Cookies() {
|
||||||
|
request.AddCookie(cookie)
|
||||||
|
}
|
||||||
|
loadedSession, err = ss.Load(request)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
It("loads a session equal to the original session", func() {
|
||||||
|
if cookieOpts.CookieSecret == "" {
|
||||||
|
// Only Email and User stored in session when encrypted
|
||||||
|
Expect(loadedSession.Email).To(Equal(session.Email))
|
||||||
|
Expect(loadedSession.User).To(Equal(session.User))
|
||||||
|
} else {
|
||||||
|
// All fields stored in session if encrypted
|
||||||
|
|
||||||
|
// Can't compare time.Time using Equal() so remove ExpiresOn from sessions
|
||||||
|
l := *loadedSession
|
||||||
|
l.ExpiresOn = time.Time{}
|
||||||
|
s := *session
|
||||||
|
s.ExpiresOn = time.Time{}
|
||||||
|
Expect(l).To(Equal(s))
|
||||||
|
|
||||||
|
// Compare time.Time separately
|
||||||
|
Expect(loadedSession.ExpiresOn.Equal(session.ExpiresOn)).To(BeTrue())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
RunSessionTests := func() {
|
||||||
|
Context("with default options", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
var err error
|
||||||
|
ss, err = sessions.NewSessionStore(opts, cookieOpts)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
SessionStoreInterfaceTests()
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with non-default options", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
cookieOpts = &options.CookieOptions{
|
||||||
|
CookieName: "_cookie_name",
|
||||||
|
CookiePath: "/path",
|
||||||
|
CookieExpire: time.Duration(72) * time.Hour,
|
||||||
|
CookieRefresh: time.Duration(3600),
|
||||||
|
CookieSecure: false,
|
||||||
|
CookieHTTPOnly: false,
|
||||||
|
CookieDomain: "example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
ss, err = sessions.NewSessionStore(opts, cookieOpts)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
SessionStoreInterfaceTests()
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with a cookie-secret set", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
secret := make([]byte, 32)
|
||||||
|
_, err := rand.Read(secret)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
cookieOpts.CookieSecret = base64.URLEncoding.EncodeToString(secret)
|
||||||
|
|
||||||
|
ss, err = sessions.NewSessionStore(opts, cookieOpts)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
})
|
||||||
|
|
||||||
|
SessionStoreInterfaceTests()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
ss = nil
|
||||||
|
opts = &options.SessionOptions{}
|
||||||
|
|
||||||
|
// Set default options in CookieOptions
|
||||||
|
cookieOpts = &options.CookieOptions{
|
||||||
|
CookieName: "_oauth2_proxy",
|
||||||
|
CookiePath: "/",
|
||||||
|
CookieExpire: time.Duration(168) * time.Hour,
|
||||||
|
CookieRefresh: time.Duration(0),
|
||||||
|
CookieSecure: true,
|
||||||
|
CookieHTTPOnly: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
session = &sessionsapi.SessionState{
|
||||||
|
AccessToken: "AccessToken",
|
||||||
|
IDToken: "IDToken",
|
||||||
|
ExpiresOn: time.Now().Add(1 * time.Hour),
|
||||||
|
RefreshToken: "RefreshToken",
|
||||||
|
Email: "john.doe@example.com",
|
||||||
|
User: "john.doe",
|
||||||
|
}
|
||||||
|
|
||||||
|
request = httptest.NewRequest("GET", "http://example.com/", nil)
|
||||||
|
response = httptest.NewRecorder()
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with type 'cookie'", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
opts.Type = options.CookieSessionStoreType
|
||||||
|
})
|
||||||
|
|
||||||
|
It("creates a cookie.SessionStore", func() {
|
||||||
|
ss, err := sessions.NewSessionStore(opts, cookieOpts)
|
||||||
|
Expect(err).NotTo(HaveOccurred())
|
||||||
|
Expect(ss).To(BeAssignableToTypeOf(&cookie.SessionStore{}))
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("the cookie.SessionStore", func() {
|
||||||
|
RunSessionTests()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Context("with an invalid type", func() {
|
||||||
|
BeforeEach(func() {
|
||||||
|
opts.Type = "invalid-type"
|
||||||
|
})
|
||||||
|
|
||||||
|
It("returns an error", func() {
|
||||||
|
ss, err := sessions.NewSessionStore(opts, cookieOpts)
|
||||||
|
Expect(err).To(HaveOccurred())
|
||||||
|
Expect(err.Error()).To(Equal("unknown session store type 'invalid-type'"))
|
||||||
|
Expect(ss).To(BeNil())
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
41
pkg/sessions/utils/utils.go
Normal file
41
pkg/sessions/utils/utils.go
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CookieForSession serializes a session state for storage in a cookie
|
||||||
|
func CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) {
|
||||||
|
return s.EncodeSessionState(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionFromCookie deserializes a session from a cookie value
|
||||||
|
func SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) {
|
||||||
|
return sessions.DecodeSessionState(v, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecretBytes attempts to base64 decode the secret, if that fails it treats the secret as binary
|
||||||
|
func SecretBytes(secret string) []byte {
|
||||||
|
b, err := base64.URLEncoding.DecodeString(addPadding(secret))
|
||||||
|
if err == nil {
|
||||||
|
return []byte(addPadding(string(b)))
|
||||||
|
}
|
||||||
|
return []byte(secret)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addPadding(secret string) string {
|
||||||
|
padding := len(secret) % 4
|
||||||
|
switch padding {
|
||||||
|
case 1:
|
||||||
|
return secret + "==="
|
||||||
|
case 2:
|
||||||
|
return secret + "=="
|
||||||
|
case 3:
|
||||||
|
return secret + "="
|
||||||
|
default:
|
||||||
|
return secret
|
||||||
|
}
|
||||||
|
}
|
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/bitly/go-simplejson"
|
"github.com/bitly/go-simplejson"
|
||||||
"github.com/pusher/oauth2_proxy/api"
|
"github.com/pusher/oauth2_proxy/api"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AzureProvider represents an Azure based Identity Provider
|
// AzureProvider represents an Azure based Identity Provider
|
||||||
@ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
var email string
|
var email string
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@windows.net", email)
|
assert.Equal(t, "user@windows.net", email)
|
||||||
@ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@windows.net", email)
|
assert.Equal(t, "user@windows.net", email)
|
||||||
@ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@windows.net", email)
|
assert.Equal(t, "user@windows.net", email)
|
||||||
@ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/api"
|
"github.com/pusher/oauth2_proxy/api"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FacebookProvider represents an Facebook based Identity Provider
|
// FacebookProvider represents an Facebook based Identity Provider
|
||||||
@ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return "", errors.New("missing access token")
|
||||||
}
|
}
|
||||||
@ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool {
|
func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
|
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GitHubProvider represents an GitHub based Identity Provider
|
// GitHubProvider represents an GitHub based Identity Provider
|
||||||
@ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
|
|
||||||
var emails []struct {
|
var emails []struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
@ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserName returns the Account user name
|
// GetUserName returns the Account user name
|
||||||
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
|
func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
|
||||||
var user struct {
|
var user struct {
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||||
@ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Empty(t, "", email)
|
assert.Empty(t, "", email)
|
||||||
@ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
|
|||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
p.Org = "testorg1"
|
p.Org = "testorg1"
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||||
@ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
|
|||||||
// We'll trigger a request failure by using an unexpected access
|
// We'll trigger a request failure by using an unexpected access
|
||||||
// token. Alternatively, we could allow the parsing of the payload as
|
// token. Alternatively, we could allow the parsing of the payload as
|
||||||
// JSON to fail.
|
// JSON to fail.
|
||||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetUserName(session)
|
email, err := p.GetUserName(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "mbland", email)
|
assert.Equal(t, "mbland", email)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/api"
|
"github.com/pusher/oauth2_proxy/api"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GitLabProvider represents an GitLab based Identity Provider
|
// GitLabProvider represents an GitLab based Identity Provider
|
||||||
@ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
|
|
||||||
req, err := http.NewRequest("GET",
|
req, err := http.NewRequest("GET",
|
||||||
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitLabProvider(bURL.Host)
|
p := testGitLabProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||||
@ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
|
|||||||
// We'll trigger a request failure by using an unexpected access
|
// We'll trigger a request failure by using an unexpected access
|
||||||
// token. Alternatively, we could allow the parsing of the payload as
|
// token. Alternatively, we could allow the parsing of the payload as
|
||||||
// JSON to fail.
|
// JSON to fail.
|
||||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitLabProvider(bURL.Host)
|
p := testGitLabProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
admin "google.golang.org/api/admin/directory/v1"
|
admin "google.golang.org/api/admin/directory/v1"
|
||||||
@ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
err = errors.New("missing code")
|
||||||
return
|
return
|
||||||
@ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s = &SessionState{
|
s = &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||||
@ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
|
|||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new ID token if required
|
// RefreshToken to fetch a new ID token if required
|
||||||
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct {
|
|||||||
*ProviderData
|
*ProviderData
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
return "", errors.New("not implemented")
|
return "", errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note that we're testing the internal validateToken() used to implement
|
// Note that we're testing the internal validateToken() used to implement
|
||||||
// several Provider's ValidateSessionState() implementations
|
// several Provider's ValidateSessionState() implementations
|
||||||
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool {
|
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/api"
|
"github.com/pusher/oauth2_proxy/api"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LinkedInProvider represents an LinkedIn based Identity Provider
|
// LinkedInProvider represents an LinkedIn based Identity Provider
|
||||||
@ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return "", errors.New("missing access token")
|
||||||
}
|
}
|
||||||
@ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool {
|
func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
|
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testLinkedInProvider(bURL.Host)
|
p := testLinkedInProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@linkedin.com", email)
|
assert.Equal(t, "user@linkedin.com", email)
|
||||||
@ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
|
|||||||
// We'll trigger a request failure by using an unexpected access
|
// We'll trigger a request failure by using an unexpected access
|
||||||
// token. Alternatively, we could allow the parsing of the payload as
|
// token. Alternatively, we could allow the parsing of the payload as
|
||||||
// JSON to fail.
|
// JSON to fail.
|
||||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testLinkedInProvider(bURL.Host)
|
p := testLinkedInProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/dgrijalva/jwt-go"
|
"github.com/dgrijalva/jwt-go"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
err = errors.New("missing code")
|
||||||
return
|
return
|
||||||
@ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Store the data that we found in the session state
|
// Store the data that we found in the session state
|
||||||
s = &SessionState{
|
s = &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||||
|
@ -5,9 +5,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
|
|
||||||
oidc "github.com/coreos/go-oidc"
|
oidc "github.com/coreos/go-oidc"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OIDCProvider represents an OIDC based Identity Provider
|
// OIDCProvider represents an OIDC based Identity Provider
|
||||||
@ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
c := oauth2.Config{
|
c := oauth2.Config{
|
||||||
ClientID: p.ClientID,
|
ClientID: p.ClientID,
|
||||||
@ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
|
|||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new ID token if required
|
// RefreshToken to fetch a new ID token if required
|
||||||
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
@ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
|
func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) {
|
||||||
c := oauth2.Config{
|
c := oauth2.Config{
|
||||||
ClientID: p.ClientID,
|
ClientID: p.ClientID,
|
||||||
ClientSecret: p.ClientSecret,
|
ClientSecret: p.ClientSecret,
|
||||||
@ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) {
|
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||||
rawIDToken, ok := token.Extra("id_token").(string)
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||||
@ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
|||||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &SessionState{
|
return &sessions.SessionState{
|
||||||
AccessToken: token.AccessToken,
|
AccessToken: token.AccessToken,
|
||||||
IDToken: rawIDToken,
|
IDToken: rawIDToken,
|
||||||
RefreshToken: token.RefreshToken,
|
RefreshToken: token.RefreshToken,
|
||||||
@ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState checks that the session's IDToken is still valid
|
// ValidateSessionState checks that the session's IDToken is still valid
|
||||||
func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool {
|
func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -10,10 +10,11 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/cookie"
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem provides a default implementation of the OAuth2 token redemption process
|
// Redeem provides a default implementation of the OAuth2 token redemption process
|
||||||
func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
err = errors.New("missing code")
|
||||||
return
|
return
|
||||||
@ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
|
|||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &jsonResponse)
|
err = json.Unmarshal(body, &jsonResponse)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s = &SessionState{
|
s = &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if a := v.Get("access_token"); a != "" {
|
if a := v.Get("access_token"); a != "" {
|
||||||
s = &SessionState{AccessToken: a}
|
s = &sessions.SessionState{AccessToken: a}
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("no access token found %s", body)
|
err = fmt.Errorf("no access token found %s", body)
|
||||||
}
|
}
|
||||||
@ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CookieForSession serializes a session state for storage in a cookie
|
// CookieForSession serializes a session state for storage in a cookie
|
||||||
func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) {
|
func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) {
|
||||||
return s.EncodeSessionState(c)
|
return s.EncodeSessionState(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionFromCookie deserializes a session from a cookie value
|
// SessionFromCookie deserializes a session from a cookie value
|
||||||
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) {
|
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) {
|
||||||
return DecodeSessionState(v, c)
|
return sessions.DecodeSessionState(v, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
return "", errors.New("not implemented")
|
return "", errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserName returns the Account username
|
// GetUserName returns the Account username
|
||||||
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
|
func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) {
|
||||||
return "", errors.New("not implemented")
|
return "", errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
|
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
return validateToken(p, s.AccessToken, nil)
|
return validateToken(p, s.AccessToken, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
// RefreshSessionIfNeeded should refresh the user's session if required and
|
||||||
// do nothing if a refresh is not required
|
// do nothing if a refresh is not required
|
||||||
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -4,12 +4,13 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRefresh(t *testing.T) {
|
func TestRefresh(t *testing.T) {
|
||||||
p := &ProviderData{}
|
p := &ProviderData{}
|
||||||
refreshed, err := p.RefreshSessionIfNeeded(&SessionState{
|
refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{
|
||||||
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
|
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
|
||||||
})
|
})
|
||||||
assert.Equal(t, false, refreshed)
|
assert.Equal(t, false, refreshed)
|
||||||
|
@ -2,20 +2,21 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/pusher/oauth2_proxy/cookie"
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Provider represents an upstream identity provider implementation
|
// Provider represents an upstream identity provider implementation
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
Data() *ProviderData
|
Data() *ProviderData
|
||||||
GetEmailAddress(*SessionState) (string, error)
|
GetEmailAddress(*sessions.SessionState) (string, error)
|
||||||
GetUserName(*SessionState) (string, error)
|
GetUserName(*sessions.SessionState) (string, error)
|
||||||
Redeem(string, string) (*SessionState, error)
|
Redeem(string, string) (*sessions.SessionState, error)
|
||||||
ValidateGroup(string) bool
|
ValidateGroup(string) bool
|
||||||
ValidateSessionState(*SessionState) bool
|
ValidateSessionState(*sessions.SessionState) bool
|
||||||
GetLoginURL(redirectURI, finalRedirect string) string
|
GetLoginURL(redirectURI, finalRedirect string) string
|
||||||
RefreshSessionIfNeeded(*SessionState) (bool, error)
|
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
|
||||||
SessionFromCookie(string, *cookie.Cipher) (*SessionState, error)
|
SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error)
|
||||||
CookieForSession(*SessionState, *cookie.Cipher) (string, error)
|
CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New provides a new Provider based on the configured provider string
|
// New provides a new Provider based on the configured provider string
|
||||||
|
Loading…
Reference in New Issue
Block a user