From 882fcf0a0108674002fdea2ba89c3d11598f8c2f Mon Sep 17 00:00:00 2001 From: Mark Maglana Date: Tue, 4 Jul 2017 14:36:21 -0700 Subject: [PATCH] providers: iterate across all pages from /user/orgs github endpoint. For some GHE instances where a user can have more than 100 organizations, traversing the other pages is important otherwise oauth2_proxy will consider the user unauthorized. This change traverses the list returned by the API to avoid that. Update github provider tests to include this case. --- providers/github.go | 70 ++++++++++++++++++++++++---------------- providers/github_test.go | 47 +++++++++++++++++++++------ 2 files changed, 80 insertions(+), 37 deletions(-) diff --git a/providers/github.go b/providers/github.go index f3af86f..26526ce 100644 --- a/providers/github.go +++ b/providers/github.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "path" + "strconv" "strings" ) @@ -61,36 +62,51 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { Login string `json:"login"` } - params := url.Values{ - "limit": {"100"}, + type orgsPage []struct { + Login string `json:"login"` } - endpoint := &url.URL{ - Scheme: p.ValidateURL.Scheme, - Host: p.ValidateURL.Host, - Path: path.Join(p.ValidateURL.Path, "/user/orgs"), - RawQuery: params.Encode(), - } - req, _ := http.NewRequest("GET", endpoint.String(), nil) - req.Header.Set("Accept", "application/vnd.github.v3+json") - req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return false, err - } + pn := 1 + for { + params := url.Values{ + "limit": {"200"}, + "page": {strconv.Itoa(pn)}, + } - body, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - if err != nil { - return false, err - } - if resp.StatusCode != 200 { - return false, fmt.Errorf( - "got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } + endpoint := &url.URL{ + Scheme: p.ValidateURL.Scheme, + Host: p.ValidateURL.Host, + Path: path.Join(p.ValidateURL.Path, "/user/orgs"), + RawQuery: params.Encode(), + } + req, _ := http.NewRequest("GET", endpoint.String(), nil) + req.Header.Set("Accept", "application/vnd.github.v3+json") + req.Header.Set("Authorization", fmt.Sprintf("token %s", accessToken)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return false, err + } - if err := json.Unmarshal(body, &orgs); err != nil { - return false, err + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return false, err + } + if resp.StatusCode != 200 { + return false, fmt.Errorf( + "got %d from %q %s", resp.StatusCode, endpoint.String(), body) + } + + var op orgsPage + if err := json.Unmarshal(body, &op); err != nil { + return false, err + } + if len(op) == 0 { + break + } + + orgs = append(orgs, op...) + pn += 1 } var presentOrgs []string @@ -118,7 +134,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { } params := url.Values{ - "limit": {"100"}, + "limit": {"200"}, } endpoint := &url.URL{ diff --git a/providers/github_test.go b/providers/github_test.go index 8080525..4810182 100644 --- a/providers/github_test.go +++ b/providers/github_test.go @@ -27,23 +27,32 @@ func testGitHubProvider(hostname string) *GitHubProvider { return p } -func testGitHubBackend(payload string) *httptest.Server { - pathToQueryMap := map[string]string{ - "/user": "", - "/user/emails": "", +func testGitHubBackend(payload []string) *httptest.Server { + pathToQueryMap := map[string][]string{ + "/user": []string{""}, + "/user/emails": []string{""}, + "/user/orgs": []string{"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"}, } return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { url := r.URL query, ok := pathToQueryMap[url.Path] + validQuery := false + index := 0 + for i, q := range query { + if q == url.RawQuery { + validQuery = true + index = i + } + } if !ok { w.WriteHeader(404) - } else if url.RawQuery != query { + } else if !validQuery { w.WriteHeader(404) } else { w.WriteHeader(200) - w.Write([]byte(payload)) + w.Write([]byte(payload[index])) } })) } @@ -89,7 +98,7 @@ func TestGitHubProviderOverrides(t *testing.T) { } func TestGitHubProviderGetEmailAddress(t *testing.T) { - b := testGitHubBackend(`[ {"email": "michael.bland@gsa.gov", "primary": true} ]`) + b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "primary": true} ]`}) defer b.Close() bURL, _ := url.Parse(b.URL) @@ -101,10 +110,28 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) { assert.Equal(t, "michael.bland@gsa.gov", email) } +func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { + b := testGitHubBackend([]string{ + `[ {"email": "michael.bland@gsa.gov", "primary": true, "login":"testorg"} ]`, + `[ {"email": "michael.bland1@gsa.gov", "primary": true, "login":"testorg1"} ]`, + `[ ]`, + }) + defer b.Close() + + bURL, _ := url.Parse(b.URL) + p := testGitHubProvider(bURL.Host) + p.Org = "testorg1" + + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetEmailAddress(session) + assert.Equal(t, nil, err) + assert.Equal(t, "michael.bland@gsa.gov", email) +} + // Note that trying to trigger the "failed building request" case is not // practical, since the only way it can fail is if the URL fails to parse. func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { - b := testGitHubBackend("unused payload") + b := testGitHubBackend([]string{"unused payload"}) defer b.Close() bURL, _ := url.Parse(b.URL) @@ -120,7 +147,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { } func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { - b := testGitHubBackend("{\"foo\": \"bar\"}") + b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"}) defer b.Close() bURL, _ := url.Parse(b.URL) @@ -133,7 +160,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { } func TestGitHubProviderGetUserName(t *testing.T) { - b := testGitHubBackend(`{"email": "michael.bland@gsa.gov", "login": "mbland"}`) + b := testGitHubBackend([]string{`{"email": "michael.bland@gsa.gov", "login": "mbland"}`}) defer b.Close() bURL, _ := url.Parse(b.URL)