Fix tests

This commit is contained in:
Lukasz Leszczuk 2019-08-21 09:20:55 +02:00
parent d7b8506ea2
commit 90208c7fe4
17 changed files with 160 additions and 186 deletions

View File

@ -335,11 +335,11 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState,
if s.Email == "" {
userDetails, err := p.provider.GetUserDetails(s)
if err != nil {
return s, err
return nil, err
}
s.Email = userDetails["email"]
if uid, found := userDetails["uid"]; found {
s.ID = uid
s.Email = userDetails.Email
if userDetails.UID != "" {
s.ID = userDetails.UID
} else {
s.ID = ""
}

View File

@ -262,18 +262,10 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
}
}
func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) {
return tp.EmailAddress, nil
}
func (tp *TestProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
userDetails := map[string]string{}
email, err := tp.GetEmailAddress(s)
if err != nil {
return nil, err
}
userDetails["email"] = email
return userDetails, nil
func (tp *TestProvider) GetUserDetails(session *sessions.SessionState) (*providers.UserDetails, error) {
return &providers.UserDetails{
Email: tp.EmailAddress,
}, nil
}
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {

View File

@ -119,64 +119,22 @@ func getUserIDFromJSON(json *simplejson.Json) (string, error) {
return uid, err
}
// GetEmailAddress returns the Account email address
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
var email string
func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
var err error
if s.AccessToken == "" {
return "", errors.New("missing access token")
return nil, errors.New("missing access token")
}
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
if err != nil {
return "", err
return nil, err
}
req.Header = getAzureHeader(s.AccessToken)
json, err := requests.Request(req)
if err != nil {
return "", err
}
email, err = getEmailFromJSON(json)
if err == nil && email != "" {
return email, err
}
email, err = json.Get("userPrincipalName").String()
if err != nil {
logger.Printf("failed making request %s", err)
return "", err
}
if email == "" {
logger.Printf("failed to get email address")
return "", err
}
return email, err
}
func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
userDetails := map[string]string{}
var err error
if s.AccessToken == "" {
return userDetails, errors.New("missing access token")
}
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
if err != nil {
return userDetails, err
}
req.Header = getAzureHeader(s.AccessToken)
json, err := requests.Request(req)
if err != nil {
return userDetails, err
return nil, err
}
logger.Printf(" JSON: %v", json)
@ -184,26 +142,26 @@ func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]str
logger.Printf("\t %20v : %v", key, value)
}
email, err := getEmailFromJSON(json)
userDetails["email"] = email
if err != nil {
logger.Printf("[GetEmailAddress] failed making request: %s", err)
return userDetails, err
logger.Printf("[GetUserDetails] failed making request: %s", err)
return nil, err
}
uid, err := getUserIDFromJSON(json)
userDetails["uid"] = uid
if err != nil {
logger.Printf("[GetEmailAddress] failed to get User ID: %s", err)
logger.Printf("[GetUserDetails] failed to get User ID: %s", err)
}
if email == "" {
logger.Printf("failed to get email address")
return userDetails, errors.New("Client email not found")
return nil, errors.New("Client email not found")
}
logger.Printf("[GetEmailAddress] Chosen email address: '%s'", email)
return userDetails, nil
logger.Printf("[GetUserDetails] Chosen email address: '%s'", email)
return &UserDetails{
Email: email,
UID: uid,
}, nil
}
// Get list of groups user belong to. Filter the desired names of groups (in case of huge group set)

View File

@ -121,7 +121,7 @@ func testAzureBackend(payload string) *httptest.Server {
}))
}
func TestAzureProviderGetEmailAddress(t *testing.T) {
func TestAzureProviderGetUserDetails(t *testing.T) {
b := testAzureBackend(`{ "mail": "user@windows.net" }`)
defer b.Close()
@ -129,12 +129,12 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email)
assert.Equal(t, "user@windows.net", details.Email)
}
func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
func TestAzureProviderGetUserDetailsMailNull(t *testing.T) {
b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`)
defer b.Close()
@ -142,12 +142,12 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email)
assert.Equal(t, "user@windows.net", details.Email)
}
func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
func TestAzureProviderGetUserDetailsGetUserPrincipalName(t *testing.T) {
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`)
defer b.Close()
@ -155,12 +155,12 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email)
assert.Equal(t, "user@windows.net", details.Email)
}
func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
func TestAzureProviderGetUserDetailsFailToGetEmailAddress(t *testing.T) {
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`)
defer b.Close()
@ -168,12 +168,12 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email)
assert.Nil(t, details)
}
func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
func TestAzureProviderGetUserDetailsEmptyUserPrincipalName(t *testing.T) {
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`)
defer b.Close()
@ -181,12 +181,12 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "", email)
details, err := p.GetUserDetails(session)
assert.Equal(t, "Client email not found", err.Error())
assert.Nil(t, details)
}
func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
func TestAzureProviderGetUserDetailsIncorrectOtherMails(t *testing.T) {
b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`)
defer b.Close()
@ -194,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email)
assert.Nil(t, details)
}

View File

@ -64,8 +64,7 @@ func (p *BitbucketProvider) SetRepository(repository string) {
}
// GetEmailAddress returns the email of the authenticated user
func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
func (p *BitbucketProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
var emails struct {
Values []struct {
Email string `json:"email"`
@ -86,12 +85,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
return nil, err
}
err = requests.RequestJSON(req, &emails)
if err != nil {
logger.Printf("failed making request %s", err)
return "", err
return nil, err
}
if p.Team != "" {
@ -102,12 +101,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
return nil, err
}
err = requests.RequestJSON(req, &teams)
if err != nil {
logger.Printf("failed requesting teams membership %s", err)
return "", err
return nil, err
}
var found = false
for _, team := range teams.Values {
@ -118,7 +117,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
}
if found != true {
logger.Print("team membership test failed, access denied")
return "", nil
return nil, nil
}
}
@ -133,12 +132,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
nil)
if err != nil {
logger.Printf("failed building request %s", err)
return "", err
return nil, err
}
err = requests.RequestJSON(req, &repositories)
if err != nil {
logger.Printf("failed checking repository access %s", err)
return "", err
return nil, err
}
var found = false
for _, repository := range repositories.Values {
@ -149,15 +148,17 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
}
if found != true {
logger.Print("repository access test failed, access denied")
return "", nil
return nil, nil
}
}
for _, email := range emails.Values {
if email.Primary {
return email.Email, nil
return &UserDetails{
Email: email.Email,
}, nil
}
}
return "", nil
return nil, nil
}

View File

@ -112,7 +112,7 @@ func TestBitbucketProviderOverrides(t *testing.T) {
assert.Equal(t, "profile", p.Data().Scope)
}
func TestBitbucketProviderGetEmailAddress(t *testing.T) {
func TestBitbucketProviderGetUserDetails(t *testing.T) {
b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true } ] }")
defer b.Close()
@ -120,12 +120,12 @@ func TestBitbucketProviderGetEmailAddress(t *testing.T) {
p := testBitbucketProvider(bURL.Host, "", "")
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
}
func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) {
func TestBitbucketProviderGetUserDetailsAndGroup(t *testing.T) {
b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true, \"username\": \"bioinformatics\" } ] }")
defer b.Close()
@ -133,14 +133,14 @@ func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) {
p := testBitbucketProvider(bURL.Host, "bioinformatics", "")
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
}
// Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse.
func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) {
func TestBitbucketProviderGetUserDetailsFailedRequest(t *testing.T) {
b := testBitbucketBackend("unused payload")
defer b.Close()
@ -151,12 +151,12 @@ func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
assert.Nil(t, details)
}
func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
func TestBitbucketProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
b := testBitbucketBackend("{\"foo\": \"bar\"}")
defer b.Close()
@ -164,7 +164,7 @@ func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T)
p := testBitbucketProvider(bURL.Host, "", "")
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, "", email)
details, err := p.GetUserDetails(session)
assert.Nil(t, details)
assert.Equal(t, nil, err)
}

View File

@ -55,13 +55,13 @@ func getFacebookHeader(accessToken string) http.Header {
}
// GetEmailAddress returns the Account email address
func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
func (p *FacebookProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
if s.AccessToken == "" {
return "", errors.New("missing access token")
return nil, errors.New("missing access token")
}
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil)
if err != nil {
return "", err
return nil, err
}
req.Header = getFacebookHeader(s.AccessToken)
@ -71,12 +71,14 @@ func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, er
var r result
err = requests.RequestJSON(req, &r)
if err != nil {
return "", err
return nil, err
}
if r.Email == "" {
return "", errors.New("no email")
return nil, errors.New("no email")
}
return r.Email, nil
return &UserDetails{
Email: r.Email,
}, nil
}
// ValidateSessionState validates the AccessToken

View File

@ -201,8 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
}
// GetEmailAddress returns the Account email address
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
func (p *GitHubProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
var emails []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
@ -213,11 +212,11 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
if p.Org != "" {
if p.Team != "" {
if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok {
return "", err
return nil, err
}
} else {
if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok {
return "", err
return nil, err
}
}
}
@ -231,32 +230,34 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
return nil, err
}
body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return "", err
return nil, err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
return nil, fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &emails); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
return nil, fmt.Errorf("%s unmarshaling %s", err, body)
}
for _, email := range emails {
if email.Primary && email.Verified {
return email.Email, nil
return &UserDetails{
Email: email.Email,
}, nil
}
}
return "", nil
return nil, nil
}
// GetUserName returns the Account user name

View File

@ -97,7 +97,7 @@ func TestGitHubProviderOverrides(t *testing.T) {
assert.Equal(t, "profile", p.Data().Scope)
}
func TestGitHubProviderGetEmailAddress(t *testing.T) {
func TestGitHubProviderGetUserDetails(t *testing.T) {
b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`})
defer b.Close()
@ -105,12 +105,12 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
p := testGitHubProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
}
func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
func TestGitHubProviderGetUserDetailsNotVerified(t *testing.T) {
b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": false, "primary": true} ]`})
defer b.Close()
@ -118,12 +118,12 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
p := testGitHubProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Empty(t, "", email)
assert.Nil(t, details)
}
func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
func TestGitHubProviderGetUserDetailsWithOrg(t *testing.T) {
b := testGitHubBackend([]string{
`[ {"email": "michael.bland@gsa.gov", "primary": true, "verified": true, "login":"testorg"} ]`,
`[ {"email": "michael.bland1@gsa.gov", "primary": true, "verified": true, "login":"testorg1"} ]`,
@ -136,14 +136,14 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
p.Org = "testorg1"
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
}
// Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse.
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
func TestGitHubProviderGetUserDetailsFailedRequest(t *testing.T) {
b := testGitHubBackend([]string{"unused payload"})
defer b.Close()
@ -154,12 +154,12 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
assert.Nil(t, details)
}
func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
func TestGitHubProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"})
defer b.Close()
@ -167,9 +167,9 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
p := testGitHubProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
assert.Nil(t, details)
}
func TestGitHubProviderGetUserName(t *testing.T) {

View File

@ -220,31 +220,32 @@ func (p *GitLabProvider) ValidateSessionState(s *sessions.SessionState) bool {
}
// GetEmailAddress returns the Account email address
func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
func (p *GitLabProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
// Retrieve user info
userInfo, err := p.getUserInfo(s)
if err != nil {
return "", fmt.Errorf("failed to retrieve user info: %v", err)
return nil, fmt.Errorf("failed to retrieve user info: %v", err)
}
// Check if email is verified
if !p.AllowUnverifiedEmail && !userInfo.EmailVerified {
return "", fmt.Errorf("user email is not verified")
return nil, fmt.Errorf("user email is not verified")
}
// Check if email has valid domain
err = p.verifyEmailDomain(userInfo)
if err != nil {
return "", fmt.Errorf("email domain check failed: %v", err)
return nil, fmt.Errorf("email domain check failed: %v", err)
}
// Check group membership
err = p.verifyGroupMembership(userInfo)
if err != nil {
return "", fmt.Errorf("group membership check failed: %v", err)
return nil, fmt.Errorf("group membership check failed: %v", err)
}
return userInfo.Email, nil
return &UserDetails{
Email: userInfo.Email,
}, nil
}
// GetUserName returns the Account user name

View File

@ -63,7 +63,7 @@ func TestGitLabProviderBadToken(t *testing.T) {
p := testGitLabProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"}
_, err := p.GetEmailAddress(session)
_, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err)
}
@ -75,7 +75,7 @@ func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) {
p := testGitLabProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(session)
_, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err)
}
@ -88,9 +88,9 @@ func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) {
p.AllowUnverifiedEmail = true
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email)
assert.Equal(t, "foo@bar.com", details.Email)
}
func TestGitLabProviderUsername(t *testing.T) {
@ -117,9 +117,9 @@ func TestGitLabProviderGroupMembershipValid(t *testing.T) {
p.Group = "foo"
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email)
assert.Equal(t, "foo@bar.com", details.Email)
}
func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
@ -132,7 +132,7 @@ func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
p.Group = "baz"
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(session)
_, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err)
}
@ -146,9 +146,9 @@ func TestGitLabProviderEmailDomainValid(t *testing.T) {
p.EmailDomains = []string{"bar.com"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email)
assert.Equal(t, "foo@bar.com", details.Email)
}
func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
@ -161,6 +161,6 @@ func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
p.EmailDomains = []string{"baz.com"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(session)
_, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err)
}

View File

@ -233,6 +233,10 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
return p.GroupValidator(email)
}
func (p *GoogleProvider) ValidateGroupWithSession(s *sessions.SessionState) bool {
return p.ValidateGroup(s.Email)
}
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {

View File

@ -9,6 +9,20 @@ import (
"github.com/pusher/oauth2_proxy/pkg/requests"
)
type NotImplementedError struct {
message string
}
func NewNotImplementedError(message string) *NotImplementedError {
return &NotImplementedError{
message: "Not implemented: " + message,
}
}
func (e *NotImplementedError) Error() string {
return e.message
}
// stripToken is a helper function to obfuscate "access_token"
// query parameters
func stripToken(endpoint string) string {

View File

@ -51,26 +51,28 @@ func getLinkedInHeader(accessToken string) http.Header {
}
// GetEmailAddress returns the Account email address
func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
func (p *LinkedInProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
if s.AccessToken == "" {
return "", errors.New("missing access token")
return nil, errors.New("missing access token")
}
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil)
if err != nil {
return "", err
return nil, err
}
req.Header = getLinkedInHeader(s.AccessToken)
json, err := requests.Request(req)
if err != nil {
return "", err
return nil, err
}
email, err := json.String()
if err != nil {
return "", err
return nil, err
}
return email, nil
return &UserDetails{
Email: email,
}, nil
}
// ValidateSessionState validates the AccessToken

View File

@ -91,7 +91,7 @@ func TestLinkedInProviderOverrides(t *testing.T) {
assert.Equal(t, "profile", p.Data().Scope)
}
func TestLinkedInProviderGetEmailAddress(t *testing.T) {
func TestLinkedInProviderGetUserDetails(t *testing.T) {
b := testLinkedInBackend(`"user@linkedin.com"`)
defer b.Close()
@ -99,12 +99,12 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
p := testLinkedInProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err)
assert.Equal(t, "user@linkedin.com", email)
assert.Equal(t, "user@linkedin.com", details.Email)
}
func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
func TestLinkedInProviderGetUserDetailsFailedRequest(t *testing.T) {
b := testLinkedInBackend("unused payload")
defer b.Close()
@ -115,12 +115,12 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
assert.Nil(t, details)
}
func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
func TestLinkedInProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
b := testLinkedInBackend("{\"foo\": \"bar\"}")
defer b.Close()
@ -128,7 +128,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
p := testLinkedInProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
details, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
assert.Nil(t, details)
}

View File

@ -110,14 +110,8 @@ func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error)
return "", errors.New("not implemented")
}
func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
userDetails := map[string]string{}
email, err := p.GetEmailAddress(s)
if err != nil {
return nil, err
}
userDetails["email"] = email
return userDetails, nil
func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
return nil, NewNotImplementedError("")
}
// GetUserName returns the Account username

View File

@ -9,7 +9,7 @@ import (
type Provider interface {
Data() *ProviderData
GetEmailAddress(*sessions.SessionState) (string, error)
GetUserDetails(*sessions.SessionState) (map[string]string, error)
GetUserDetails(*sessions.SessionState) (*UserDetails, error)
GetUserName(*sessions.SessionState) (string, error)
GetGroups(*sessions.SessionState, string) (map[string]string, error)
Redeem(string, string) (*sessions.SessionState, error)
@ -23,6 +23,11 @@ type Provider interface {
CookieForSession(*sessions.SessionState, *encryption.Cipher) (string, error)
}
type UserDetails struct {
Email string
UID string
}
// New provides a new Provider based on the configured provider string
func New(provider string, p *ProviderData) Provider {
switch provider {