Merge remote-tracking branch 'upstream/master' into verified

This commit is contained in:
Carlos Alexandro Becker 2019-03-20 13:46:04 -03:00
commit 95ee4358b2
No known key found for this signature in database
GPG Key ID: E61E2F7DC14AB940
15 changed files with 1091 additions and 148 deletions

9
.github/CODEOWNERS vendored
View File

@ -1,3 +1,12 @@
# Default owner should be a Pusher cloud-team member unless overridden by later # Default owner should be a Pusher cloud-team member unless overridden by later
# rules in this file # rules in this file
* @pusher/cloud-team * @pusher/cloud-team
# login.gov provider
# Note: If @timothy-spencer terms out of his appointment, your best bet
# for finding somebody who can test the oauth2_proxy would be to ask somebody
# in the login.gov team (https://login.gov/developers/), the cloud.gov team
# (https://cloud.gov/docs/help/), or the 18F org (https://18f.gsa.gov/contact/
# or the public devops channel at https://chat.18f.gov/).
providers/logingov.go @timothy-spencer
providers/logingov_test.go @timothy-spencer

View File

@ -3,6 +3,13 @@
## Changes since v3.1.0 ## Changes since v3.1.0
- [#96](https://github.com/bitly/oauth2_proxy/pull/96) Check if email is verified on GitHub (@caarlos0) - [#96](https://github.com/bitly/oauth2_proxy/pull/96) Check if email is verified on GitHub (@caarlos0)
- [#63](https://github.com/pusher/oauth2_proxy/pull/63) Use encoding/json for SessionState serialization (@yaegashi)
- Use JSON to encode session state to be stored in browser cookies
- Implement legacy decode function to support existing cookies generated by older versions
- Add detailed table driven tests in session_state_test.go
- [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added login.gov provider (@timothy-spencer)
- [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added environment variables for all config options (@timothy-spencer)
- [#70](https://github.com/pusher/oauth2_proxy/pull/70) Fix handling of splitted cookies (@einfachchr)
- [#92](https://github.com/pusher/oauth2_proxy/pull/92) Merge websocket proxy feature from openshift/oauth-proxy (@butzist) - [#92](https://github.com/pusher/oauth2_proxy/pull/92) Merge websocket proxy feature from openshift/oauth-proxy (@butzist)
- [#57](https://github.com/pusher/oauth2_proxy/pull/57) Fall back to using OIDC Subject instead of Email (@aigarius) - [#57](https://github.com/pusher/oauth2_proxy/pull/57) Fall back to using OIDC Subject instead of Email (@aigarius)
- [#85](https://github.com/pusher/oauth2_proxy/pull/85) Use non-root user in docker images (@kskewes) - [#85](https://github.com/pusher/oauth2_proxy/pull/85) Use non-root user in docker images (@kskewes)

9
Gopkg.lock generated
View File

@ -41,6 +41,14 @@
revision = "346938d642f2ec3594ed81d874461961cd0faa76" revision = "346938d642f2ec3594ed81d874461961cd0faa76"
version = "v1.1.0" version = "v1.1.0"
[[projects]]
digest = "1:6098222470fe0172157ce9bbef5d2200df4edde17ee649c5d6e48330e4afa4c6"
name = "github.com/dgrijalva/jwt-go"
packages = ["."]
pruneopts = ""
revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e"
version = "v3.2.0"
[[projects]] [[projects]]
branch = "master" branch = "master"
digest = "1:3b760d3b93f994df8eb1d9ebfad17d3e9e37edcb7f7efaa15b427c0d7a64f4e4" digest = "1:3b760d3b93f994df8eb1d9ebfad17d3e9e37edcb7f7efaa15b427c0d7a64f4e4"
@ -201,6 +209,7 @@
"github.com/BurntSushi/toml", "github.com/BurntSushi/toml",
"github.com/bitly/go-simplejson", "github.com/bitly/go-simplejson",
"github.com/coreos/go-oidc", "github.com/coreos/go-oidc",
"github.com/dgrijalva/jwt-go",
"github.com/mbland/hmacauth", "github.com/mbland/hmacauth",
"github.com/mreiferson/go-options", "github.com/mreiferson/go-options",
"github.com/stretchr/testify/assert", "github.com/stretchr/testify/assert",

View File

@ -48,6 +48,7 @@ Valid providers are :
- [GitHub](#github-auth-provider) - [GitHub](#github-auth-provider)
- [GitLab](#gitlab-auth-provider) - [GitLab](#gitlab-auth-provider)
- [LinkedIn](#linkedin-auth-provider) - [LinkedIn](#linkedin-auth-provider)
- [login.gov](#login.gov-provider)
The provider can be selected using the `provider` configuration value. The provider can be selected using the `provider` configuration value.
@ -166,6 +167,54 @@ OpenID Connect is a spec for OAUTH 2.0 + identity that is implemented by many ma
-cookie-secure=false -cookie-secure=false
-email-domain example.com -email-domain example.com
### login.gov Provider
login.gov is an OIDC provider for the US Government.
If you are a US Government agency, you can contact the login.gov team through the contact information
that you can find on https://login.gov/developers/ and work with them to understand how to get login.gov
accounts for integration/test and production access.
A developer guide is available here: https://developers.login.gov/, though this proxy handles everything
but the data you need to create to register your application in the login.gov dashboard.
As a demo, we will assume that you are running your application that you want to secure locally on
http://localhost:3000/, that you will be starting your proxy up on http://localhost:4180/, and that
you have an agency integration account for testing.
First, register your application in the dashboard. The important bits are:
* Identity protocol: make this `Openid connect`
* Issuer: do what they say for OpenID Connect. We will refer to this string as `${LOGINGOV_ISSUER}`.
* Public key: This is a self-signed certificate in .pem format generated from a 2048 bit RSA private key.
A quick way to do this is `openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -days 3650 -nodes -subj '/C=US/ST=Washington/L=DC/O=GSA/OU=18F/CN=localhost'`,
The contents of the `key.pem` shall be referred to as `${OAUTH2_PROXY_JWT_KEY}`.
* Return to App URL: Make this be `http://localhost:4180/`
* Redirect URIs: Make this be `http://localhost:4180/oauth2/callback`.
* Attribute Bundle: Make sure that email is selected.
Now start the proxy up with the following options:
```
./oauth2_proxy -provider login.gov \
-client-id=${LOGINGOV_ISSUER} \
-redirect-url=http://localhost:4180/oauth2/callback \
-oidc-issuer-url=https://idp.int.identitysandbox.gov/ \
-cookie-secure=false \
-email-domain=gsa.gov \
-upstream=http://localhost:3000/ \
-cookie-secret=somerandomstring12341234567890AB \
-cookie-domain=localhost \
-skip-provider-button=true \
-pubjwk-url=https://idp.int.identitysandbox.gov/api/openid_connect/certs \
-profile-url=https://idp.int.identitysandbox.gov/api/openid_connect/userinfo \
-jwt-key="${OAUTH2_PROXY_JWT_KEY}"
```
You can also set all these options with environment variables, for use in cloud/docker environments.
Once it is running, you should be able to go to `http://localhost:4180/` in your browser,
get authenticated by the login.gov integration server, and then get proxied on to your
application running on `http://localhost:3000/`. In a real deployment, you would secure
your application with a firewall or something so that it was only accessible from the
proxy, and you would use real hostnames everywhere.
#### Skip OIDC discovery #### Skip OIDC discovery
Some providers do not support OIDC discovery via their issuer URL, so oauth2_proxy cannot simply grab the authorization, token and jwks URI endpoints from the provider's metadata. Some providers do not support OIDC discovery via their issuer URL, so oauth2_proxy cannot simply grab the authorization, token and jwks URI endpoints from the provider's metadata.

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings"
"testing" "testing"
"time" "time"
) )
@ -16,7 +17,7 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
Format, Format,
ExpectedLogMessage string ExpectedLogMessage string
}{ }{
{defaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", ts.Format("02/Jan/2006:15:04:05 -0700"))}, {defaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0", ts.Format("02/Jan/2006:15:04:05 -0700"))},
{"{{.RequestMethod}}", "GET\n"}, {"{{.RequestMethod}}", "GET\n"},
} }
@ -35,8 +36,8 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) {
h.ServeHTTP(httptest.NewRecorder(), r) h.ServeHTTP(httptest.NewRecorder(), r)
actual := buf.String() actual := buf.String()
if actual != test.ExpectedLogMessage { if !strings.Contains(actual, test.ExpectedLogMessage) {
t.Errorf("Log message was\n%s\ninstead of expected \n%s", actual, test.ExpectedLogMessage) t.Errorf("Log message was\n%s\ninstead of matching \n%s", actual, test.ExpectedLogMessage)
} }
} }
} }

View File

@ -4,6 +4,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"math/rand"
"os" "os"
"runtime" "runtime"
"strings" "strings"
@ -88,6 +89,9 @@ func main() {
flagSet.String("approval-prompt", "force", "OAuth approval_prompt") flagSet.String("approval-prompt", "force", "OAuth approval_prompt")
flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)")
flagSet.String("acr-values", "http://idmanagement.gov/ns/assurance/loa/1", "acr values string: optional, used by login.gov")
flagSet.String("jwt-key", "", "private key used to sign JWT: required by login.gov")
flagSet.String("pubjwk-url", "", "JWK pubkey access endpoint: required by login.gov")
flagSet.Parse(os.Args[1:]) flagSet.Parse(os.Args[1:])
@ -133,6 +137,8 @@ func main() {
} }
} }
rand.Seed(time.Now().UnixNano())
s := &Server{ s := &Server{
Handler: LoggingHandler(os.Stdout, oauthproxy, opts.RequestLogging, opts.RequestLoggingFormat), Handler: LoggingHandler(os.Stdout, oauthproxy, opts.RequestLogging, opts.RequestLoggingFormat),
Opts: opts, Opts: opts,

View File

@ -204,7 +204,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
} }
redirectURL := opts.redirectURL redirectURL := opts.redirectURL
if redirectURL.String() == "" { if redirectURL.Path == "" {
redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix)
} }
@ -241,7 +241,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix),
SignOutPath: fmt.Sprintf("%s/sign_out", opts.ProxyPrefix), SignOutPath: fmt.Sprintf("%s/sign_out", opts.ProxyPrefix),
OAuthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix), OAuthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix),
OAuthCallbackPath: redirectURL.Path, OAuthCallbackPath: fmt.Sprintf("%s/callback", opts.ProxyPrefix),
AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix), AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix),
ProxyPrefix: opts.ProxyPrefix, ProxyPrefix: opts.ProxyPrefix,
@ -452,9 +452,18 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va
// ClearSessionCookie creates a cookie to unset the user's authentication cookie // ClearSessionCookie creates a cookie to unset the user's authentication cookie
// stored in the user's session // stored in the user's session
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
cookies := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) var cookies []*http.Cookie
for _, clr := range cookies {
http.SetCookie(rw, clr) // matches CookieName, CookieName_<number>
var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", p.CookieName))
for _, c := range req.Cookies() {
if cookieNameRegex.MatchString(c.Name) {
clearCookie := p.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
http.SetCookie(rw, clearCookie)
cookies = append(cookies, clearCookie)
}
} }
// ugly hack because default domain changed // ugly hack because default domain changed

