Merge pull request #14 from pusher/oidc

OIDC ID Token, Authorization Headers, Refreshing and Verification
This commit is contained in:
Joel Speed 2019-01-22 15:56:37 +00:00 committed by GitHub
commit 440d2f32bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 227 additions and 40 deletions

View File

@ -2,6 +2,11 @@
## Changes since v3.0.0
- [#14](https://github.com/pusher/oauth2_proxy/pull/14) OIDC ID Token, Authorization Headers, Refreshing and Verification (@joelspeed)
- Implement `pass-authorization-header` and `set-authorization-header` flags
- Implement token refreshing in OIDC provider
- Split cookies larger than 4k limit into multiple cookies
- Implement token validation in OIDC provider
- [#21](https://github.com/pusher/oauth2_proxy/pull/21) Docker Improvement (@yaegashi)
- Move Docker base image from debian to alpine
- Install ca-certificates in docker image

View File

@ -212,6 +212,7 @@ Usage of oauth2_proxy:
-https-address string: <addr>:<port> to listen on for HTTPS clients (default ":443")
-login-url string: Authentication endpoint
-pass-access-token: pass OAuth access_token to upstream via X-Forwarded-Access-Token header
-pass-authorization-header: pass OIDC IDToken to upstream via Authorization Bearer header
-pass-basic-auth: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream (default true)
-pass-host-header: pass the request Host Header to upstream (default true)
-pass-user-headers: pass X-Forwarded-User and X-Forwarded-Email information to upstream (default true)
@ -225,6 +226,7 @@ Usage of oauth2_proxy:
-resource string: The resource that is protected (Azure AD only)
-scope string: OAuth scope specification
-set-xauthrequest: set X-Auth-Request-User and X-Auth-Request-Email response headers (useful in Nginx auth_request mode)
-set-authorization-header: set Authorization Bearer response header (useful in Nginx auth_request mode)
-signature-key string: GAP-Signature request signature key (algorithm:secretkey)
-skip-auth-preflight: will skip authentication for OPTIONS requests
-skip-auth-regex value: bypass authentication for requests path's that match (may be given multiple times)

View File

@ -37,6 +37,8 @@ func main() {
flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header")
flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header")
flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream")
flagSet.Bool("pass-authorization-header", false, "pass the Authorization Header to upstream")
flagSet.Bool("set-authorization-header", false, "set Authorization response headers (useful in Nginx auth_request mode)")
flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)")
flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start")
flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests")

View File

@ -26,6 +26,10 @@ const (
httpScheme = "http"
httpsScheme = "https"
// Cookies are limited to 4kb including the length of the cookie name,
// the cookie name can be up to 256 bytes
maxCookieLength = 3840
)
// SignatureHeaders contains the headers to be signed by the hmac algorithm
@ -76,6 +80,8 @@ type OAuthProxy struct {
PassUserHeaders bool
BasicAuthPassword string
PassAccessToken bool
SetAuthorization bool
PassAuthorization bool
CookieCipher *cookie.Cipher
skipAuthRegex []string
skipAuthPreflight bool
@ -183,7 +189,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, refresh)
var cipher *cookie.Cipher
if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
if opts.PassAccessToken || opts.SetAuthorization || opts.PassAuthorization || (opts.CookieRefresh != time.Duration(0)) {
var err error
cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret))
if err != nil {
@ -222,6 +228,8 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
PassUserHeaders: opts.PassUserHeaders,
BasicAuthPassword: opts.BasicAuthPassword,
PassAccessToken: opts.PassAccessToken,
SetAuthorization: opts.SetAuthorization,
PassAuthorization: opts.PassAuthorization,
SkipProviderButton: opts.SkipProviderButton,
CookieCipher: cipher,
templates: loadTemplates(opts.CustomTemplatesDir),
@ -278,15 +286,100 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
// MakeSessionCookie creates an http.Cookie containing the authenticated user's
// authentication details
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie {
if value != "" {
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
if len(value) > 4096 {
// Cookies cannot be larger than 4kb
log.Printf("WARNING - Cookie Size: %d bytes", len(value))
}
c := p.makeCookie(req, p.CookieName, value, expiration, now)
if len(c.Value) > 4096-len(p.CookieName) {
return splitCookie(c)
}
return []*http.Cookie{c}
}
func copyCookie(c *http.Cookie) *http.Cookie {
return &http.Cookie{
Name: c.Name,
Value: c.Value,
Path: c.Path,
Domain: c.Domain,
Expires: c.Expires,
RawExpires: c.RawExpires,
MaxAge: c.MaxAge,
Secure: c.Secure,
HttpOnly: c.HttpOnly,
Raw: c.Raw,
Unparsed: c.Unparsed,
}
}
return p.makeCookie(req, p.CookieName, value, expiration, now)
// splitCookie reads the full cookie generated to store the session and splits
// it into a slice of cookies which fit within the 4kb cookie limit indexing
// the cookies from 0
func splitCookie(c *http.Cookie) []*http.Cookie {
if len(c.Value) < maxCookieLength {
return []*http.Cookie{c}
}
cookies := []*http.Cookie{}
valueBytes := []byte(c.Value)
count := 0
for len(valueBytes) > 0 {
new := copyCookie(c)
new.Name = fmt.Sprintf("%s-%d", c.Name, count)
count++
if len(valueBytes) < maxCookieLength {
new.Value = string(valueBytes)
valueBytes = []byte{}
} else {
newValue := valueBytes[:maxCookieLength]
valueBytes = valueBytes[maxCookieLength:]
new.Value = string(newValue)
}
cookies = append(cookies, new)
}
return cookies
}
// joinCookies takes a slice of cookies from the request and reconstructs the
// full session cookie
func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
if len(cookies) == 0 {
return nil, fmt.Errorf("list of cookies must be > 0")
}
if len(cookies) == 1 {
return cookies[0], nil
}
c := copyCookie(cookies[0])
for i := 1; i < len(cookies); i++ {
c.Value += cookies[i].Value
}
c.Name = strings.TrimRight(c.Name, "-0")
return c, nil
}
// loadCookie retreieves the sessions state cookie from the http request.
// If a single cookie is present this will be returned, otherwise it attempts
// to reconstruct a cookie split up by splitCookie
func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
c, err := req.Cookie(cookieName)
if err == nil {
return c, nil
}
cookies := []*http.Cookie{}
err = nil
count := 0
for err == nil {
var c *http.Cookie
c, err = req.Cookie(fmt.Sprintf("%s-%d", cookieName, count))
if err == nil {
cookies = append(cookies, c)
count++
}
}
if len(cookies) == 0 {
return nil, fmt.Errorf("Could not find cookie %s", cookieName)
}
return joinCookies(cookies)
}
// MakeCSRFCookie creates a cookie for CSRF
@ -330,12 +423,14 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va
// ClearSessionCookie creates a cookie to unset the user's authentication cookie
// stored in the user's session
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())
cookies := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())
for _, clr := range cookies {
http.SetCookie(rw, clr)
}
// ugly hack because default domain changed
if p.CookieDomain == "" {
clr2 := *clr
if p.CookieDomain == "" && len(cookies) > 0 {
clr2 := *cookies[0]
clr2.Domain = req.Host
http.SetCookie(rw, &clr2)
}
@ -343,13 +438,15 @@ func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Reques
// SetSessionCookie adds the user's session cookie to the response
func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()))
for _, c := range p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()) {
http.SetCookie(rw, c)
}
}
// LoadCookiedSession reads the user's authentication details from the request
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
var age time.Duration
c, err := req.Cookie(p.CookieName)
c, err := loadCookie(req, p.CookieName)
if err != nil {
// always http.ErrNoCookie
return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
@ -750,6 +847,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
if p.PassAccessToken && session.AccessToken != "" {
req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken}
}
if p.PassAuthorization && session.IDToken != "" {
req.Header["Authorization"] = []string{fmt.Sprintf("Bearer %s", session.IDToken)}
}
if p.SetAuthorization && session.IDToken != "" {
rw.Header().Set("Authorization", fmt.Sprintf("Bearer %s", session.IDToken))
}
if session.Email == "" {
rw.Header().Set("GAP-Auth", session.User)
} else {

View File

@ -502,7 +502,7 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
})
}
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie {
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
}
@ -511,7 +511,9 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
if err != nil {
return err
}
p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref))
for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) {
p.req.AddCookie(c)
}
return nil
}
@ -800,8 +802,9 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
if err != nil {
panic(err)
}
cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now())
req.AddCookie(cookie)
for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) {
req.AddCookie(c)
}
// This is used by the upstream to validate the signature.
st.authenticator.auth = hmacauth.NewHmacAuth(
crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)

