diff --git a/oauthproxy.go b/oauthproxy.go index f7bbef3..714c7a3 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -712,7 +712,7 @@ func (p *OAuthProxy) getAuthenticatedSession(rw http.ResponseWriter, req *http.R if p.skipJwtBearerTokens && req.Header.Get("Authorization") != "" { session, err = p.GetJwtSession(req) if err != nil { - logger.Printf("Error validating JWT token from Authorization header: %s", err) + logger.Printf("Error retrieving session from token in Authorization header: %s", err) } if session != nil { saveSession = false @@ -938,9 +938,9 @@ func (p *OAuthProxy) findBearerToken(req *http.Request) (string, error) { if len(s) != 2 { return "", fmt.Errorf("invalid authorization header %s", auth) } - + jwtRegex := regexp.MustCompile(`^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`) var rawBearerToken string - if s[0] == "Bearer" { + if s[0] == "Bearer" && jwtRegex.MatchString(s[1]) { rawBearerToken = s[1] } else if s[0] == "Basic" { // Check if we have a Bearer token masquerading in Basic @@ -955,7 +955,6 @@ func (p *OAuthProxy) findBearerToken(req *http.Request) (string, error) { user, password := pair[0], pair[1] // check user, user+password, or just password for a token - jwtRegex := regexp.MustCompile(`^eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]+$`) if jwtRegex.MatchString(user) { // Support blank passwords or magic `x-oauth-basic` passwords - nothing else if password == "" || password == "x-oauth-basic" { @@ -965,8 +964,9 @@ func (p *OAuthProxy) findBearerToken(req *http.Request) (string, error) { // support passwords and ignore user rawBearerToken = password } - } else { - return "", fmt.Errorf("invalid authorization header %s", auth) + } + if rawBearerToken == "" { + return "", fmt.Errorf("no valid bearer token found in authorization header") } return rawBearerToken, nil diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 1d09bbb..493ce72 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -3,6 +3,7 @@ package main import ( "crypto" "encoding/base64" + "fmt" "io" "io/ioutil" "net" @@ -1132,3 +1133,65 @@ func TestClearSingleCookie(t *testing.T) { assert.Equal(t, 1, len(header["Set-Cookie"]), "should have 1 set-cookie header entries") } + +func TestFindJwtBearerToken(t *testing.T) { + p := OAuthProxy{CookieName: "oauth2", CookieDomain: "abc"} + getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}} + + validToken := "eyJfoobar.eyJfoobar.12345asdf" + var token string + + // Bearer + getReq.Header = map[string][]string{ + "Authorization": {fmt.Sprintf("Bearer %s", validToken)}, + } + + token, _ = p.findBearerToken(getReq) + assert.Equal(t, validToken, token) + + // Basic - no password + getReq.SetBasicAuth(token, "") + token, _ = p.findBearerToken(getReq) + assert.Equal(t, validToken, token) + + // Basic - sentinel password + getReq.SetBasicAuth(token, "x-oauth-basic") + token, _ = p.findBearerToken(getReq) + assert.Equal(t, validToken, token) + + // Basic - any username, password matching jwt pattern + getReq.SetBasicAuth("any-username-you-could-wish-for", token) + token, _ = p.findBearerToken(getReq) + assert.Equal(t, validToken, token) + + failures := []string{ + // Too many parts + "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.dGVzdA.dGVzdA", + // Not enough parts + "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA", + // Invalid encrypted key + "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.//////.dGVzdA.dGVzdA.dGVzdA", + // Invalid IV + "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.//////.dGVzdA.dGVzdA", + // Invalid ciphertext + "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.//////.dGVzdA", + // Invalid tag + "eyJhbGciOiJSU0EtT0FFUCIsImVuYyI6IkExMjhHQ00ifQ.dGVzdA.dGVzdA.dGVzdA.//////", + // Invalid header + "W10.dGVzdA.dGVzdA.dGVzdA.dGVzdA", + // Invalid header + "######.dGVzdA.dGVzdA.dGVzdA.dGVzdA", + // Missing alc/enc params + "e30.dGVzdA.dGVzdA.dGVzdA.dGVzdA", + } + + for _, failure := range failures { + getReq.Header = map[string][]string{ + "Authorization": {fmt.Sprintf("Bearer %s", failure)}, + } + _, err := p.findBearerToken(getReq) + assert.Error(t, err) + } + + fmt.Printf("%s", token) +}