View File

@ -1064,3 +1064,47 @@ func TestAjaxForbiddendRequest(t *testing.T) {
mime := rh.Get("Content-Type") mime := rh.Get("Content-Type")
assert.NotEqual(t, applicationJSON, mime) assert.NotEqual(t, applicationJSON, mime)
} }
func TestClearSplitCookie(t *testing.T) {
p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"}
var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)
req.AddCookie(&http.Cookie{
Name: "test1",
Value: "test1",
})
req.AddCookie(&http.Cookie{
Name: "oauth2_0",
Value: "oauth2_0",
})
req.AddCookie(&http.Cookie{
Name: "oauth2_1",
Value: "oauth2_1",
})
p.ClearSessionCookie(rw, req)
header := rw.Header()
assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries")
}
func TestClearSingleCookie(t *testing.T) {
p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"}
var rw = httptest.NewRecorder()
req := httptest.NewRequest("get", "/", nil)
req.AddCookie(&http.Cookie{
Name: "test1",
Value: "test1",
})
req.AddCookie(&http.Cookie{
Name: "oauth2",
Value: "oauth2",
})
p.ClearSessionCookie(rw, req)
header := rw.Header()
assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries")
}

View File

@ -14,6 +14,7 @@ import (
"time" "time"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"github.com/dgrijalva/jwt-go"
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/providers" "github.com/pusher/oauth2_proxy/providers"
) )
@ -21,71 +22,74 @@ import (
// Options holds Configuration Options that can be set by Command Line Flag, // Options holds Configuration Options that can be set by Command Line Flag,
// or Config File // or Config File
type Options struct { type Options struct {
ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"` ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix" env:"OAUTH2_PROXY_PROXY_PREFIX"`
ProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets"` ProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets" env:"OAUTH2_PROXY_PROXY_WEBSOCKETS"`
HTTPAddress string `flag:"http-address" cfg:"http_address"` HTTPAddress string `flag:"http-address" cfg:"http_address" env:"OAUTH2_PROXY_HTTP_ADDRESS"`
HTTPSAddress string `flag:"https-address" cfg:"https_address"` HTTPSAddress string `flag:"https-address" cfg:"https_address" env:"OAUTH2_PROXY_HTTPS_ADDRESS"`
RedirectURL string `flag:"redirect-url" cfg:"redirect_url"` RedirectURL string `flag:"redirect-url" cfg:"redirect_url" env:"OAUTH2_PROXY_REDIRECT_URL"`
ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"`
ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"`
TLSCertFile string `flag:"tls-cert" cfg:"tls_cert_file"` TLSCertFile string `flag:"tls-cert" cfg:"tls_cert_file" env:"OAUTH2_PROXY_TLS_CERT_FILE"`
TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file"` TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file" env:"OAUTH2_PROXY_TLS_KEY_FILE"`
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file" env:"OAUTH2_PROXY_AUTHENTICATED_EMAILS_FILE"`
AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"` AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant" env:"OAUTH2_PROXY_AZURE_TENANT"`
EmailDomains []string `flag:"email-domain" cfg:"email_domains"` EmailDomains []string `flag:"email-domain" cfg:"email_domains" env:"OAUTH2_PROXY_EMAIL_DOMAINS"`
WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains" env:"OAUTH2_PROXY_WHITELIST_DOMAINS"` WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains" env:"OAUTH2_PROXY_WHITELIST_DOMAINS"`
GitHubOrg string `flag:"github-org" cfg:"github_org"` GitHubOrg string `flag:"github-org" cfg:"github_org" env:"OAUTH2_PROXY_GITHUB_ORG"`
GitHubTeam string `flag:"github-team" cfg:"github_team"` GitHubTeam string `flag:"github-team" cfg:"github_team" env:"OAUTH2_PROXY_GITHUB_TEAM"`
GoogleGroups []string `flag:"google-group" cfg:"google_group"` GoogleGroups []string `flag:"google-group" cfg:"google_group" env:"OAUTH2_PROXY_GOOGLE_GROUPS"`
GoogleAdminEmail string `flag:"google-admin-email" cfg:"google_admin_email"` GoogleAdminEmail string `flag:"google-admin-email" cfg:"google_admin_email" env:"OAUTH2_PROXY_GOOGLE_ADMIN_EMAIL"`
GoogleServiceAccountJSON string `flag:"google-service-account-json" cfg:"google_service_account_json"` GoogleServiceAccountJSON string `flag:"google-service-account-json" cfg:"google_service_account_json" env:"OAUTH2_PROXY_GOOGLE_SERVICE_ACCOUNT_JSON"`
HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file" env:"OAUTH2_PROXY_HTPASSWD_FILE"`
DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form" env:"OAUTH2_PROXY_DISPLAY_HTPASSWD_FORM"`
CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"`
Footer string `flag:"footer" cfg:"footer"` Footer string `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"`
CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"`
CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"`
CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"`
CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"`
CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"`
CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"` CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"`
CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"`
Upstreams []string `flag:"upstream" cfg:"upstreams"` Upstreams []string `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"`
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"`
PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"` PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth" env:"OAUTH2_PROXY_PASS_BASIC_AUTH"`
BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password"` BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password" env:"OAUTH2_PROXY_BASIC_AUTH_PASSWORD"`
PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token"` PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token" env:"OAUTH2_PROXY_PASS_ACCESS_TOKEN"`
PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header" env:"OAUTH2_PROXY_PASS_HOST_HEADER"`
SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button"` SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button" env:"OAUTH2_PROXY_SKIP_PROVIDER_BUTTON"`
PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"` PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers" env:"OAUTH2_PROXY_PASS_USER_HEADERS"`
SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify" env:"OAUTH2_PROXY_SSL_INSECURE_SKIP_VERIFY"`
SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"` SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest" env:"OAUTH2_PROXY_SET_XAUTHREQUEST"`
SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header"` SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header" env:"OAUTH2_PROXY_SET_AUTHORIZATION_HEADER"`
PassAuthorization bool `flag:"pass-authorization-header" cfg:"pass_authorization_header"` PassAuthorization bool `flag:"pass-authorization-header" cfg:"pass_authorization_header" env:"OAUTH2_PROXY_PASS_AUTHORIZATION_HEADER"`
SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight" env:"OAUTH2_PROXY_SKIP_AUTH_PREFLIGHT"`
FlushInterval time.Duration `flag:"flush-interval" cfg:"flush_interval"` FlushInterval time.Duration `flag:"flush-interval" cfg:"flush_interval" env:"OAUTH2_PROXY_FLUSH_INTERVAL"`
// These options allow for other providers besides Google, with // These options allow for other providers besides Google, with
// potential overrides. // potential overrides.
Provider string `flag:"provider" cfg:"provider"` Provider string `flag:"provider" cfg:"provider" env:"OAUTH2_PROXY_PROVIDER"`
OIDCIssuerURL string `flag:"oidc-issuer-url" cfg:"oidc_issuer_url"` OIDCIssuerURL string `flag:"oidc-issuer-url" cfg:"oidc_issuer_url" env:"OAUTH2_PROXY_OIDC_ISSUER_URL"`
SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery"` SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery" env:"OAUTH2_SKIP_OIDC_DISCOVERY"`
OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url"` OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url" env:"OAUTH2_OIDC_JWKS_URL"`
LoginURL string `flag:"login-url" cfg:"login_url"` LoginURL string `flag:"login-url" cfg:"login_url" env:"OAUTH2_PROXY_LOGIN_URL"`
RedeemURL string `flag:"redeem-url" cfg:"redeem_url"` RedeemURL string `flag:"redeem-url" cfg:"redeem_url" env:"OAUTH2_PROXY_REDEEM_URL"`
ProfileURL string `flag:"profile-url" cfg:"profile_url"` ProfileURL string `flag:"profile-url" cfg:"profile_url" env:"OAUTH2_PROXY_PROFILE_URL"`
ProtectedResource string `flag:"resource" cfg:"resource"` ProtectedResource string `flag:"resource" cfg:"resource" env:"OAUTH2_PROXY_RESOURCE"`
ValidateURL string `flag:"validate-url" cfg:"validate_url"` ValidateURL string `flag:"validate-url" cfg:"validate_url" env:"OAUTH2_PROXY_VALIDATE_URL"`
Scope string `flag:"scope" cfg:"scope"` Scope string `flag:"scope" cfg:"scope" env:"OAUTH2_PROXY_SCOPE"`
ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt"` ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt" env:"OAUTH2_PROXY_APPROVAL_PROMPT"`
RequestLogging bool `flag:"request-logging" cfg:"request_logging"` RequestLogging bool `flag:"request-logging" cfg:"request_logging" env:"OAUTH2_PROXY_REQUEST_LOGGING"`
RequestLoggingFormat string `flag:"request-logging-format" cfg:"request_logging_format"` RequestLoggingFormat string `flag:"request-logging-format" cfg:"request_logging_format" env:"OAUTH2_PROXY_REQUEST_LOGGING_FORMAT"`
SignatureKey string `flag:"signature-key" cfg:"signature_key" env:"OAUTH2_PROXY_SIGNATURE_KEY"` SignatureKey string `flag:"signature-key" cfg:"signature_key" env:"OAUTH2_PROXY_SIGNATURE_KEY"`
AcrValues string `flag:"acr-values" cfg:"acr_values" env:"OAUTH2_PROXY_ACR_VALUES"`
JWTKey string `flag:"jwt-key" cfg:"jwt_key" env:"OAUTH2_PROXY_JWT_KEY"`
PubJWKURL string `flag:"pubjwk-url" cfg:"pubjwk_url" env:"OAUTH2_PROXY_PUBJWK_URL"`
// internal values that are set after config validation // internal values that are set after config validation
redirectURL *url.URL redirectURL *url.URL
@ -157,7 +161,8 @@ func (o *Options) Validate() error {
if o.ClientID == "" { if o.ClientID == "" {
msgs = append(msgs, "missing setting: client-id") msgs = append(msgs, "missing setting: client-id")
} }
if o.ClientSecret == "" { // login.gov uses a signed JWT to authenticate, not a client-secret
if o.ClientSecret == "" && o.Provider != "login.gov" {
msgs = append(msgs, "missing setting: client-secret") msgs = append(msgs, "missing setting: client-secret")
} }
if o.AuthenticatedEmailsFile == "" && len(o.EmailDomains) == 0 && o.HtpasswdFile == "" { if o.AuthenticatedEmailsFile == "" && len(o.EmailDomains) == 0 && o.HtpasswdFile == "" {
@ -318,6 +323,19 @@ func parseProviderInfo(o *Options, msgs []string) []string {
} else { } else {
p.Verifier = o.oidcVerifier p.Verifier = o.oidcVerifier
} }
case *providers.LoginGovProvider:
p.AcrValues = o.AcrValues
p.PubJWKURL, msgs = parseURL(o.PubJWKURL, "pubjwk", msgs)
if o.JWTKey == "" {
msgs = append(msgs, "login.gov provider requires a private key for signing JWTs")
} else {
signKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(o.JWTKey))
if err != nil {
msgs = append(msgs, "could not parse RSA Private Key PEM")
} else {
p.JWTKey = signKey
}
}
} }
return msgs return msgs
} }

275
providers/logingov.go Normal file
View File

@ -0,0 +1,275 @@
package providers
import (
"bytes"
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"time"
"github.com/dgrijalva/jwt-go"
"gopkg.in/square/go-jose.v2"
)
// LoginGovProvider represents an OIDC based Identity Provider
type LoginGovProvider struct {
*ProviderData
// TODO (@timothy-spencer): Ideally, the nonce would be in the session state, but the session state
// is created only upon code redemption, not during the auth, when this must be supplied.
Nonce string
AcrValues string
JWTKey *rsa.PrivateKey
PubJWKURL *url.URL
}
// For generating a nonce
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
func randSeq(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}
// NewLoginGovProvider initiates a new LoginGovProvider
func NewLoginGovProvider(p *ProviderData) *LoginGovProvider {
p.ProviderName = "login.gov"
if p.LoginURL == nil || p.LoginURL.String() == "" {
p.LoginURL = &url.URL{
Scheme: "https",
Host: "secure.login.gov",
Path: "/openid_connect/authorize",
}
}
if p.RedeemURL == nil || p.RedeemURL.String() == "" {
p.RedeemURL = &url.URL{
Scheme: "https",
Host: "secure.login.gov",
Path: "/api/openid_connect/token",
}
}
if p.ProfileURL == nil || p.ProfileURL.String() == "" {
p.ProfileURL = &url.URL{
Scheme: "https",
Host: "secure.login.gov",
Path: "/api/openid_connect/userinfo",
}
}
if p.Scope == "" {
p.Scope = "email openid"
}
return &LoginGovProvider{
ProviderData: p,
Nonce: randSeq(32),
}
}
type loginGovCustomClaims struct {
Acr string `json:"acr"`
Nonce string `json:"nonce"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Birthdate string `json:"birthdate"`
AtHash string `json:"at_hash"`
CHash string `json:"c_hash"`
jwt.StandardClaims
}
// checkNonce checks the nonce in the id_token
func checkNonce(idToken string, p *LoginGovProvider) (err error) {
token, err := jwt.ParseWithClaims(idToken, &loginGovCustomClaims{}, func(token *jwt.Token) (interface{}, error) {
resp, myerr := http.Get(p.PubJWKURL.String())
if myerr != nil {
return nil, myerr
}
if resp.StatusCode != 200 {
myerr = fmt.Errorf("got %d from %q", resp.StatusCode, p.PubJWKURL.String())
return nil, myerr
}
body, myerr := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if myerr != nil {
return nil, myerr
}
var pubkeys jose.JSONWebKeySet
myerr = json.Unmarshal(body, &pubkeys)
if myerr != nil {
return nil, myerr
}
pubkey := pubkeys.Keys[0]
return pubkey.Key, nil
})
if err != nil {
return
}
claims := token.Claims.(*loginGovCustomClaims)
if claims.Nonce != p.Nonce {
err = fmt.Errorf("nonce validation failed")
return
}
return
}
func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email string, err error) {
// query the user info endpoint for user attributes
var req *http.Request
req, err = http.NewRequest("GET", userInfoEndpoint, nil)
if err != nil {
return
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body)
return
}
// parse the user attributes from the data we got and make sure that
// the email address has been validated.
var emailData struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
err = json.Unmarshal(body, &emailData)
if err != nil {
return
}
if emailData.Email == "" {
err = fmt.Errorf("missing email")
return
}
email = emailData.Email
if !emailData.EmailVerified {
err = fmt.Errorf("email %s not listed as verified", email)
return
}
return
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
if code == "" {
err = errors.New("missing code")
return
}
claims := &jwt.StandardClaims{
Issuer: p.ClientID,
Subject: p.ClientID,
Audience: p.RedeemURL.String(),
ExpiresAt: int64(time.Now().Add(time.Duration(5 * time.Minute)).Unix()),
Id: randSeq(32),
}
token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims)
ss, err := token.SignedString(p.JWTKey)
if err != nil {
return
}
params := url.Values{}
params.Add("client_assertion", ss)
params.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")
params.Add("code", code)
params.Add("grant_type", "authorization_code")
var req *http.Request
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
var resp *http.Response
resp, err = http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
// Get the token from the body that we got from the token endpoint.
var jsonResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}
err = json.Unmarshal(body, &jsonResponse)
if err != nil {
return
}
// check nonce here
err = checkNonce(jsonResponse.IDToken, p)
if err != nil {
return
}
// Get the email address
var email string
email, err = emailFromUserInfo(jsonResponse.AccessToken, p.ProfileURL.String())
if err != nil {
return
}
// Store the data that we found in the session state
s = &SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
Email: email,
}
return
}
// GetLoginURL overrides GetLoginURL to add login.gov parameters
func (p *LoginGovProvider) GetLoginURL(redirectURI, state string) string {
var a url.URL
a = *p.LoginURL
params, _ := url.ParseQuery(a.RawQuery)
params.Set("redirect_uri", redirectURI)
params.Set("approval_prompt", p.ApprovalPrompt)
params.Add("scope", p.Scope)
params.Set("client_id", p.ClientID)
params.Set("response_type", "code")
params.Add("state", state)
params.Add("acr_values", p.AcrValues)
params.Add("nonce", p.Nonce)
a.RawQuery = params.Encode()
return a.String()
}