View File

@ -61,6 +61,8 @@ type Options struct {
PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"`
SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"`
SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"`
SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header"`
PassAuthorization bool `flag:"pass-authorization-header" cfg:"pass_authorization_header"`
SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"`
// These options allow for other providers besides Google, with
@ -113,6 +115,8 @@ func NewOptions() *Options {
PassUserHeaders: true,
PassAccessToken: false,
PassHostHeader: true,
SetAuthorization: false,
PassAuthorization: false,
ApprovalPrompt: "force",
RequestLogging: true,
RequestLoggingFormat: defaultRequestLoggingFormat,

View File

@ -145,6 +145,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
}
s = &SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
RefreshToken: jsonResponse.RefreshToken,
Email: email,

View File

@ -38,7 +38,61 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
if err != nil {
return nil, fmt.Errorf("token exchange: %v", err)
}
s, err = p.createSessionState(ctx, token)
if err != nil {
return nil, fmt.Errorf("unable to update session: %v", err)
}
return
}
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil
}
origExpiration := s.ExpiresOn
err := p.redeemRefreshToken(s)
if err != nil {
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
}
fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
return true, nil
}
func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
c := oauth2.Config{
ClientID: p.ClientID,
ClientSecret: p.ClientSecret,
Endpoint: oauth2.Endpoint{
TokenURL: p.RedeemURL.String(),
},
}
ctx := context.Background()
t := &oauth2.Token{
RefreshToken: s.RefreshToken,
Expiry: time.Now().Add(-time.Hour),
}
token, err := c.TokenSource(ctx, t).Token()
if err != nil {
return fmt.Errorf("failed to get token: %v", err)
}
newSession, err := p.createSessionState(ctx, token)
if err != nil {
return fmt.Errorf("unable to update session: %v", err)
}
s.AccessToken = newSession.AccessToken
s.IDToken = newSession.IDToken
s.RefreshToken = newSession.RefreshToken
s.ExpiresOn = newSession.ExpiresOn
s.Email = newSession.Email
return
}
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("token response did not contain an id_token")
@ -66,28 +120,22 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
}
s = &SessionState{
return &SessionState{
AccessToken: token.AccessToken,
IDToken: rawIDToken,
RefreshToken: token.RefreshToken,
ExpiresOn: token.Expiry,
Email: claims.Email,
}, nil
}
return
// ValidateSessionState checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool {
ctx := context.Background()
_, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil {
return false
}
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
//
// WARNGING: This implementation is broken and does not check with the upstream
// OIDC provider before refreshing the session
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil
}
origExpiration := s.ExpiresOn
s.ExpiresOn = time.Now().Add(time.Second).Truncate(time.Second)
fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration)
return false, nil
return true
}

