diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 2651908..47d0e56 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,3 +1,12 @@ # Default owner should be a Pusher cloud-team member unless overridden by later # rules in this file * @pusher/cloud-team + +# login.gov provider +# Note: If @timothy-spencer terms out of his appointment, your best bet +# for finding somebody who can test the oauth2_proxy would be to ask somebody +# in the login.gov team (https://login.gov/developers/), the cloud.gov team +# (https://cloud.gov/docs/help/), or the 18F org (https://18f.gsa.gov/contact/ +# or the public devops channel at https://chat.18f.gov/). +providers/logingov.go @timothy-spencer +providers/logingov_test.go @timothy-spencer diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d011a5..f0d92d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,13 @@ ## Changes since v3.1.0 - [#96](https://github.com/bitly/oauth2_proxy/pull/96) Check if email is verified on GitHub (@caarlos0) +- [#63](https://github.com/pusher/oauth2_proxy/pull/63) Use encoding/json for SessionState serialization (@yaegashi) + - Use JSON to encode session state to be stored in browser cookies + - Implement legacy decode function to support existing cookies generated by older versions + - Add detailed table driven tests in session_state_test.go +- [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added login.gov provider (@timothy-spencer) +- [#55](https://github.com/pusher/oauth2_proxy/pull/55) Added environment variables for all config options (@timothy-spencer) +- [#70](https://github.com/pusher/oauth2_proxy/pull/70) Fix handling of splitted cookies (@einfachchr) - [#92](https://github.com/pusher/oauth2_proxy/pull/92) Merge websocket proxy feature from openshift/oauth-proxy (@butzist) - [#57](https://github.com/pusher/oauth2_proxy/pull/57) Fall back to using OIDC Subject instead of Email (@aigarius) - [#85](https://github.com/pusher/oauth2_proxy/pull/85) Use non-root user in docker images (@kskewes) diff --git a/Gopkg.lock b/Gopkg.lock index 6d27a3b..51d8342 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -41,6 +41,14 @@ revision = "346938d642f2ec3594ed81d874461961cd0faa76" version = "v1.1.0" +[[projects]] + digest = "1:6098222470fe0172157ce9bbef5d2200df4edde17ee649c5d6e48330e4afa4c6" + name = "github.com/dgrijalva/jwt-go" + packages = ["."] + pruneopts = "" + revision = "06ea1031745cb8b3dab3f6a236daf2b0aa468b7e" + version = "v3.2.0" + [[projects]] branch = "master" digest = "1:3b760d3b93f994df8eb1d9ebfad17d3e9e37edcb7f7efaa15b427c0d7a64f4e4" @@ -201,6 +209,7 @@ "github.com/BurntSushi/toml", "github.com/bitly/go-simplejson", "github.com/coreos/go-oidc", + "github.com/dgrijalva/jwt-go", "github.com/mbland/hmacauth", "github.com/mreiferson/go-options", "github.com/stretchr/testify/assert", diff --git a/README.md b/README.md index 67ec599..1bb63be 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ Valid providers are : - [GitHub](#github-auth-provider) - [GitLab](#gitlab-auth-provider) - [LinkedIn](#linkedin-auth-provider) +- [login.gov](#login.gov-provider) The provider can be selected using the `provider` configuration value. @@ -166,6 +167,54 @@ OpenID Connect is a spec for OAUTH 2.0 + identity that is implemented by many ma -cookie-secure=false -email-domain example.com +### login.gov Provider + +login.gov is an OIDC provider for the US Government. +If you are a US Government agency, you can contact the login.gov team through the contact information +that you can find on https://login.gov/developers/ and work with them to understand how to get login.gov +accounts for integration/test and production access. + +A developer guide is available here: https://developers.login.gov/, though this proxy handles everything +but the data you need to create to register your application in the login.gov dashboard. + +As a demo, we will assume that you are running your application that you want to secure locally on +http://localhost:3000/, that you will be starting your proxy up on http://localhost:4180/, and that +you have an agency integration account for testing. + +First, register your application in the dashboard. The important bits are: + * Identity protocol: make this `Openid connect` + * Issuer: do what they say for OpenID Connect. We will refer to this string as `${LOGINGOV_ISSUER}`. + * Public key: This is a self-signed certificate in .pem format generated from a 2048 bit RSA private key. + A quick way to do this is `openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -days 3650 -nodes -subj '/C=US/ST=Washington/L=DC/O=GSA/OU=18F/CN=localhost'`, + The contents of the `key.pem` shall be referred to as `${OAUTH2_PROXY_JWT_KEY}`. + * Return to App URL: Make this be `http://localhost:4180/` + * Redirect URIs: Make this be `http://localhost:4180/oauth2/callback`. + * Attribute Bundle: Make sure that email is selected. + +Now start the proxy up with the following options: +``` +./oauth2_proxy -provider login.gov \ + -client-id=${LOGINGOV_ISSUER} \ + -redirect-url=http://localhost:4180/oauth2/callback \ + -oidc-issuer-url=https://idp.int.identitysandbox.gov/ \ + -cookie-secure=false \ + -email-domain=gsa.gov \ + -upstream=http://localhost:3000/ \ + -cookie-secret=somerandomstring12341234567890AB \ + -cookie-domain=localhost \ + -skip-provider-button=true \ + -pubjwk-url=https://idp.int.identitysandbox.gov/api/openid_connect/certs \ + -profile-url=https://idp.int.identitysandbox.gov/api/openid_connect/userinfo \ + -jwt-key="${OAUTH2_PROXY_JWT_KEY}" +``` +You can also set all these options with environment variables, for use in cloud/docker environments. + +Once it is running, you should be able to go to `http://localhost:4180/` in your browser, +get authenticated by the login.gov integration server, and then get proxied on to your +application running on `http://localhost:3000/`. In a real deployment, you would secure +your application with a firewall or something so that it was only accessible from the +proxy, and you would use real hostnames everywhere. + #### Skip OIDC discovery Some providers do not support OIDC discovery via their issuer URL, so oauth2_proxy cannot simply grab the authorization, token and jwks URI endpoints from the provider's metadata. diff --git a/logging_handler_test.go b/logging_handler_test.go index 9717cd6..de0efcc 100644 --- a/logging_handler_test.go +++ b/logging_handler_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" "time" ) @@ -16,7 +17,7 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { Format, ExpectedLogMessage string }{ - {defaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0.000\n", ts.Format("02/Jan/2006:15:04:05 -0700"))}, + {defaultRequestLoggingFormat, fmt.Sprintf("127.0.0.1 - - [%s] test-server GET - \"/foo/bar\" HTTP/1.1 \"\" 200 4 0", ts.Format("02/Jan/2006:15:04:05 -0700"))}, {"{{.RequestMethod}}", "GET\n"}, } @@ -35,8 +36,8 @@ func TestLoggingHandler_ServeHTTP(t *testing.T) { h.ServeHTTP(httptest.NewRecorder(), r) actual := buf.String() - if actual != test.ExpectedLogMessage { - t.Errorf("Log message was\n%s\ninstead of expected \n%s", actual, test.ExpectedLogMessage) + if !strings.Contains(actual, test.ExpectedLogMessage) { + t.Errorf("Log message was\n%s\ninstead of matching \n%s", actual, test.ExpectedLogMessage) } } } diff --git a/main.go b/main.go index 1992fc0..5625259 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "log" + "math/rand" "os" "runtime" "strings" @@ -88,6 +89,9 @@ func main() { flagSet.String("approval-prompt", "force", "OAuth approval_prompt") flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") + flagSet.String("acr-values", "http://idmanagement.gov/ns/assurance/loa/1", "acr values string: optional, used by login.gov") + flagSet.String("jwt-key", "", "private key used to sign JWT: required by login.gov") + flagSet.String("pubjwk-url", "", "JWK pubkey access endpoint: required by login.gov") flagSet.Parse(os.Args[1:]) @@ -133,6 +137,8 @@ func main() { } } + rand.Seed(time.Now().UnixNano()) + s := &Server{ Handler: LoggingHandler(os.Stdout, oauthproxy, opts.RequestLogging, opts.RequestLoggingFormat), Opts: opts, diff --git a/oauthproxy.go b/oauthproxy.go index 7b6ba88..561dce3 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -204,7 +204,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { } redirectURL := opts.redirectURL - if redirectURL.String() == "" { + if redirectURL.Path == "" { redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix) } @@ -241,7 +241,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix), SignOutPath: fmt.Sprintf("%s/sign_out", opts.ProxyPrefix), OAuthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix), - OAuthCallbackPath: redirectURL.Path, + OAuthCallbackPath: fmt.Sprintf("%s/callback", opts.ProxyPrefix), AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix), ProxyPrefix: opts.ProxyPrefix, @@ -452,9 +452,18 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va // ClearSessionCookie creates a cookie to unset the user's authentication cookie // stored in the user's session func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { - cookies := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) - for _, clr := range cookies { - http.SetCookie(rw, clr) + var cookies []*http.Cookie + + // matches CookieName, CookieName_ + var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", p.CookieName)) + + for _, c := range req.Cookies() { + if cookieNameRegex.MatchString(c.Name) { + clearCookie := p.makeCookie(req, c.Name, "", time.Hour*-1, time.Now()) + + http.SetCookie(rw, clearCookie) + cookies = append(cookies, clearCookie) + } } // ugly hack because default domain changed diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 757180a..45bacd7 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1064,3 +1064,47 @@ func TestAjaxForbiddendRequest(t *testing.T) { mime := rh.Get("Content-Type") assert.NotEqual(t, applicationJSON, mime) } + +func TestClearSplitCookie(t *testing.T) { + p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} + var rw = httptest.NewRecorder() + req := httptest.NewRequest("get", "/", nil) + + req.AddCookie(&http.Cookie{ + Name: "test1", + Value: "test1", + }) + req.AddCookie(&http.Cookie{ + Name: "oauth2_0", + Value: "oauth2_0", + }) + req.AddCookie(&http.Cookie{ + Name: "oauth2_1", + Value: "oauth2_1", + }) + + p.ClearSessionCookie(rw, req) + header := rw.Header() + + assert.Equal(t, 2, len(header["Set-Cookie"]), "should have 3 set-cookie header entries") +} + +func TestClearSingleCookie(t *testing.T) { + p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} + var rw = httptest.NewRecorder() + req := httptest.NewRequest("get", "/", nil) + + req.AddCookie(&http.Cookie{ + Name: "test1", + Value: "test1", + }) + req.AddCookie(&http.Cookie{ + Name: "oauth2", + Value: "oauth2", + }) + + p.ClearSessionCookie(rw, req) + header := rw.Header() + + assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries") +} diff --git a/options.go b/options.go index efa8596..90af3d3 100644 --- a/options.go +++ b/options.go @@ -14,6 +14,7 @@ import ( "time" oidc "github.com/coreos/go-oidc" + "github.com/dgrijalva/jwt-go" "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/providers" ) @@ -21,71 +22,74 @@ import ( // Options holds Configuration Options that can be set by Command Line Flag, // or Config File type Options struct { - ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"` - ProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets"` - HTTPAddress string `flag:"http-address" cfg:"http_address"` - HTTPSAddress string `flag:"https-address" cfg:"https_address"` - RedirectURL string `flag:"redirect-url" cfg:"redirect_url"` + ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix" env:"OAUTH2_PROXY_PROXY_PREFIX"` + ProxyWebSockets bool `flag:"proxy-websockets" cfg:"proxy_websockets" env:"OAUTH2_PROXY_PROXY_WEBSOCKETS"` + HTTPAddress string `flag:"http-address" cfg:"http_address" env:"OAUTH2_PROXY_HTTP_ADDRESS"` + HTTPSAddress string `flag:"https-address" cfg:"https_address" env:"OAUTH2_PROXY_HTTPS_ADDRESS"` + RedirectURL string `flag:"redirect-url" cfg:"redirect_url" env:"OAUTH2_PROXY_REDIRECT_URL"` ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` - TLSCertFile string `flag:"tls-cert" cfg:"tls_cert_file"` - TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file"` + TLSCertFile string `flag:"tls-cert" cfg:"tls_cert_file" env:"OAUTH2_PROXY_TLS_CERT_FILE"` + TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file" env:"OAUTH2_PROXY_TLS_KEY_FILE"` - AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` - AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"` - EmailDomains []string `flag:"email-domain" cfg:"email_domains"` + AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file" env:"OAUTH2_PROXY_AUTHENTICATED_EMAILS_FILE"` + AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant" env:"OAUTH2_PROXY_AZURE_TENANT"` + EmailDomains []string `flag:"email-domain" cfg:"email_domains" env:"OAUTH2_PROXY_EMAIL_DOMAINS"` WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains" env:"OAUTH2_PROXY_WHITELIST_DOMAINS"` - GitHubOrg string `flag:"github-org" cfg:"github_org"` - GitHubTeam string `flag:"github-team" cfg:"github_team"` - GoogleGroups []string `flag:"google-group" cfg:"google_group"` - GoogleAdminEmail string `flag:"google-admin-email" cfg:"google_admin_email"` - GoogleServiceAccountJSON string `flag:"google-service-account-json" cfg:"google_service_account_json"` - HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file"` - DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form"` - CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir"` - Footer string `flag:"footer" cfg:"footer"` + GitHubOrg string `flag:"github-org" cfg:"github_org" env:"OAUTH2_PROXY_GITHUB_ORG"` + GitHubTeam string `flag:"github-team" cfg:"github_team" env:"OAUTH2_PROXY_GITHUB_TEAM"` + GoogleGroups []string `flag:"google-group" cfg:"google_group" env:"OAUTH2_PROXY_GOOGLE_GROUPS"` + GoogleAdminEmail string `flag:"google-admin-email" cfg:"google_admin_email" env:"OAUTH2_PROXY_GOOGLE_ADMIN_EMAIL"` + GoogleServiceAccountJSON string `flag:"google-service-account-json" cfg:"google_service_account_json" env:"OAUTH2_PROXY_GOOGLE_SERVICE_ACCOUNT_JSON"` + HtpasswdFile string `flag:"htpasswd-file" cfg:"htpasswd_file" env:"OAUTH2_PROXY_HTPASSWD_FILE"` + DisplayHtpasswdForm bool `flag:"display-htpasswd-form" cfg:"display_htpasswd_form" env:"OAUTH2_PROXY_DISPLAY_HTPASSWD_FORM"` + CustomTemplatesDir string `flag:"custom-templates-dir" cfg:"custom_templates_dir" env:"OAUTH2_PROXY_CUSTOM_TEMPLATES_DIR"` + Footer string `flag:"footer" cfg:"footer" env:"OAUTH2_PROXY_FOOTER"` CookieName string `flag:"cookie-name" cfg:"cookie_name" env:"OAUTH2_PROXY_COOKIE_NAME"` CookieSecret string `flag:"cookie-secret" cfg:"cookie_secret" env:"OAUTH2_PROXY_COOKIE_SECRET"` CookieDomain string `flag:"cookie-domain" cfg:"cookie_domain" env:"OAUTH2_PROXY_COOKIE_DOMAIN"` CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` - CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"` - CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` + CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure" env:"OAUTH2_PROXY_COOKIE_SECURE"` + CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly" env:"OAUTH2_PROXY_COOKIE_HTTPONLY"` - Upstreams []string `flag:"upstream" cfg:"upstreams"` - SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` - PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"` - BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password"` - PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token"` - PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` - SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button"` - PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"` - SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` - SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"` - SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header"` - PassAuthorization bool `flag:"pass-authorization-header" cfg:"pass_authorization_header"` - SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` - FlushInterval time.Duration `flag:"flush-interval" cfg:"flush_interval"` + Upstreams []string `flag:"upstream" cfg:"upstreams" env:"OAUTH2_PROXY_UPSTREAMS"` + SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex" env:"OAUTH2_PROXY_SKIP_AUTH_REGEX"` + PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth" env:"OAUTH2_PROXY_PASS_BASIC_AUTH"` + BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password" env:"OAUTH2_PROXY_BASIC_AUTH_PASSWORD"` + PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token" env:"OAUTH2_PROXY_PASS_ACCESS_TOKEN"` + PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header" env:"OAUTH2_PROXY_PASS_HOST_HEADER"` + SkipProviderButton bool `flag:"skip-provider-button" cfg:"skip_provider_button" env:"OAUTH2_PROXY_SKIP_PROVIDER_BUTTON"` + PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers" env:"OAUTH2_PROXY_PASS_USER_HEADERS"` + SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify" env:"OAUTH2_PROXY_SSL_INSECURE_SKIP_VERIFY"` + SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest" env:"OAUTH2_PROXY_SET_XAUTHREQUEST"` + SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header" env:"OAUTH2_PROXY_SET_AUTHORIZATION_HEADER"` + PassAuthorization bool `flag:"pass-authorization-header" cfg:"pass_authorization_header" env:"OAUTH2_PROXY_PASS_AUTHORIZATION_HEADER"` + SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight" env:"OAUTH2_PROXY_SKIP_AUTH_PREFLIGHT"` + FlushInterval time.Duration `flag:"flush-interval" cfg:"flush_interval" env:"OAUTH2_PROXY_FLUSH_INTERVAL"` // These options allow for other providers besides Google, with // potential overrides. - Provider string `flag:"provider" cfg:"provider"` - OIDCIssuerURL string `flag:"oidc-issuer-url" cfg:"oidc_issuer_url"` - SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery"` - OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url"` - LoginURL string `flag:"login-url" cfg:"login_url"` - RedeemURL string `flag:"redeem-url" cfg:"redeem_url"` - ProfileURL string `flag:"profile-url" cfg:"profile_url"` - ProtectedResource string `flag:"resource" cfg:"resource"` - ValidateURL string `flag:"validate-url" cfg:"validate_url"` - Scope string `flag:"scope" cfg:"scope"` - ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt"` + Provider string `flag:"provider" cfg:"provider" env:"OAUTH2_PROXY_PROVIDER"` + OIDCIssuerURL string `flag:"oidc-issuer-url" cfg:"oidc_issuer_url" env:"OAUTH2_PROXY_OIDC_ISSUER_URL"` + SkipOIDCDiscovery bool `flag:"skip-oidc-discovery" cfg:"skip_oidc_discovery" env:"OAUTH2_SKIP_OIDC_DISCOVERY"` + OIDCJwksURL string `flag:"oidc-jwks-url" cfg:"oidc_jwks_url" env:"OAUTH2_OIDC_JWKS_URL"` + LoginURL string `flag:"login-url" cfg:"login_url" env:"OAUTH2_PROXY_LOGIN_URL"` + RedeemURL string `flag:"redeem-url" cfg:"redeem_url" env:"OAUTH2_PROXY_REDEEM_URL"` + ProfileURL string `flag:"profile-url" cfg:"profile_url" env:"OAUTH2_PROXY_PROFILE_URL"` + ProtectedResource string `flag:"resource" cfg:"resource" env:"OAUTH2_PROXY_RESOURCE"` + ValidateURL string `flag:"validate-url" cfg:"validate_url" env:"OAUTH2_PROXY_VALIDATE_URL"` + Scope string `flag:"scope" cfg:"scope" env:"OAUTH2_PROXY_SCOPE"` + ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt" env:"OAUTH2_PROXY_APPROVAL_PROMPT"` - RequestLogging bool `flag:"request-logging" cfg:"request_logging"` - RequestLoggingFormat string `flag:"request-logging-format" cfg:"request_logging_format"` + RequestLogging bool `flag:"request-logging" cfg:"request_logging" env:"OAUTH2_PROXY_REQUEST_LOGGING"` + RequestLoggingFormat string `flag:"request-logging-format" cfg:"request_logging_format" env:"OAUTH2_PROXY_REQUEST_LOGGING_FORMAT"` SignatureKey string `flag:"signature-key" cfg:"signature_key" env:"OAUTH2_PROXY_SIGNATURE_KEY"` + AcrValues string `flag:"acr-values" cfg:"acr_values" env:"OAUTH2_PROXY_ACR_VALUES"` + JWTKey string `flag:"jwt-key" cfg:"jwt_key" env:"OAUTH2_PROXY_JWT_KEY"` + PubJWKURL string `flag:"pubjwk-url" cfg:"pubjwk_url" env:"OAUTH2_PROXY_PUBJWK_URL"` // internal values that are set after config validation redirectURL *url.URL @@ -157,7 +161,8 @@ func (o *Options) Validate() error { if o.ClientID == "" { msgs = append(msgs, "missing setting: client-id") } - if o.ClientSecret == "" { + // login.gov uses a signed JWT to authenticate, not a client-secret + if o.ClientSecret == "" && o.Provider != "login.gov" { msgs = append(msgs, "missing setting: client-secret") } if o.AuthenticatedEmailsFile == "" && len(o.EmailDomains) == 0 && o.HtpasswdFile == "" { @@ -318,6 +323,19 @@ func parseProviderInfo(o *Options, msgs []string) []string { } else { p.Verifier = o.oidcVerifier } + case *providers.LoginGovProvider: + p.AcrValues = o.AcrValues + p.PubJWKURL, msgs = parseURL(o.PubJWKURL, "pubjwk", msgs) + if o.JWTKey == "" { + msgs = append(msgs, "login.gov provider requires a private key for signing JWTs") + } else { + signKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(o.JWTKey)) + if err != nil { + msgs = append(msgs, "could not parse RSA Private Key PEM") + } else { + p.JWTKey = signKey + } + } } return msgs } diff --git a/providers/logingov.go b/providers/logingov.go new file mode 100644 index 0000000..09bd3be --- /dev/null +++ b/providers/logingov.go @@ -0,0 +1,275 @@ +package providers + +import ( + "bytes" + "crypto/rsa" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "math/rand" + "net/http" + "net/url" + "time" + + "github.com/dgrijalva/jwt-go" + "gopkg.in/square/go-jose.v2" +) + +// LoginGovProvider represents an OIDC based Identity Provider +type LoginGovProvider struct { + *ProviderData + + // TODO (@timothy-spencer): Ideally, the nonce would be in the session state, but the session state + // is created only upon code redemption, not during the auth, when this must be supplied. + Nonce string + AcrValues string + JWTKey *rsa.PrivateKey + PubJWKURL *url.URL +} + +// For generating a nonce +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randSeq(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +// NewLoginGovProvider initiates a new LoginGovProvider +func NewLoginGovProvider(p *ProviderData) *LoginGovProvider { + p.ProviderName = "login.gov" + + if p.LoginURL == nil || p.LoginURL.String() == "" { + p.LoginURL = &url.URL{ + Scheme: "https", + Host: "secure.login.gov", + Path: "/openid_connect/authorize", + } + } + if p.RedeemURL == nil || p.RedeemURL.String() == "" { + p.RedeemURL = &url.URL{ + Scheme: "https", + Host: "secure.login.gov", + Path: "/api/openid_connect/token", + } + } + if p.ProfileURL == nil || p.ProfileURL.String() == "" { + p.ProfileURL = &url.URL{ + Scheme: "https", + Host: "secure.login.gov", + Path: "/api/openid_connect/userinfo", + } + } + if p.Scope == "" { + p.Scope = "email openid" + } + + return &LoginGovProvider{ + ProviderData: p, + Nonce: randSeq(32), + } +} + +type loginGovCustomClaims struct { + Acr string `json:"acr"` + Nonce string `json:"nonce"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + GivenName string `json:"given_name"` + FamilyName string `json:"family_name"` + Birthdate string `json:"birthdate"` + AtHash string `json:"at_hash"` + CHash string `json:"c_hash"` + jwt.StandardClaims +} + +// checkNonce checks the nonce in the id_token +func checkNonce(idToken string, p *LoginGovProvider) (err error) { + token, err := jwt.ParseWithClaims(idToken, &loginGovCustomClaims{}, func(token *jwt.Token) (interface{}, error) { + resp, myerr := http.Get(p.PubJWKURL.String()) + if myerr != nil { + return nil, myerr + } + if resp.StatusCode != 200 { + myerr = fmt.Errorf("got %d from %q", resp.StatusCode, p.PubJWKURL.String()) + return nil, myerr + } + body, myerr := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if myerr != nil { + return nil, myerr + } + + var pubkeys jose.JSONWebKeySet + myerr = json.Unmarshal(body, &pubkeys) + if myerr != nil { + return nil, myerr + } + pubkey := pubkeys.Keys[0] + + return pubkey.Key, nil + }) + if err != nil { + return + } + + claims := token.Claims.(*loginGovCustomClaims) + if claims.Nonce != p.Nonce { + err = fmt.Errorf("nonce validation failed") + return + } + return +} + +func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email string, err error) { + // query the user info endpoint for user attributes + var req *http.Request + req, err = http.NewRequest("GET", userInfoEndpoint, nil) + if err != nil { + return + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + var body []byte + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return + } + + if resp.StatusCode != 200 { + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, userInfoEndpoint, body) + return + } + + // parse the user attributes from the data we got and make sure that + // the email address has been validated. + var emailData struct { + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + } + err = json.Unmarshal(body, &emailData) + if err != nil { + return + } + if emailData.Email == "" { + err = fmt.Errorf("missing email") + return + } + email = emailData.Email + if !emailData.EmailVerified { + err = fmt.Errorf("email %s not listed as verified", email) + return + } + return +} + +// Redeem exchanges the OAuth2 authentication token for an ID token +func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { + if code == "" { + err = errors.New("missing code") + return + } + + claims := &jwt.StandardClaims{ + Issuer: p.ClientID, + Subject: p.ClientID, + Audience: p.RedeemURL.String(), + ExpiresAt: int64(time.Now().Add(time.Duration(5 * time.Minute)).Unix()), + Id: randSeq(32), + } + token := jwt.NewWithClaims(jwt.GetSigningMethod("RS256"), claims) + ss, err := token.SignedString(p.JWTKey) + if err != nil { + return + } + + params := url.Values{} + params.Add("client_assertion", ss) + params.Add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + params.Add("code", code) + params.Add("grant_type", "authorization_code") + + var req *http.Request + req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) + if err != nil { + return + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + var resp *http.Response + resp, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + var body []byte + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return + } + + if resp.StatusCode != 200 { + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) + return + } + + // Get the token from the body that we got from the token endpoint. + var jsonResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + } + err = json.Unmarshal(body, &jsonResponse) + if err != nil { + return + } + + // check nonce here + err = checkNonce(jsonResponse.IDToken, p) + if err != nil { + return + } + + // Get the email address + var email string + email, err = emailFromUserInfo(jsonResponse.AccessToken, p.ProfileURL.String()) + if err != nil { + return + } + + // Store the data that we found in the session state + s = &SessionState{ + AccessToken: jsonResponse.AccessToken, + IDToken: jsonResponse.IDToken, + ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), + Email: email, + } + return +} + +// GetLoginURL overrides GetLoginURL to add login.gov parameters +func (p *LoginGovProvider) GetLoginURL(redirectURI, state string) string { + var a url.URL + a = *p.LoginURL + params, _ := url.ParseQuery(a.RawQuery) + params.Set("redirect_uri", redirectURI) + params.Set("approval_prompt", p.ApprovalPrompt) + params.Add("scope", p.Scope) + params.Set("client_id", p.ClientID) + params.Set("response_type", "code") + params.Add("state", state) + params.Add("acr_values", p.AcrValues) + params.Add("nonce", p.Nonce) + a.RawQuery = params.Encode() + return a.String() +} diff --git a/providers/logingov_test.go b/providers/logingov_test.go new file mode 100644 index 0000000..29808d0 --- /dev/null +++ b/providers/logingov_test.go @@ -0,0 +1,290 @@ +package providers + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/assert" + "gopkg.in/square/go-jose.v2" +) + +type MyKeyData struct { + PubKey crypto.PublicKey + PrivKey *rsa.PrivateKey + PubJWK jose.JSONWebKey +} + +func newLoginGovServer(body []byte) (*url.URL, *httptest.Server) { + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Write(body) + })) + u, _ := url.Parse(s.URL) + return u, s +} + +func newLoginGovProvider() (l *LoginGovProvider, serverKey *MyKeyData, err error) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + serverKey = &MyKeyData{ + PubKey: key.Public(), + PrivKey: key, + PubJWK: jose.JSONWebKey{ + Key: key.Public(), + KeyID: "testkey", + Algorithm: string(jose.RS256), + Use: "sig", + }, + } + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return + } + + l = NewLoginGovProvider( + &ProviderData{ + ProviderName: "", + LoginURL: &url.URL{}, + RedeemURL: &url.URL{}, + ProfileURL: &url.URL{}, + ValidateURL: &url.URL{}, + Scope: ""}) + l.JWTKey = privateKey + l.Nonce = "fakenonce" + return +} + +func TestLoginGovProviderDefaults(t *testing.T) { + p, _, err := newLoginGovProvider() + assert.NotEqual(t, nil, p) + assert.NoError(t, err) + assert.Equal(t, "login.gov", p.Data().ProviderName) + assert.Equal(t, "https://secure.login.gov/openid_connect/authorize", + p.Data().LoginURL.String()) + assert.Equal(t, "https://secure.login.gov/api/openid_connect/token", + p.Data().RedeemURL.String()) + assert.Equal(t, "https://secure.login.gov/api/openid_connect/userinfo", + p.Data().ProfileURL.String()) + assert.Equal(t, "email openid", p.Data().Scope) +} + +func TestLoginGovProviderOverrides(t *testing.T) { + p := NewLoginGovProvider( + &ProviderData{ + LoginURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/auth"}, + RedeemURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/token"}, + ProfileURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/profile"}, + Scope: "profile"}) + assert.NotEqual(t, nil, p) + assert.Equal(t, "login.gov", p.Data().ProviderName) + assert.Equal(t, "https://example.com/oauth/auth", + p.Data().LoginURL.String()) + assert.Equal(t, "https://example.com/oauth/token", + p.Data().RedeemURL.String()) + assert.Equal(t, "https://example.com/oauth/profile", + p.Data().ProfileURL.String()) + assert.Equal(t, "profile", p.Data().Scope) +} + +func TestLoginGovProviderSessionData(t *testing.T) { + p, serverkey, err := newLoginGovProvider() + assert.NotEqual(t, nil, p) + assert.NoError(t, err) + + // Set up the redeem endpoint here + type loginGovRedeemResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + IDToken string `json:"id_token"` + } + expiresIn := int64(60) + type MyCustomClaims struct { + Acr string `json:"acr"` + Nonce string `json:"nonce"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + GivenName string `json:"given_name"` + FamilyName string `json:"family_name"` + Birthdate string `json:"birthdate"` + AtHash string `json:"at_hash"` + CHash string `json:"c_hash"` + jwt.StandardClaims + } + claims := MyCustomClaims{ + "http://idmanagement.gov/ns/assurance/loa/1", + "fakenonce", + "timothy.spencer@gsa.gov", + true, + "", + "", + "", + "", + "", + jwt.StandardClaims{ + Audience: "Audience", + ExpiresAt: time.Now().Unix() + expiresIn, + Id: "foo", + IssuedAt: time.Now().Unix(), + Issuer: "https://idp.int.login.gov", + NotBefore: time.Now().Unix() - 1, + Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca", + }, + } + idtoken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signedidtoken, err := idtoken.SignedString(serverkey.PrivKey) + assert.NoError(t, err) + body, err := json.Marshal(loginGovRedeemResponse{ + AccessToken: "a1234", + TokenType: "Bearer", + ExpiresIn: expiresIn, + IDToken: signedidtoken, + }) + assert.NoError(t, err) + var server *httptest.Server + p.RedeemURL, server = newLoginGovServer(body) + defer server.Close() + + // Set up the user endpoint here + type loginGovUserResponse struct { + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Subject string `json:"sub"` + } + userbody, err := json.Marshal(loginGovUserResponse{ + Email: "timothy.spencer@gsa.gov", + EmailVerified: true, + Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca", + }) + assert.NoError(t, err) + var userserver *httptest.Server + p.ProfileURL, userserver = newLoginGovServer(userbody) + defer userserver.Close() + + // Set up the PubJWKURL endpoint here used to verify the JWT + var pubkeys jose.JSONWebKeySet + pubkeys.Keys = append(pubkeys.Keys, serverkey.PubJWK) + pubjwkbody, err := json.Marshal(pubkeys) + assert.NoError(t, err) + var pubjwkserver *httptest.Server + p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody) + defer pubjwkserver.Close() + + session, err := p.Redeem("http://redirect/", "code1234") + assert.NoError(t, err) + assert.NotEqual(t, session, nil) + assert.Equal(t, "timothy.spencer@gsa.gov", session.Email) + assert.Equal(t, "a1234", session.AccessToken) + + // The test ought to run in under 2 seconds. If not, you may need to bump this up. + assert.InDelta(t, session.ExpiresOn.Unix(), time.Now().Unix()+expiresIn, 2) +} + +func TestLoginGovProviderBadNonce(t *testing.T) { + p, serverkey, err := newLoginGovProvider() + assert.NotEqual(t, nil, p) + assert.NoError(t, err) + + // Set up the redeem endpoint here + type loginGovRedeemResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + IDToken string `json:"id_token"` + } + expiresIn := int64(60) + type MyCustomClaims struct { + Acr string `json:"acr"` + Nonce string `json:"nonce"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + GivenName string `json:"given_name"` + FamilyName string `json:"family_name"` + Birthdate string `json:"birthdate"` + AtHash string `json:"at_hash"` + CHash string `json:"c_hash"` + jwt.StandardClaims + } + claims := MyCustomClaims{ + "http://idmanagement.gov/ns/assurance/loa/1", + "badfakenonce", + "timothy.spencer@gsa.gov", + true, + "", + "", + "", + "", + "", + jwt.StandardClaims{ + Audience: "Audience", + ExpiresAt: time.Now().Unix() + expiresIn, + Id: "foo", + IssuedAt: time.Now().Unix(), + Issuer: "https://idp.int.login.gov", + NotBefore: time.Now().Unix() - 1, + Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca", + }, + } + idtoken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signedidtoken, err := idtoken.SignedString(serverkey.PrivKey) + assert.NoError(t, err) + body, err := json.Marshal(loginGovRedeemResponse{ + AccessToken: "a1234", + TokenType: "Bearer", + ExpiresIn: expiresIn, + IDToken: signedidtoken, + }) + assert.NoError(t, err) + var server *httptest.Server + p.RedeemURL, server = newLoginGovServer(body) + defer server.Close() + + // Set up the user endpoint here + type loginGovUserResponse struct { + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Subject string `json:"sub"` + } + userbody, err := json.Marshal(loginGovUserResponse{ + Email: "timothy.spencer@gsa.gov", + EmailVerified: true, + Subject: "b2d2d115-1d7e-4579-b9d6-f8e84f4f56ca", + }) + assert.NoError(t, err) + var userserver *httptest.Server + p.ProfileURL, userserver = newLoginGovServer(userbody) + defer userserver.Close() + + // Set up the PubJWKURL endpoint here used to verify the JWT + var pubkeys jose.JSONWebKeySet + pubkeys.Keys = append(pubkeys.Keys, serverkey.PubJWK) + pubjwkbody, err := json.Marshal(pubkeys) + assert.NoError(t, err) + var pubjwkserver *httptest.Server + p.PubJWKURL, pubjwkserver = newLoginGovServer(pubjwkbody) + defer pubjwkserver.Close() + + _, err = p.Redeem("http://redirect/", "code1234") + + // The "badfakenonce" in the idtoken above should cause this to error out + assert.Error(t, err) +} diff --git a/providers/providers.go b/providers/providers.go index 15e6421..4616153 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -33,6 +33,8 @@ func New(provider string, p *ProviderData) Provider { return NewGitLabProvider(p) case "oidc": return NewOIDCProvider(p) + case "login.gov": + return NewLoginGovProvider(p) default: return NewGoogleProvider(p) } diff --git a/providers/session_state.go b/providers/session_state.go index 2862cdd..4741b4a 100644 --- a/providers/session_state.go +++ b/providers/session_state.go @@ -1,6 +1,7 @@ package providers import ( + "encoding/json" "fmt" "strconv" "strings" @@ -11,12 +12,18 @@ import ( // SessionState is used to store information about the currently authenticated user session type SessionState struct { - AccessToken string - IDToken string - ExpiresOn time.Time - RefreshToken string - Email string - User string + AccessToken string `json:",omitempty"` + IDToken string `json:",omitempty"` + ExpiresOn time.Time `json:"-"` + RefreshToken string `json:",omitempty"` + Email string `json:",omitempty"` + User string `json:",omitempty"` +} + +// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value +type SessionStateJSON struct { + *SessionState + ExpiresOn *time.Time `json:",omitempty"` } // IsExpired checks whether the session has expired @@ -29,7 +36,7 @@ func (s *SessionState) IsExpired() bool { // String constructs a summary of the session state func (s *SessionState) String() string { - o := fmt.Sprintf("Session{%s", s.accountInfo()) + o := fmt.Sprintf("Session{email:%s user:%s", s.Email, s.User) if s.AccessToken != "" { o += " token:true" } @@ -47,95 +54,145 @@ func (s *SessionState) String() string { // EncodeSessionState returns string representation of the current session func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { - if c == nil || s.AccessToken == "" { - return s.accountInfo(), nil - } - return s.EncryptedString(c) -} - -func (s *SessionState) accountInfo() string { - return fmt.Sprintf("email:%s user:%s", s.Email, s.User) -} - -// EncryptedString encrypts the session state into a cookie string -func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { - var err error + var ss SessionState if c == nil { - panic("error. missing cipher") - } - a := s.AccessToken - if a != "" { - if a, err = c.Encrypt(a); err != nil { - return "", err + // Store only Email and User when cipher is unavailable + ss.Email = s.Email + ss.User = s.User + } else { + ss = *s + var err error + if ss.AccessToken != "" { + ss.AccessToken, err = c.Encrypt(ss.AccessToken) + if err != nil { + return "", err + } + } + if ss.IDToken != "" { + ss.IDToken, err = c.Encrypt(ss.IDToken) + if err != nil { + return "", err + } + } + if ss.RefreshToken != "" { + ss.RefreshToken, err = c.Encrypt(ss.RefreshToken) + if err != nil { + return "", err + } } } - i := s.IDToken - if i != "" { - if i, err = c.Encrypt(i); err != nil { - return "", err - } + // Embed SessionState and ExpiresOn pointer into SessionStateJSON + ssj := &SessionStateJSON{SessionState: &ss} + if !ss.ExpiresOn.IsZero() { + ssj.ExpiresOn = &ss.ExpiresOn } - r := s.RefreshToken - if r != "" { - if r, err = c.Encrypt(r); err != nil { - return "", err - } - } - return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil + b, err := json.Marshal(ssj) + return string(b), err } -func decodeSessionStatePlain(v string) (s *SessionState, err error) { +// legacyDecodeSessionStatePlain decodes older plain session state string +func legacyDecodeSessionStatePlain(v string) (*SessionState, error) { chunks := strings.Split(v, " ") if len(chunks) != 2 { - return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks)) + return nil, fmt.Errorf("invalid session state (legacy: expected 2 chunks for user/email got %d)", len(chunks)) } - email := strings.TrimPrefix(chunks[0], "email:") user := strings.TrimPrefix(chunks[1], "user:") - if user == "" { - user = strings.Split(email, "@")[0] - } + email := strings.TrimPrefix(chunks[0], "email:") return &SessionState{User: user, Email: email}, nil } -// DecodeSessionState decodes the session cookie string into a SessionState -func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { - if c == nil { - return decodeSessionStatePlain(v) - } - +// legacyDecodeSessionState attempts to decode the session state string +// generated by v3.1.0 or older +func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) { chunks := strings.Split(v, "|") - if len(chunks) != 5 { - err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks)) - return + + if c == nil { + if len(chunks) != 1 { + return nil, fmt.Errorf("invalid session state (legacy: expected 1 chunk for plain got %d)", len(chunks)) + } + return legacyDecodeSessionStatePlain(chunks[0]) } - sessionState, err := decodeSessionStatePlain(chunks[0]) + if len(chunks) != 4 && len(chunks) != 5 { + return nil, fmt.Errorf("invalid session state (legacy: expected 4 or 5 chunks for full got %d)", len(chunks)) + } + + i := 0 + ss, err := legacyDecodeSessionStatePlain(chunks[i]) if err != nil { return nil, err } - if chunks[1] != "" { - if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil { - return nil, err - } + i++ + ss.AccessToken = chunks[i] + + if len(chunks) == 5 { + // SessionState with IDToken in v3.1.0 + i++ + ss.IDToken = chunks[i] } - if chunks[2] != "" { - if sessionState.IDToken, err = c.Decrypt(chunks[2]); err != nil { - return nil, err - } + i++ + ts, err := strconv.Atoi(chunks[i]) + if err != nil { + return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err) } + ss.ExpiresOn = time.Unix(int64(ts), 0) - ts, _ := strconv.Atoi(chunks[3]) - sessionState.ExpiresOn = time.Unix(int64(ts), 0) + i++ + ss.RefreshToken = chunks[i] - if chunks[4] != "" { - if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil { - return nil, err - } - } - - return sessionState, nil + return ss, nil +} + +// DecodeSessionState decodes the session cookie string into a SessionState +func DecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error) { + var ssj SessionStateJSON + var ss *SessionState + err := json.Unmarshal([]byte(v), &ssj) + if err == nil && ssj.SessionState != nil { + // Extract SessionState and ExpiresOn value from SessionStateJSON + ss = ssj.SessionState + if ssj.ExpiresOn != nil { + ss.ExpiresOn = *ssj.ExpiresOn + } + } else { + // Try to decode a legacy string when json.Unmarshal failed + ss, err = legacyDecodeSessionState(v, c) + if err != nil { + return nil, err + } + } + if c == nil { + // Load only Email and User when cipher is unavailable + ss = &SessionState{ + Email: ss.Email, + User: ss.User, + } + } else { + if ss.AccessToken != "" { + ss.AccessToken, err = c.Decrypt(ss.AccessToken) + if err != nil { + return nil, err + } + } + if ss.IDToken != "" { + ss.IDToken, err = c.Decrypt(ss.IDToken) + if err != nil { + return nil, err + } + } + if ss.RefreshToken != "" { + ss.RefreshToken, err = c.Decrypt(ss.RefreshToken) + if err != nil { + return nil, err + } + } + } + if ss.User == "" { + ss.User = strings.Split(ss.Email, "@")[0] + } + return ss, nil } diff --git a/providers/session_state_test.go b/providers/session_state_test.go index 504228f..9557eea 100644 --- a/providers/session_state_test.go +++ b/providers/session_state_test.go @@ -2,7 +2,6 @@ package providers import ( "fmt" - "strings" "testing" "time" @@ -27,7 +26,6 @@ func TestSessionStateSerialization(t *testing.T) { } encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - assert.Equal(t, 4, strings.Count(encoded, "|")) ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) @@ -65,7 +63,6 @@ func TestSessionStateSerializationWithUser(t *testing.T) { } encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - assert.Equal(t, 4, strings.Count(encoded, "|")) ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) @@ -96,8 +93,6 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { } encoded, err := s.EncodeSessionState(nil) assert.Equal(t, nil, err) - expected := fmt.Sprintf("email:%s user:", s.Email) - assert.Equal(t, expected, encoded) // only email should have been serialized ss, err := DecodeSessionState(encoded, nil) @@ -118,8 +113,6 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { } encoded, err := s.EncodeSessionState(nil) assert.Equal(t, nil, err) - expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User) - assert.Equal(t, expected, encoded) // only email should have been serialized ss, err := DecodeSessionState(encoded, nil) @@ -130,19 +123,6 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { assert.Equal(t, "", ss.RefreshToken) } -func TestSessionStateAccountInfo(t *testing.T) { - s := &SessionState{ - Email: "user@domain.com", - User: "just-user", - } - expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User) - assert.Equal(t, expected, s.accountInfo()) - - s.Email = "" - expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User) - assert.Equal(t, expected, s.accountInfo()) -} - func TestExpired(t *testing.T) { s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} assert.Equal(t, true, s.IsExpired()) @@ -153,3 +133,185 @@ func TestExpired(t *testing.T) { s = &SessionState{} assert.Equal(t, false, s.IsExpired()) } + +type testCase struct { + SessionState + Encoded string + Cipher *cookie.Cipher + Error bool +} + +// TestEncodeSessionState tests EncodeSessionState with the test vector +// +// Currently only tests without cipher here because we have no way to mock +// the random generator used in EncodeSessionState. +func TestEncodeSessionState(t *testing.T) { + e := time.Now().Add(time.Duration(1) * time.Hour) + + testCases := []testCase{ + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + }, + Encoded: `{"Email":"user@domain.com","User":"just-user"}`, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + AccessToken: "token1234", + IDToken: "rawtoken1234", + ExpiresOn: e, + RefreshToken: "refresh4321", + }, + Encoded: `{"Email":"user@domain.com","User":"just-user"}`, + }, + } + + for i, tc := range testCases { + encoded, err := tc.EncodeSessionState(tc.Cipher) + t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) + if tc.Error { + assert.Error(t, err) + assert.Empty(t, encoded) + continue + } + assert.NoError(t, err) + assert.JSONEq(t, tc.Encoded, encoded) + } +} + +// TestDecodeSessionState tests DecodeSessionState with the test vector +func TestDecodeSessionState(t *testing.T) { + e := time.Now().Add(time.Duration(1) * time.Hour) + eJSON, _ := e.MarshalJSON() + eString := string(eJSON) + eUnix := e.Unix() + + c, err := cookie.NewCipher([]byte(secret)) + assert.NoError(t, err) + + testCases := []testCase{ + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + }, + Encoded: `{"Email":"user@domain.com","User":"just-user"}`, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "user", + }, + Encoded: `{"Email":"user@domain.com"}`, + }, + { + SessionState: SessionState{ + User: "just-user", + }, + Encoded: `{"User":"just-user"}`, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + }, + Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + AccessToken: "token1234", + IDToken: "rawtoken1234", + ExpiresOn: e, + RefreshToken: "refresh4321", + }, + Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), + Cipher: c, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + }, + Encoded: `{"Email":"user@domain.com","User":"just-user"}`, + Cipher: c, + }, + { + Encoded: `{"Email":"user@domain.com","User":"just-user","AccessToken":"X"}`, + Cipher: c, + Error: true, + }, + { + Encoded: `{"Email":"user@domain.com","User":"just-user","IDToken":"XXXX"}`, + Cipher: c, + Error: true, + }, + { + SessionState: SessionState{ + User: "just-user", + Email: "user@domain.com", + }, + Encoded: "email:user@domain.com user:just-user", + }, + { + Encoded: "email:user@domain.com user:just-user||||", + Error: true, + }, + { + Encoded: "email:user@domain.com user:just-user", + Cipher: c, + Error: true, + }, + { + Encoded: "email:user@domain.com user:just-user|||99999999999999999999|", + Cipher: c, + Error: true, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + AccessToken: "token1234", + ExpiresOn: e, + RefreshToken: "refresh4321", + }, + Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix), + Cipher: c, + }, + { + SessionState: SessionState{ + Email: "user@domain.com", + User: "just-user", + AccessToken: "token1234", + IDToken: "rawtoken1234", + ExpiresOn: e, + RefreshToken: "refresh4321", + }, + Encoded: fmt.Sprintf("email:user@domain.com user:just-user|I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==|xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==|%d|qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K", eUnix), + Cipher: c, + }, + } + + for i, tc := range testCases { + ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) + t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) + if tc.Error { + assert.Error(t, err) + assert.Nil(t, ss) + continue + } + assert.NoError(t, err) + if assert.NotNil(t, ss) { + assert.Equal(t, tc.User, ss.User) + assert.Equal(t, tc.Email, ss.Email) + assert.Equal(t, tc.AccessToken, ss.AccessToken) + assert.Equal(t, tc.RefreshToken, ss.RefreshToken) + assert.Equal(t, tc.IDToken, ss.IDToken) + assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) + } + } +} diff --git a/string_array.go b/string_array.go index 5a624be..a6e1d96 100644 --- a/string_array.go +++ b/string_array.go @@ -7,6 +7,11 @@ import ( // StringArray is a type alias for a slice of strings type StringArray []string +// Get returns the slice of strings +func (a *StringArray) Get() interface{} { + return []string(*a) +} + // Set appends a string to the StringArray func (a *StringArray) Set(s string) error { *a = append(*a, s)