290
providers/logingov_test.go Normal file
View File

@ -0,0 +1,290 @@
package providers
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/assert"
"gopkg.in/square/go-jose.v2"
)
type MyKeyData struct {
PubKey crypto.PublicKey
PrivKey *rsa.PrivateKey
PubJWK jose.JSONWebKey
}
func newLoginGovServer(body []byte) (*url.URL, *httptest.Server) {
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write(body)
}))
u, _ := url.Parse(s.URL)
return u, s
}
func newLoginGovProvider() (l *LoginGovProvider, serverKey *MyKeyData, err error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return
}
serverKey = &MyKeyData{
PubKey: key.Public(),
PrivKey: key,
PubJWK: jose.JSONWebKey{
Key: key.Public(),
KeyID: "testkey",
Algorithm: string(jose.RS256),
Use: "sig",
},
}
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return
}
l = NewLoginGovProvider(
&ProviderData{
ProviderName: "",
LoginURL: &url.URL{},
RedeemURL: &url.URL{},
ProfileURL: &url.URL{},
ValidateURL: &url.URL{},
Scope: ""})
l.JWTKey = privateKey
l.Nonce = "fakenonce"
return
}
func TestLoginGovProviderDefaults(t *testing.T) {
p, _, err := newLoginGovProvider()
assert.NotEqual(t, nil, p)
assert.NoError(t, err)
assert.Equal(t, "login.gov", p.Data().ProviderName)
assert.Equal(t, "https://secure.login.gov/openid_connect/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://secure.login.gov/api/openid_connect/token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://secure.login.gov/api/openid_connect/userinfo",
p.Data().ProfileURL.String())
assert.Equal(t, "email openid", p.Data().Scope)
}
func TestLoginGovProviderOverrides(t *testing.T) {
p := NewLoginGovProvider(
&ProviderData{
LoginURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/oauth/auth"},
RedeemURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/oauth/token"},
ProfileURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/oauth/profile"},
Scope: "profile"})
assert.NotEqual(t, nil, p)
assert.Equal(t, "login.gov", p.Data().ProviderName)
assert.Equal(t, "https://example.com/oauth/auth",
p.Data().LoginURL.String())
assert.Equal(t, "https://example.com/oauth/token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://example.com/oauth/profile",
p.Data().ProfileURL.String())
assert.Equal(t, "profile", p.Data().Scope)
}
func TestLoginGovProviderSessionData(t *testing.T) {
p, serverkey, err := newLoginGovProvider()
assert.NotEqual(t, nil, p)
assert.NoError(t, err)
// Set up the redeem endpoint here
type loginGovRedeemResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
expiresIn := int64(60)
type MyCustomClaims struct {
Acr string `json:"acr"`
Nonce string `json:"nonce"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Birthdate string `json:"birthdate"`
AtHash string `json:"at_hash"`
CHash string `json:"c_hash"`
jwt.StandardClaims
}
claims := MyCustomClaims{
"http://idmanagement.gov/ns/assurance/loa/1",
"fakenonce",
"timothy.spencer@gsa.gov",
true,
"",
"",
"",
"",
"",
jwt.StandardClaims{
Audience: "Audience",
ExpiresAt: time.Now().Unix() + expiresIn,
Id: "foo",
IssuedAt: time.Now().Unix(),
Issuer: "https://idp.int.login.gov",
NotBefore: time.Now().Unix() - 1,
Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca",
},
}
idtoken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedidtoken, err := idtoken.SignedString(serverkey.PrivKey)
assert.NoError(t, err)
body, err := json.Marshal(loginGovRedeemResponse{
AccessToken: "a1234",
TokenType: "Bearer",
ExpiresIn: expiresIn,
IDToken: signedidtoken,
})
assert.NoError(t, err)
var server *httptest.Server
p.RedeemURL, server = newLoginGovServer(body)
defer server.Close()
// Set up the user endpoint here
type loginGovUserResponse struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Subject string `json:"sub"`
}
userbody, err := json.Marshal(loginGovUserResponse{
Email: "timothy.spencer@gsa.gov",
EmailVerified: true,
Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca",
})
assert.NoError(t, err)
var userserver *httptest.Server
p.ProfileURL, userserver = newLoginGovServer(userbody)
defer userserver.Close()
// Set up the PubJWKURL endpoint here used to verify the JWT
var pubkeys jose.JSONWebKeySet
pubkeys.Keys = append(pubkeys.Keys, serverkey.PubJWK)
pubjwkbody, err := json.Marshal(pubkeys)
assert.NoError(t, err)
var pubjwkserver *httptest.Server
p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody)
defer pubjwkserver.Close()
session, err := p.Redeem("http://redirect/", "code1234")
assert.NoError(t, err)
assert.NotEqual(t, session, nil)
assert.Equal(t, "timothy.spencer@gsa.gov", session.Email)
assert.Equal(t, "a1234", session.AccessToken)
// The test ought to run in under 2 seconds. If not, you may need to bump this up.
assert.InDelta(t, session.ExpiresOn.Unix(), time.Now().Unix()+expiresIn, 2)
}
func TestLoginGovProviderBadNonce(t *testing.T) {
p, serverkey, err := newLoginGovProvider()
assert.NotEqual(t, nil, p)
assert.NoError(t, err)
// Set up the redeem endpoint here
type loginGovRedeemResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
expiresIn := int64(60)
type MyCustomClaims struct {
Acr string `json:"acr"`
Nonce string `json:"nonce"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Birthdate string `json:"birthdate"`
AtHash string `json:"at_hash"`
CHash string `json:"c_hash"`
jwt.StandardClaims
}
claims := MyCustomClaims{
"http://idmanagement.gov/ns/assurance/loa/1",
"badfakenonce",
"timothy.spencer@gsa.gov",
true,
"",
"",
"",
"",
"",
jwt.StandardClaims{
Audience: "Audience",
ExpiresAt: time.Now().Unix() + expiresIn,
Id: "foo",
IssuedAt: time.Now().Unix(),
Issuer: "https://idp.int.login.gov",
NotBefore: time.Now().Unix() - 1,
Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca",
},
}
idtoken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedidtoken, err := idtoken.SignedString(serverkey.PrivKey)
assert.NoError(t, err)
body, err := json.Marshal(loginGovRedeemResponse{
AccessToken: "a1234",
TokenType: "Bearer",
ExpiresIn: expiresIn,
IDToken: signedidtoken,
})
assert.NoError(t, err)
var server *httptest.Server
p.RedeemURL, server = newLoginGovServer(body)
defer server.Close()
// Set up the user endpoint here
type loginGovUserResponse struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Subject string `json:"sub"`
}
userbody, err := json.Marshal(loginGovUserResponse{
Email: "timothy.spencer@gsa.gov",
EmailVerified: true,
Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca",
})
assert.NoError(t, err)
var userserver *httptest.Server
p.ProfileURL, userserver = newLoginGovServer(userbody)
defer userserver.Close()
// Set up the PubJWKURL endpoint here used to verify the JWT
var pubkeys jose.JSONWebKeySet
pubkeys.Keys = append(pubkeys.Keys, serverkey.PubJWK)
pubjwkbody, err := json.Marshal(pubkeys)
assert.NoError(t, err)
var pubjwkserver *httptest.Server
p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody)
defer pubjwkserver.Close()
_, err = p.Redeem("http://redirect/", "code1234")
// The "badfakenonce" in the idtoken above should cause this to error out
assert.Error(t, err)
}