View File

@ -12,6 +12,7 @@ import (
// SessionState is used to store information about the currently authenticated user session
type SessionState struct {
AccessToken string
IDToken string
ExpiresOn time.Time
RefreshToken string
Email string
@ -32,6 +33,9 @@ func (s *SessionState) String() string {
if s.AccessToken != "" {
o += " token:true"
}
if s.IDToken != "" {
o += " id_token:true"
}
if !s.ExpiresOn.IsZero() {
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
}
@ -65,13 +69,19 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
return "", err
}
}
i := s.IDToken
if i != "" {
if i, err = c.Encrypt(i); err != nil {
return "", err
}
}
r := s.RefreshToken
if r != "" {
if r, err = c.Encrypt(r); err != nil {
return "", err
}
}
return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil
return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil
}
func decodeSessionStatePlain(v string) (s *SessionState, err error) {
@ -96,8 +106,8 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error)
}
chunks := strings.Split(v, "|")
if len(chunks) != 4 {
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
if len(chunks) != 5 {
err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks))
return
}
@ -112,11 +122,17 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error)
}
}
ts, _ := strconv.Atoi(chunks[2])
if chunks[2] != "" {
if sessionState.IDToken, err = c.Decrypt(chunks[2]); err != nil {
return nil, err
}
}
ts, _ := strconv.Atoi(chunks[3])
sessionState.ExpiresOn = time.Unix(int64(ts), 0)
if chunks[3] != "" {
if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil {
if chunks[4] != "" {
if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil {
return nil, err
}
}

View File

@ -21,12 +21,13 @@ func TestSessionStateSerialization(t *testing.T) {
s := &SessionState{
Email: "user@domain.com",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err)
assert.Equal(t, 3, strings.Count(encoded, "|"))
assert.Equal(t, 4, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss)
@ -34,6 +35,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, "user", ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.IDToken, ss.IDToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
@ -45,6 +47,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.IDToken, ss.IDToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
}
@ -62,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
}
encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err)
assert.Equal(t, 3, strings.Count(encoded, "|"))
assert.Equal(t, 4, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss)