From 90208c7fe4f74c2da4bd59c13fe6aba9daaffdd3 Mon Sep 17 00:00:00 2001 From: Lukasz Leszczuk Date: Wed, 21 Aug 2019 09:20:55 +0200 Subject: [PATCH] Fix tests --- oauthproxy.go | 8 ++--- oauthproxy_test.go | 16 +++------ providers/azure.go | 68 +++++++---------------------------- providers/azure_test.go | 38 ++++++++++---------- providers/bitbucket.go | 25 ++++++------- providers/bitbucket_test.go | 24 ++++++------- providers/facebook.go | 14 ++++---- providers/github.go | 21 +++++------ providers/github_test.go | 30 ++++++++-------- providers/gitlab.go | 15 ++++---- providers/gitlab_test.go | 20 +++++------ providers/google.go | 4 +++ providers/internal_util.go | 14 ++++++++ providers/linkedin.go | 14 ++++---- providers/linkedin_test.go | 18 +++++----- providers/provider_default.go | 10 ++---- providers/providers.go | 7 +++- 17 files changed, 160 insertions(+), 186 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index f30cfd3..de2cd4d 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -335,11 +335,11 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, if s.Email == "" { userDetails, err := p.provider.GetUserDetails(s) if err != nil { - return s, err + return nil, err } - s.Email = userDetails["email"] - if uid, found := userDetails["uid"]; found { - s.ID = uid + s.Email = userDetails.Email + if userDetails.UID != "" { + s.ID = userDetails.UID } else { s.ID = "" } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index d51d242..d02ce43 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -262,18 +262,10 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { } } -func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) { - return tp.EmailAddress, nil -} - -func (tp *TestProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) { - userDetails := map[string]string{} - email, err := tp.GetEmailAddress(s) - if err != nil { - return nil, err - } - userDetails["email"] = email - return userDetails, nil +func (tp *TestProvider) GetUserDetails(session *sessions.SessionState) (*providers.UserDetails, error) { + return &providers.UserDetails{ + Email: tp.EmailAddress, + }, nil } func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool { diff --git a/providers/azure.go b/providers/azure.go index fdaf0a5..95244c8 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -119,64 +119,22 @@ func getUserIDFromJSON(json *simplejson.Json) (string, error) { return uid, err } -// GetEmailAddress returns the Account email address -func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { - var email string +func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { var err error if s.AccessToken == "" { - return "", errors.New("missing access token") + return nil, errors.New("missing access token") } req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) if err != nil { - return "", err + return nil, err } req.Header = getAzureHeader(s.AccessToken) json, err := requests.Request(req) if err != nil { - return "", err - } - - email, err = getEmailFromJSON(json) - - if err == nil && email != "" { - return email, err - } - - email, err = json.Get("userPrincipalName").String() - - if err != nil { - logger.Printf("failed making request %s", err) - return "", err - } - - if email == "" { - logger.Printf("failed to get email address") - return "", err - } - - return email, err -} - -func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) { - userDetails := map[string]string{} - var err error - - if s.AccessToken == "" { - return userDetails, errors.New("missing access token") - } - req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) - if err != nil { - return userDetails, err - } - req.Header = getAzureHeader(s.AccessToken) - - json, err := requests.Request(req) - - if err != nil { - return userDetails, err + return nil, err } logger.Printf(" JSON: %v", json) @@ -184,26 +142,26 @@ func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]str logger.Printf("\t %20v : %v", key, value) } email, err := getEmailFromJSON(json) - userDetails["email"] = email if err != nil { - logger.Printf("[GetEmailAddress] failed making request: %s", err) - return userDetails, err + logger.Printf("[GetUserDetails] failed making request: %s", err) + return nil, err } uid, err := getUserIDFromJSON(json) - userDetails["uid"] = uid if err != nil { - logger.Printf("[GetEmailAddress] failed to get User ID: %s", err) + logger.Printf("[GetUserDetails] failed to get User ID: %s", err) } if email == "" { logger.Printf("failed to get email address") - return userDetails, errors.New("Client email not found") + return nil, errors.New("Client email not found") } - logger.Printf("[GetEmailAddress] Chosen email address: '%s'", email) - - return userDetails, nil + logger.Printf("[GetUserDetails] Chosen email address: '%s'", email) + return &UserDetails{ + Email: email, + UID: uid, + }, nil } // Get list of groups user belong to. Filter the desired names of groups (in case of huge group set) diff --git a/providers/azure_test.go b/providers/azure_test.go index 12b65a6..d1365a3 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -121,7 +121,7 @@ func testAzureBackend(payload string) *httptest.Server { })) } -func TestAzureProviderGetEmailAddress(t *testing.T) { +func TestAzureProviderGetUserDetails(t *testing.T) { b := testAzureBackend(`{ "mail": "user@windows.net" }`) defer b.Close() @@ -129,12 +129,12 @@ func TestAzureProviderGetEmailAddress(t *testing.T) { p := testAzureProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "user@windows.net", email) + assert.Equal(t, "user@windows.net", details.Email) } -func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { +func TestAzureProviderGetUserDetailsMailNull(t *testing.T) { b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`) defer b.Close() @@ -142,12 +142,12 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { p := testAzureProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "user@windows.net", email) + assert.Equal(t, "user@windows.net", details.Email) } -func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { +func TestAzureProviderGetUserDetailsGetUserPrincipalName(t *testing.T) { b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`) defer b.Close() @@ -155,12 +155,12 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { p := testAzureProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "user@windows.net", email) + assert.Equal(t, "user@windows.net", details.Email) } -func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { +func TestAzureProviderGetUserDetailsFailToGetEmailAddress(t *testing.T) { b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`) defer b.Close() @@ -168,12 +168,12 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { p := testAzureProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, "type assertion to string failed", err.Error()) - assert.Equal(t, "", email) + assert.Nil(t, details) } -func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { +func TestAzureProviderGetUserDetailsEmptyUserPrincipalName(t *testing.T) { b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`) defer b.Close() @@ -181,12 +181,12 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { p := testAzureProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) - assert.Equal(t, nil, err) - assert.Equal(t, "", email) + details, err := p.GetUserDetails(session) + assert.Equal(t, "Client email not found", err.Error()) + assert.Nil(t, details) } -func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { +func TestAzureProviderGetUserDetailsIncorrectOtherMails(t *testing.T) { b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`) defer b.Close() @@ -194,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { p := testAzureProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, "type assertion to string failed", err.Error()) - assert.Equal(t, "", email) + assert.Nil(t, details) } diff --git a/providers/bitbucket.go b/providers/bitbucket.go index 63c1d0f..c787795 100644 --- a/providers/bitbucket.go +++ b/providers/bitbucket.go @@ -64,8 +64,7 @@ func (p *BitbucketProvider) SetRepository(repository string) { } // GetEmailAddress returns the email of the authenticated user -func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { - +func (p *BitbucketProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { var emails struct { Values []struct { Email string `json:"email"` @@ -86,12 +85,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) if err != nil { logger.Printf("failed building request %s", err) - return "", err + return nil, err } err = requests.RequestJSON(req, &emails) if err != nil { logger.Printf("failed making request %s", err) - return "", err + return nil, err } if p.Team != "" { @@ -102,12 +101,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e teamURL.String()+"?role=member&access_token="+s.AccessToken, nil) if err != nil { logger.Printf("failed building request %s", err) - return "", err + return nil, err } err = requests.RequestJSON(req, &teams) if err != nil { logger.Printf("failed requesting teams membership %s", err) - return "", err + return nil, err } var found = false for _, team := range teams.Values { @@ -118,7 +117,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e } if found != true { logger.Print("team membership test failed, access denied") - return "", nil + return nil, nil } } @@ -133,12 +132,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e nil) if err != nil { logger.Printf("failed building request %s", err) - return "", err + return nil, err } err = requests.RequestJSON(req, &repositories) if err != nil { logger.Printf("failed checking repository access %s", err) - return "", err + return nil, err } var found = false for _, repository := range repositories.Values { @@ -149,15 +148,17 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e } if found != true { logger.Print("repository access test failed, access denied") - return "", nil + return nil, nil } } for _, email := range emails.Values { if email.Primary { - return email.Email, nil + return &UserDetails{ + Email: email.Email, + }, nil } } - return "", nil + return nil, nil } diff --git a/providers/bitbucket_test.go b/providers/bitbucket_test.go index 585603d..ce99e0c 100644 --- a/providers/bitbucket_test.go +++ b/providers/bitbucket_test.go @@ -112,7 +112,7 @@ func TestBitbucketProviderOverrides(t *testing.T) { assert.Equal(t, "profile", p.Data().Scope) } -func TestBitbucketProviderGetEmailAddress(t *testing.T) { +func TestBitbucketProviderGetUserDetails(t *testing.T) { b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true } ] }") defer b.Close() @@ -120,12 +120,12 @@ func TestBitbucketProviderGetEmailAddress(t *testing.T) { p := testBitbucketProvider(bURL.Host, "", "") session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) + assert.Equal(t, "michael.bland@gsa.gov", details.Email) } -func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) { +func TestBitbucketProviderGetUserDetailsAndGroup(t *testing.T) { b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true, \"username\": \"bioinformatics\" } ] }") defer b.Close() @@ -133,14 +133,14 @@ func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) { p := testBitbucketProvider(bURL.Host, "bioinformatics", "") session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) + assert.Equal(t, "michael.bland@gsa.gov", details.Email) } // Note that trying to trigger the "failed building request" case is not // practical, since the only way it can fail is if the URL fails to parse. -func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) { +func TestBitbucketProviderGetUserDetailsFailedRequest(t *testing.T) { b := testBitbucketBackend("unused payload") defer b.Close() @@ -151,12 +151,12 @@ func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) { // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) + assert.Nil(t, details) } -func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { +func TestBitbucketProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) { b := testBitbucketBackend("{\"foo\": \"bar\"}") defer b.Close() @@ -164,7 +164,7 @@ func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) p := testBitbucketProvider(bURL.Host, "", "") session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) - assert.Equal(t, "", email) + details, err := p.GetUserDetails(session) + assert.Nil(t, details) assert.Equal(t, nil, err) } diff --git a/providers/facebook.go b/providers/facebook.go index abd5382..7a768c2 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -55,13 +55,13 @@ func getFacebookHeader(accessToken string) http.Header { } // GetEmailAddress returns the Account email address -func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *FacebookProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { if s.AccessToken == "" { - return "", errors.New("missing access token") + return nil, errors.New("missing access token") } req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil) if err != nil { - return "", err + return nil, err } req.Header = getFacebookHeader(s.AccessToken) @@ -71,12 +71,14 @@ func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, er var r result err = requests.RequestJSON(req, &r) if err != nil { - return "", err + return nil, err } if r.Email == "" { - return "", errors.New("no email") + return nil, errors.New("no email") } - return r.Email, nil + return &UserDetails{ + Email: r.Email, + }, nil } // ValidateSessionState validates the AccessToken diff --git a/providers/github.go b/providers/github.go index ba58bb1..01c90b6 100644 --- a/providers/github.go +++ b/providers/github.go @@ -201,8 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { } // GetEmailAddress returns the Account email address -func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { - +func (p *GitHubProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { var emails []struct { Email string `json:"email"` Primary bool `json:"primary"` @@ -213,11 +212,11 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro if p.Org != "" { if p.Team != "" { if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok { - return "", err + return nil, err } } else { if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok { - return "", err + return nil, err } } } @@ -231,32 +230,34 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken)) resp, err := http.DefaultClient.Do(req) if err != nil { - return "", err + return nil, err } body, err := ioutil.ReadAll(resp.Body) resp.Body.Close() if err != nil { - return "", err + return nil, err } if resp.StatusCode != 200 { - return "", fmt.Errorf("got %d from %q %s", + return nil, fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) } logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) if err := json.Unmarshal(body, &emails); err != nil { - return "", fmt.Errorf("%s unmarshaling %s", err, body) + return nil, fmt.Errorf("%s unmarshaling %s", err, body) } for _, email := range emails { if email.Primary && email.Verified { - return email.Email, nil + return &UserDetails{ + Email: email.Email, + }, nil } } - return "", nil + return nil, nil } // GetUserName returns the Account user name diff --git a/providers/github_test.go b/providers/github_test.go index 2d45b84..8ffad4e 100644 --- a/providers/github_test.go +++ b/providers/github_test.go @@ -97,7 +97,7 @@ func TestGitHubProviderOverrides(t *testing.T) { assert.Equal(t, "profile", p.Data().Scope) } -func TestGitHubProviderGetEmailAddress(t *testing.T) { +func TestGitHubProviderGetUserDetails(t *testing.T) { b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}) defer b.Close() @@ -105,12 +105,12 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) { p := testGitHubProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) + assert.Equal(t, "michael.bland@gsa.gov", details.Email) } -func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) { +func TestGitHubProviderGetUserDetailsNotVerified(t *testing.T) { b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": false, "primary": true} ]`}) defer b.Close() @@ -118,12 +118,12 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) { p := testGitHubProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Empty(t, "", email) + assert.Nil(t, details) } -func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { +func TestGitHubProviderGetUserDetailsWithOrg(t *testing.T) { b := testGitHubBackend([]string{ `[ {"email": "michael.bland@gsa.gov", "primary": true, "verified": true, "login":"testorg"} ]`, `[ {"email": "michael.bland1@gsa.gov", "primary": true, "verified": true, "login":"testorg1"} ]`, @@ -136,14 +136,14 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { p.Org = "testorg1" session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "michael.bland@gsa.gov", email) + assert.Equal(t, "michael.bland@gsa.gov", details.Email) } // Note that trying to trigger the "failed building request" case is not // practical, since the only way it can fail is if the URL fails to parse. -func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { +func TestGitHubProviderGetUserDetailsFailedRequest(t *testing.T) { b := testGitHubBackend([]string{"unused payload"}) defer b.Close() @@ -154,12 +154,12 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) + assert.Nil(t, details) } -func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { +func TestGitHubProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) { b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"}) defer b.Close() @@ -167,9 +167,9 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { p := testGitHubProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) + assert.Nil(t, details) } func TestGitHubProviderGetUserName(t *testing.T) { diff --git a/providers/gitlab.go b/providers/gitlab.go index c32ebe8..155097e 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -220,31 +220,32 @@ func (p *GitLabProvider) ValidateSessionState(s *sessions.SessionState) bool { } // GetEmailAddress returns the Account email address -func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *GitLabProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { // Retrieve user info userInfo, err := p.getUserInfo(s) if err != nil { - return "", fmt.Errorf("failed to retrieve user info: %v", err) + return nil, fmt.Errorf("failed to retrieve user info: %v", err) } // Check if email is verified if !p.AllowUnverifiedEmail && !userInfo.EmailVerified { - return "", fmt.Errorf("user email is not verified") + return nil, fmt.Errorf("user email is not verified") } // Check if email has valid domain err = p.verifyEmailDomain(userInfo) if err != nil { - return "", fmt.Errorf("email domain check failed: %v", err) + return nil, fmt.Errorf("email domain check failed: %v", err) } // Check group membership err = p.verifyGroupMembership(userInfo) if err != nil { - return "", fmt.Errorf("group membership check failed: %v", err) + return nil, fmt.Errorf("group membership check failed: %v", err) } - - return userInfo.Email, nil + return &UserDetails{ + Email: userInfo.Email, + }, nil } // GetUserName returns the Account user name diff --git a/providers/gitlab_test.go b/providers/gitlab_test.go index f75c4bf..1e3f258 100644 --- a/providers/gitlab_test.go +++ b/providers/gitlab_test.go @@ -63,7 +63,7 @@ func TestGitLabProviderBadToken(t *testing.T) { p := testGitLabProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"} - _, err := p.GetEmailAddress(session) + _, err := p.GetUserDetails(session) assert.NotEqual(t, nil, err) } @@ -75,7 +75,7 @@ func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) { p := testGitLabProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - _, err := p.GetEmailAddress(session) + _, err := p.GetUserDetails(session) assert.NotEqual(t, nil, err) } @@ -88,9 +88,9 @@ func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) { p.AllowUnverifiedEmail = true session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "foo@bar.com", email) + assert.Equal(t, "foo@bar.com", details.Email) } func TestGitLabProviderUsername(t *testing.T) { @@ -117,9 +117,9 @@ func TestGitLabProviderGroupMembershipValid(t *testing.T) { p.Group = "foo" session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "foo@bar.com", email) + assert.Equal(t, "foo@bar.com", details.Email) } func TestGitLabProviderGroupMembershipMissing(t *testing.T) { @@ -132,7 +132,7 @@ func TestGitLabProviderGroupMembershipMissing(t *testing.T) { p.Group = "baz" session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - _, err := p.GetEmailAddress(session) + _, err := p.GetUserDetails(session) assert.NotEqual(t, nil, err) } @@ -146,9 +146,9 @@ func TestGitLabProviderEmailDomainValid(t *testing.T) { p.EmailDomains = []string{"bar.com"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "foo@bar.com", email) + assert.Equal(t, "foo@bar.com", details.Email) } func TestGitLabProviderEmailDomainInvalid(t *testing.T) { @@ -161,6 +161,6 @@ func TestGitLabProviderEmailDomainInvalid(t *testing.T) { p.EmailDomains = []string{"baz.com"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"} - _, err := p.GetEmailAddress(session) + _, err := p.GetUserDetails(session) assert.NotEqual(t, nil, err) } diff --git a/providers/google.go b/providers/google.go index 4748631..d3e8d75 100644 --- a/providers/google.go +++ b/providers/google.go @@ -233,6 +233,10 @@ func (p *GoogleProvider) ValidateGroup(email string) bool { return p.GroupValidator(email) } +func (p *GoogleProvider) ValidateGroupWithSession(s *sessions.SessionState) bool { + return p.ValidateGroup(s.Email) +} + // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { diff --git a/providers/internal_util.go b/providers/internal_util.go index fb33b31..d345a50 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -9,6 +9,20 @@ import ( "github.com/pusher/oauth2_proxy/pkg/requests" ) +type NotImplementedError struct { + message string +} + +func NewNotImplementedError(message string) *NotImplementedError { + return &NotImplementedError{ + message: "Not implemented: " + message, + } +} + +func (e *NotImplementedError) Error() string { + return e.message +} + // stripToken is a helper function to obfuscate "access_token" // query parameters func stripToken(endpoint string) string { diff --git a/providers/linkedin.go b/providers/linkedin.go index bca2936..2040454 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -51,26 +51,28 @@ func getLinkedInHeader(accessToken string) http.Header { } // GetEmailAddress returns the Account email address -func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { +func (p *LinkedInProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { if s.AccessToken == "" { - return "", errors.New("missing access token") + return nil, errors.New("missing access token") } req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil) if err != nil { - return "", err + return nil, err } req.Header = getLinkedInHeader(s.AccessToken) json, err := requests.Request(req) if err != nil { - return "", err + return nil, err } email, err := json.String() if err != nil { - return "", err + return nil, err } - return email, nil + return &UserDetails{ + Email: email, + }, nil } // ValidateSessionState validates the AccessToken diff --git a/providers/linkedin_test.go b/providers/linkedin_test.go index 9910a71..b398d5b 100644 --- a/providers/linkedin_test.go +++ b/providers/linkedin_test.go @@ -91,7 +91,7 @@ func TestLinkedInProviderOverrides(t *testing.T) { assert.Equal(t, "profile", p.Data().Scope) } -func TestLinkedInProviderGetEmailAddress(t *testing.T) { +func TestLinkedInProviderGetUserDetails(t *testing.T) { b := testLinkedInBackend(`"user@linkedin.com"`) defer b.Close() @@ -99,12 +99,12 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { p := testLinkedInProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.Equal(t, nil, err) - assert.Equal(t, "user@linkedin.com", email) + assert.Equal(t, "user@linkedin.com", details.Email) } -func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { +func TestLinkedInProviderGetUserDetailsFailedRequest(t *testing.T) { b := testLinkedInBackend("unused payload") defer b.Close() @@ -115,12 +115,12 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. session := &sessions.SessionState{AccessToken: "unexpected_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) + assert.Nil(t, details) } -func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { +func TestLinkedInProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) { b := testLinkedInBackend("{\"foo\": \"bar\"}") defer b.Close() @@ -128,7 +128,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { p := testLinkedInProvider(bURL.Host) session := &sessions.SessionState{AccessToken: "imaginary_access_token"} - email, err := p.GetEmailAddress(session) + details, err := p.GetUserDetails(session) assert.NotEqual(t, nil, err) - assert.Equal(t, "", email) + assert.Nil(t, details) } diff --git a/providers/provider_default.go b/providers/provider_default.go index 17e32c4..6d84a43 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -110,14 +110,8 @@ func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) return "", errors.New("not implemented") } -func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (map[string]string, error) { - userDetails := map[string]string{} - email, err := p.GetEmailAddress(s) - if err != nil { - return nil, err - } - userDetails["email"] = email - return userDetails, nil +func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { + return nil, NewNotImplementedError("") } // GetUserName returns the Account username diff --git a/providers/providers.go b/providers/providers.go index 3faef61..79facf1 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -9,7 +9,7 @@ import ( type Provider interface { Data() *ProviderData GetEmailAddress(*sessions.SessionState) (string, error) - GetUserDetails(*sessions.SessionState) (map[string]string, error) + GetUserDetails(*sessions.SessionState) (*UserDetails, error) GetUserName(*sessions.SessionState) (string, error) GetGroups(*sessions.SessionState, string) (map[string]string, error) Redeem(string, string) (*sessions.SessionState, error) @@ -23,6 +23,11 @@ type Provider interface { CookieForSession(*sessions.SessionState, *encryption.Cipher) (string, error) } +type UserDetails struct { + Email string + UID string +} + // New provides a new Provider based on the configured provider string func New(provider string, p *ProviderData) Provider { switch provider {