View File

@ -33,6 +33,8 @@ func New(provider string, p *ProviderData) Provider {
return NewGitLabProvider(p) return NewGitLabProvider(p)
case "oidc": case "oidc":
return NewOIDCProvider(p) return NewOIDCProvider(p)
case "login.gov":
return NewLoginGovProvider(p)
default: default:
return NewGoogleProvider(p) return NewGoogleProvider(p)
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@ -11,12 +12,18 @@ import (
// SessionState is used to store information about the currently authenticated user session // SessionState is used to store information about the currently authenticated user session
type SessionState struct { type SessionState struct {
AccessToken string AccessToken string `json:",omitempty"`
IDToken string IDToken string `json:",omitempty"`
ExpiresOn time.Time ExpiresOn time.Time `json:"-"`
RefreshToken string RefreshToken string `json:",omitempty"`
Email string Email string `json:",omitempty"`
User string User string `json:",omitempty"`
}
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
type SessionStateJSON struct {
*SessionState
ExpiresOn *time.Time `json:",omitempty"`
} }
// IsExpired checks whether the session has expired // IsExpired checks whether the session has expired
@ -29,7 +36,7 @@ func (s *SessionState) IsExpired() bool {
// String constructs a summary of the session state // String constructs a summary of the session state
func (s *SessionState) String() string { func (s *SessionState) String() string {
o := fmt.Sprintf("Session{%s", s.accountInfo()) o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User)
if s.AccessToken != "" { if s.AccessToken != "" {
o += " token:true" o += " token:true"
} }
@ -47,95 +54,145 @@ func (s *SessionState) String() string {
// EncodeSessionState returns string representation of the current session // EncodeSessionState returns string representation of the current session
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
if c == nil || s.AccessToken == "" { var ss SessionState
return s.accountInfo(), nil
}
return s.EncryptedString(c)
}
func (s *SessionState) accountInfo() string {
return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
}
// EncryptedString encrypts the session state into a cookie string
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
var err error
if c == nil { if c == nil {
panic("error. missing cipher") // Store only Email and User when cipher is unavailable
} ss.Email = s.Email
a := s.AccessToken ss.User = s.User
if a != "" { } else {
if a, err = c.Encrypt(a); err != nil { ss = *s
return "", err var err error
if ss.AccessToken != "" {
ss.AccessToken, err = c.Encrypt(ss.AccessToken)
if err != nil {
return "", err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Encrypt(ss.IDToken)
if err != nil {
return "", err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Encrypt(ss.RefreshToken)
if err != nil {
return "", err
}
} }
} }
i := s.IDToken // Embed SessionState and ExpiresOn pointer into SessionStateJSON
if i != "" { ssj := &SessionStateJSON{SessionState: &ss}
if i, err = c.Encrypt(i); err != nil { if !ss.ExpiresOn.IsZero() {
return "", err ssj.ExpiresOn = &ss.ExpiresOn
}
} }
r := s.RefreshToken b, err := json.Marshal(ssj)
if r != "" { return string(b), err
if r, err = c.Encrypt(r); err != nil {
return "", err
}
}
return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil
} }
func decodeSessionStatePlain(v string) (s *SessionState, err error) { // legacyDecodeSessionStatePlain decodes older plain session state string
func legacyDecodeSessionStatePlain(v string) (*SessionState, error) {
chunks := strings.Split(v, " ") chunks := strings.Split(v, " ")
if len(chunks) != 2 { if len(chunks) != 2 {
return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks)) return nil, fmt.Errorf("invalid session state (legacy: expected 2 chunks for user/email got %d)", len(chunks))
} }
email := strings.TrimPrefix(chunks[0], "email:")
user := strings.TrimPrefix(chunks[1], "user:") user := strings.TrimPrefix(chunks[1], "user:")
if user == "" { email := strings.TrimPrefix(chunks[0], "email:")
user = strings.Split(email, "@")[0]
}
return &SessionState{User: user, Email: email}, nil return &SessionState{User: user, Email: email}, nil
} }
// DecodeSessionState decodes the session cookie string into a SessionState // legacyDecodeSessionState attempts to decode the session state string
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { // generated by v3.1.0 or older
if c == nil { func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
return decodeSessionStatePlain(v)
}
chunks := strings.Split(v, "|") chunks := strings.Split(v, "|")
if len(chunks) != 5 {
err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks)) if c == nil {
return if len(chunks) != 1 {
return nil, fmt.Errorf("invalid session state (legacy: expected 1 chunk for plain got %d)", len(chunks))
}
return legacyDecodeSessionStatePlain(chunks[0])
} }
sessionState, err := decodeSessionStatePlain(chunks[0]) if len(chunks) != 4 && len(chunks) != 5 {
return nil, fmt.Errorf("invalid session state (legacy: expected 4 or 5 chunks for full got %d)", len(chunks))
}
i := 0
ss, err := legacyDecodeSessionStatePlain(chunks[i])
if err != nil { if err != nil {
return nil, err return nil, err
} }
if chunks[1] != "" { i++
if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil { ss.AccessToken = chunks[i]
return nil, err
} if len(chunks) == 5 {
// SessionState with IDToken in v3.1.0
i++
ss.IDToken = chunks[i]
} }
if chunks[2] != "" { i++
if sessionState.IDToken, err = c.Decrypt(chunks[2]); err != nil { ts, err := strconv.Atoi(chunks[i])
return nil, err if err != nil {
} return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err)
} }
ss.ExpiresOn = time.Unix(int64(ts), 0)
ts, _ := strconv.Atoi(chunks[3]) i++
sessionState.ExpiresOn = time.Unix(int64(ts), 0) ss.RefreshToken = chunks[i]
if chunks[4] != "" { return ss, nil
if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil { }
return nil, err
} // DecodeSessionState decodes the session cookie string into a SessionState
} func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) {
var ssj SessionStateJSON
return sessionState, nil var ss *SessionState
err := json.Unmarshal([]byte(v), &ssj)
if err == nil && ssj.SessionState != nil {
// Extract SessionState and ExpiresOn value from SessionStateJSON
ss = ssj.SessionState
if ssj.ExpiresOn != nil {
ss.ExpiresOn = *ssj.ExpiresOn
}
} else {
// Try to decode a legacy string when json.Unmarshal failed
ss, err = legacyDecodeSessionState(v, c)
if err != nil {
return nil, err
}
}
if c == nil {
// Load only Email and User when cipher is unavailable
ss = &SessionState{
Email: ss.Email,
User: ss.User,
}
} else {
if ss.AccessToken != "" {
ss.AccessToken, err = c.Decrypt(ss.AccessToken)
if err != nil {
return nil, err
}
}
if ss.IDToken != "" {
ss.IDToken, err = c.Decrypt(ss.IDToken)
if err != nil {
return nil, err
}
}
if ss.RefreshToken != "" {
ss.RefreshToken, err = c.Decrypt(ss.RefreshToken)
if err != nil {
return nil, err
}
}
}
if ss.User == "" {
ss.User = strings.Split(ss.Email, "@")[0]
}
return ss, nil
} }

