diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 34bf030..35ed59a 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -229,8 +229,9 @@ func TestIsValidRedirect(t *testing.T) { type TestProvider struct { *providers.ProviderData - EmailAddress string - ValidToken bool + EmailAddress string + ValidToken bool + GroupValidator func(string) bool } func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { @@ -255,6 +256,9 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { Scope: "profile.email", }, EmailAddress: emailAddress, + GroupValidator: func(s string) bool { + return true + }, } } @@ -266,6 +270,13 @@ func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) boo return tp.ValidToken } +func (tp *TestProvider) ValidateGroup(email string) bool { + if tp.GroupValidator != nil { + return tp.GroupValidator(email) + } + return true +} + func TestBasicAuthPassword(t *testing.T) { providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { logger.Printf("%#v", r) @@ -791,6 +802,25 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { assert.Equal(t, "unauthorized request\n", string(bodyBytes)) } +func TestAuthOnlyEndpointUnauthorizedOnProviderGroupValidationFailure(t *testing.T) { + test := NewAuthOnlyEndpointTest() + startSession := &sessions.SessionState{ + Email: "michael.bland@gsa.gov", AccessToken: "my_access_token", CreatedAt: time.Now()} + test.SaveSession(startSession) + provider := &TestProvider{ + ValidToken: true, + GroupValidator: func(s string) bool { + return false + }, + } + + test.proxy.provider = provider + test.proxy.ServeHTTP(test.rw, test.req) + assert.Equal(t, http.StatusUnauthorized, test.rw.Code) + bodyBytes, _ := ioutil.ReadAll(test.rw.Body) + assert.Equal(t, "unauthorized request\n", string(bodyBytes)) +} + func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { var pcTest ProcessCookieTest @@ -1168,69 +1198,80 @@ func TestGetJwtSession(t *testing.T) { keyset := NoOpKeySet{} verifier := oidc.NewVerifier("https://issuer.example.com", keyset, &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) - p := OAuthProxy{} - p.jwtBearerVerifiers = append(p.jwtBearerVerifiers, verifier) - req, _ := http.NewRequest("GET", "/", strings.NewReader("")) + test := NewAuthOnlyEndpointTest(func(opts *Options) { + opts.PassAuthorization = true + opts.SetAuthorization = true + opts.SetXAuthRequest = true + opts.SkipJwtBearerTokens = true + opts.jwtBearerVerifiers = append(opts.jwtBearerVerifiers, verifier) + }) + tp, _ := test.proxy.provider.(*TestProvider) + tp.GroupValidator = func(s string) bool { + return true + } + authHeader := fmt.Sprintf("Bearer %s", goodJwt) - req.Header = map[string][]string{ + test.req.Header = map[string][]string{ "Authorization": {authHeader}, } // Bearer - session, _ := p.GetJwtSession(req) + session, _ := test.proxy.GetJwtSession(test.req) assert.Equal(t, session.User, "john@example.com") assert.Equal(t, session.Email, "john@example.com") assert.Equal(t, session.ExpiresOn, time.Unix(1912151821, 0)) assert.Equal(t, session.IDToken, goodJwt) - jwtProviderServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger.Printf("%#v", r) - var payload string - payload = r.Header.Get("Authorization") - if payload == "" { - payload = "No Authorization header found." - } - w.WriteHeader(200) - w.Write([]byte(payload)) - })) - - opts := NewOptions() - opts.Upstreams = append(opts.Upstreams, jwtProviderServer.URL) - opts.PassAuthorization = true - opts.SetAuthorization = true - opts.SetXAuthRequest = true - opts.CookieSecret = "0123456789abcdef0123" - opts.SkipJwtBearerTokens = true - opts.Validate() - - // We can't actually use opts.Validate() because it will attempt to find a jwks URI - opts.jwtBearerVerifiers = append(opts.jwtBearerVerifiers, verifier) - - providerURL, _ := url.Parse(jwtProviderServer.URL) - const emailAddress = "john@example.com" - - opts.provider = NewTestProvider(providerURL, emailAddress) - jwtTestProxy := NewOAuthProxy(opts, func(email string) bool { - return email == emailAddress - }) - - rw := httptest.NewRecorder() - jwtTestProxy.ServeHTTP(rw, req) - if rw.Code >= 400 { - t.Fatalf("expected 3xx got %d", rw.Code) + test.proxy.ServeHTTP(test.rw, test.req) + if test.rw.Code >= 400 { + t.Fatalf("expected 3xx got %d", test.rw.Code) } // Check PassAuthorization, should overwrite Basic header - assert.Equal(t, req.Header.Get("Authorization"), authHeader) - assert.Equal(t, req.Header.Get("X-Forwarded-User"), "john@example.com") - assert.Equal(t, req.Header.Get("X-Forwarded-Email"), "john@example.com") + assert.Equal(t, test.req.Header.Get("Authorization"), authHeader) + assert.Equal(t, test.req.Header.Get("X-Forwarded-User"), "john@example.com") + assert.Equal(t, test.req.Header.Get("X-Forwarded-Email"), "john@example.com") // SetAuthorization and SetXAuthRequest - assert.Equal(t, rw.Header().Get("Authorization"), authHeader) - assert.Equal(t, rw.Header().Get("X-Auth-Request-User"), "john@example.com") - assert.Equal(t, rw.Header().Get("X-Auth-Request-Email"), "john@example.com") + assert.Equal(t, test.rw.Header().Get("Authorization"), authHeader) + assert.Equal(t, test.rw.Header().Get("X-Auth-Request-User"), "john@example.com") + assert.Equal(t, test.rw.Header().Get("X-Auth-Request-Email"), "john@example.com") +} +func TestJwtUnauthorizedOnGroupValidationFailure(t *testing.T) { + goodJwt := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + + "eyJzdWIiOiIxMjM0NTY3ODkwIiwiYXVkIjoiaHR0cHM6Ly90ZXN0Lm15YXBwLmNvbSIsIm5hbWUiOiJKb2huIERvZSIsImVtY" + + "WlsIjoiam9obkBleGFtcGxlLmNvbSIsImlzcyI6Imh0dHBzOi8vaXNzdWVyLmV4YW1wbGUuY29tIiwiaWF0IjoxNTUzNjkxMj" + + "E1LCJleHAiOjE5MTIxNTE4MjF9." + + "rLVyzOnEldUq_pNkfa-WiV8TVJYWyZCaM2Am_uo8FGg11zD7l-qmz3x1seTvqpH6Y0Ty00fmv6dJnGnC8WMnPXQiodRTfhBSe" + + "OKZMu0HkMD2sg52zlKkbfLTO6ic5VnbVgwjjrB8am_Ta6w7kyFUaB5C1BsIrrLMldkWEhynbb8" + + keyset := NoOpKeySet{} + verifier := oidc.NewVerifier("https://issuer.example.com", keyset, + &oidc.Config{ClientID: "https://test.myapp.com", SkipExpiryCheck: true}) + + test := NewAuthOnlyEndpointTest(func(opts *Options) { + opts.PassAuthorization = true + opts.SetAuthorization = true + opts.SetXAuthRequest = true + opts.SkipJwtBearerTokens = true + opts.jwtBearerVerifiers = append(opts.jwtBearerVerifiers, verifier) + }) + tp, _ := test.proxy.provider.(*TestProvider) + // Verify ValidateGroup fails JWT authorization + tp.GroupValidator = func(s string) bool { + return false + } + + authHeader := fmt.Sprintf("Bearer %s", goodJwt) + test.req.Header = map[string][]string{ + "Authorization": {authHeader}, + } + test.proxy.ServeHTTP(test.rw, test.req) + if test.rw.Code != http.StatusUnauthorized { + t.Fatalf("expected 401 got %d", test.rw.Code) + } } func TestFindJwtBearerToken(t *testing.T) {