View File

@ -2,7 +2,6 @@ package providers
import ( import (
"fmt" "fmt"
"strings"
"testing" "testing"
"time" "time"
@ -27,7 +26,6 @@ func TestSessionStateSerialization(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(c) encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, 4, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c) ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
@ -65,7 +63,6 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(c) encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, 4, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c) ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
@ -96,8 +93,6 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(nil) encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
expected := fmt.Sprintf("email:%s user:", s.Email)
assert.Equal(t, expected, encoded)
// only email should have been serialized // only email should have been serialized
ss, err := DecodeSessionState(encoded, nil) ss, err := DecodeSessionState(encoded, nil)
@ -118,8 +113,6 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(nil) encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User)
assert.Equal(t, expected, encoded)
// only email should have been serialized // only email should have been serialized
ss, err := DecodeSessionState(encoded, nil) ss, err := DecodeSessionState(encoded, nil)
@ -130,19 +123,6 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
assert.Equal(t, "", ss.RefreshToken) assert.Equal(t, "", ss.RefreshToken)
} }
func TestSessionStateAccountInfo(t *testing.T) {
s := &SessionState{
Email: "user@domain.com",
User: "just-user",
}
expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User)
assert.Equal(t, expected, s.accountInfo())
s.Email = ""
expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User)
assert.Equal(t, expected, s.accountInfo())
}
func TestExpired(t *testing.T) { func TestExpired(t *testing.T) {
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
assert.Equal(t, true, s.IsExpired()) assert.Equal(t, true, s.IsExpired())
@ -153,3 +133,185 @@ func TestExpired(t *testing.T) {
s = &SessionState{} s = &SessionState{}
assert.Equal(t, false, s.IsExpired()) assert.Equal(t, false, s.IsExpired())
} }
type testCase struct {
SessionState
Encoded string
Cipher *cookie.Cipher
Error bool
}
// TestEncodeSessionState tests EncodeSessionState with the test vector
//
// Currently only tests without cipher here because we have no way to mock
// the random generator used in EncodeSessionState.
func TestEncodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour)
testCases := []testCase{
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
}
for i, tc := range testCases {
encoded, err := tc.EncodeSessionState(tc.Cipher)
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
if tc.Error {
assert.Error(t, err)
assert.Empty(t, encoded)
continue
}
assert.NoError(t, err)
assert.JSONEq(t, tc.Encoded, encoded)
}
}
// TestDecodeSessionState tests DecodeSessionState with the test vector
func TestDecodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour)
eJSON, _ := e.MarshalJSON()
eString := string(eJSON)
eUnix := e.Unix()
c, err := cookie.NewCipher([]byte(secret))
assert.NoError(t, err)
testCases := []testCase{
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "user",
},
Encoded: `{"Email":"user@domain.com"}`,
},
{
SessionState: SessionState{
User: "just-user",
},
Encoded: `{"User":"just-user"}`,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
Cipher: c,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
Cipher: c,
},
{
Encoded: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`,
Cipher: c,
Error: true,
},
{
Encoded: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`,
Cipher: c,
Error: true,
},
{
SessionState: SessionState{
User: "just-user",
Email: "user@domain.com",
},
Encoded: "email:user@domain.com user:just-user",
},
{
Encoded: "email:user@domain.com user:just-user||||",
Error: true,
},
{
Encoded: "email:user@domain.com user:just-user",
Cipher: c,
Error: true,
},
{
Encoded: "email:user@domain.com user:just-user|||99999999999999999999|",
Cipher: c,
Error: true,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix),
Cipher: c,
},
{
SessionState: SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix),
Cipher: c,
},
}
for i, tc := range testCases {
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
if tc.Error {
assert.Error(t, err)
assert.Nil(t, ss)
continue
}
assert.NoError(t, err)
if assert.NotNil(t, ss) {
assert.Equal(t, tc.User, ss.User)
assert.Equal(t, tc.Email, ss.Email)
assert.Equal(t, tc.AccessToken, ss.AccessToken)
assert.Equal(t, tc.RefreshToken, ss.RefreshToken)
assert.Equal(t, tc.IDToken, ss.IDToken)
assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
}
}
}

View File

@ -7,6 +7,11 @@ import (
// StringArray is a type alias for a slice of strings // StringArray is a type alias for a slice of strings
type StringArray []string type StringArray []string
// Get returns the slice of strings
func (a *StringArray) Get() interface{} {
return []string(*a)
}
// Set appends a string to the StringArray // Set appends a string to the StringArray
func (a *StringArray) Set(s string) error { func (a *StringArray) Set(s string) error {
*a = append(*a, s) *a = append(*a, s)