Fix tests

This commit is contained in:
Lukasz Leszczuk 2019-08-21 09:20:55 +02:00
parent d7b8506ea2
commit 90208c7fe4
17 changed files with 160 additions and 186 deletions

View File

@ -335,11 +335,11 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState,
if s.Email == "" { if s.Email == "" {
userDetails, err := p.provider.GetUserDetails(s) userDetails, err := p.provider.GetUserDetails(s)
if err != nil { if err != nil {
return s, err return nil, err
} }
s.Email = userDetails["email"] s.Email = userDetails.Email
if uid, found := userDetails["uid"]; found { if userDetails.UID != "" {
s.ID = uid s.ID = userDetails.UID
} else { } else {
s.ID = "" s.ID = ""
} }

View File

@ -262,18 +262,10 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
} }
} }
func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) { func (tp *TestProvider) GetUserDetails(session *sessions.SessionState) (*providers.UserDetails, error) {
return tp.EmailAddress, nil return &providers.UserDetails{
} Email: tp.EmailAddress,
}, nil
func (tp *TestProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
userDetails := map[string]string{}
email, err := tp.GetEmailAddress(s)
if err != nil {
return nil, err
}
userDetails["email"] = email
return userDetails, nil
} }
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool { func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {

View File

@ -119,64 +119,22 @@ func getUserIDFromJSON(json *simplejson.Json) (string, error) {
return uid, err return uid, err
} }
// GetEmailAddress returns the Account email address func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
var email string
var err error var err error
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return nil, errors.New("missing access token")
} }
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
if err != nil { if err != nil {
return "", err return nil, err
} }
req.Header = getAzureHeader(s.AccessToken) req.Header = getAzureHeader(s.AccessToken)
json, err := requests.Request(req) json, err := requests.Request(req)
if err != nil { if err != nil {
return "", err return nil, err
}
email, err = getEmailFromJSON(json)
if err == nil && email != "" {
return email, err
}
email, err = json.Get("userPrincipalName").String()
if err != nil {
logger.Printf("failed making request %s", err)
return "", err
}
if email == "" {
logger.Printf("failed to get email address")
return "", err
}
return email, err
}
func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
userDetails := map[string]string{}
var err error
if s.AccessToken == "" {
return userDetails, errors.New("missing access token")
}
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
if err != nil {
return userDetails, err
}
req.Header = getAzureHeader(s.AccessToken)
json, err := requests.Request(req)
if err != nil {
return userDetails, err
} }
logger.Printf(" JSON: %v", json) logger.Printf(" JSON: %v", json)
@ -184,26 +142,26 @@ func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]str
logger.Printf("\t %20v : %v", key, value) logger.Printf("\t %20v : %v", key, value)
} }
email, err := getEmailFromJSON(json) email, err := getEmailFromJSON(json)
userDetails["email"] = email
if err != nil { if err != nil {
logger.Printf("[GetEmailAddress] failed making request: %s", err) logger.Printf("[GetUserDetails] failed making request: %s", err)
return userDetails, err return nil, err
} }
uid, err := getUserIDFromJSON(json) uid, err := getUserIDFromJSON(json)
userDetails["uid"] = uid
if err != nil { if err != nil {
logger.Printf("[GetEmailAddress] failed to get User ID: %s", err) logger.Printf("[GetUserDetails] failed to get User ID: %s", err)
} }
if email == "" { if email == "" {
logger.Printf("failed to get email address") logger.Printf("failed to get email address")
return userDetails, errors.New("Client email not found") return nil, errors.New("Client email not found")
} }
logger.Printf("[GetEmailAddress] Chosen email address: '%s'", email) logger.Printf("[GetUserDetails] Chosen email address: '%s'", email)
return &UserDetails{
return userDetails, nil Email: email,
UID: uid,
}, nil
} }
// Get list of groups user belong to. Filter the desired names of groups (in case of huge group set) // Get list of groups user belong to. Filter the desired names of groups (in case of huge group set)

View File

@ -121,7 +121,7 @@ func testAzureBackend(payload string) *httptest.Server {
})) }))
} }
func TestAzureProviderGetEmailAddress(t *testing.T) { func TestAzureProviderGetUserDetails(t *testing.T) {
b := testAzureBackend(`{ "mail": "user@windows.net" }`) b := testAzureBackend(`{ "mail": "user@windows.net" }`)
defer b.Close() defer b.Close()
@ -129,12 +129,12 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", details.Email)
} }
func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { func TestAzureProviderGetUserDetailsMailNull(t *testing.T) {
b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`) b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`)
defer b.Close() defer b.Close()
@ -142,12 +142,12 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", details.Email)
} }
func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { func TestAzureProviderGetUserDetailsGetUserPrincipalName(t *testing.T) {
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`) b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`)
defer b.Close() defer b.Close()
@ -155,12 +155,12 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", details.Email)
} }
func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { func TestAzureProviderGetUserDetailsFailToGetEmailAddress(t *testing.T) {
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`) b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`)
defer b.Close() defer b.Close()
@ -168,12 +168,12 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email) assert.Nil(t, details)
} }
func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { func TestAzureProviderGetUserDetailsEmptyUserPrincipalName(t *testing.T) {
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`) b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`)
defer b.Close() defer b.Close()
@ -181,12 +181,12 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, "Client email not found", err.Error())
assert.Equal(t, "", email) assert.Nil(t, details)
} }
func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { func TestAzureProviderGetUserDetailsIncorrectOtherMails(t *testing.T) {
b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`) b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`)
defer b.Close() defer b.Close()
@ -194,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email) assert.Nil(t, details)
} }

View File

@ -64,8 +64,7 @@ func (p *BitbucketProvider) SetRepository(repository string) {
} }
// GetEmailAddress returns the email of the authenticated user // GetEmailAddress returns the email of the authenticated user
func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *BitbucketProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
var emails struct { var emails struct {
Values []struct { Values []struct {
Email string `json:"email"` Email string `json:"email"`
@ -86,12 +85,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) logger.Printf("failed building request %s", err)
return "", err return nil, err
} }
err = requests.RequestJSON(req, &emails) err = requests.RequestJSON(req, &emails)
if err != nil { if err != nil {
logger.Printf("failed making request %s", err) logger.Printf("failed making request %s", err)
return "", err return nil, err
} }
if p.Team != "" { if p.Team != "" {
@ -102,12 +101,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil) teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) logger.Printf("failed building request %s", err)
return "", err return nil, err
} }
err = requests.RequestJSON(req, &teams) err = requests.RequestJSON(req, &teams)
if err != nil { if err != nil {
logger.Printf("failed requesting teams membership %s", err) logger.Printf("failed requesting teams membership %s", err)
return "", err return nil, err
} }
var found = false var found = false
for _, team := range teams.Values { for _, team := range teams.Values {
@ -118,7 +117,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
} }
if found != true { if found != true {
logger.Print("team membership test failed, access denied") logger.Print("team membership test failed, access denied")
return "", nil return nil, nil
} }
} }
@ -133,12 +132,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
nil) nil)
if err != nil { if err != nil {
logger.Printf("failed building request %s", err) logger.Printf("failed building request %s", err)
return "", err return nil, err
} }
err = requests.RequestJSON(req, &repositories) err = requests.RequestJSON(req, &repositories)
if err != nil { if err != nil {
logger.Printf("failed checking repository access %s", err) logger.Printf("failed checking repository access %s", err)
return "", err return nil, err
} }
var found = false var found = false
for _, repository := range repositories.Values { for _, repository := range repositories.Values {
@ -149,15 +148,17 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
} }
if found != true { if found != true {
logger.Print("repository access test failed, access denied") logger.Print("repository access test failed, access denied")
return "", nil return nil, nil
} }
} }
for _, email := range emails.Values { for _, email := range emails.Values {
if email.Primary { if email.Primary {
return email.Email, nil return &UserDetails{
Email: email.Email,
}, nil
} }
} }
return "", nil return nil, nil
} }

View File

@ -112,7 +112,7 @@ func TestBitbucketProviderOverrides(t *testing.T) {
assert.Equal(t, "profile", p.Data().Scope) assert.Equal(t, "profile", p.Data().Scope)
} }
func TestBitbucketProviderGetEmailAddress(t *testing.T) { func TestBitbucketProviderGetUserDetails(t *testing.T) {
b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true } ] }") b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true } ] }")
defer b.Close() defer b.Close()
@ -120,12 +120,12 @@ func TestBitbucketProviderGetEmailAddress(t *testing.T) {
p := testBitbucketProvider(bURL.Host, "", "") p := testBitbucketProvider(bURL.Host, "", "")
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", details.Email)
} }
func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) { func TestBitbucketProviderGetUserDetailsAndGroup(t *testing.T) {
b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true, \"username\": \"bioinformatics\" } ] }") b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true, \"username\": \"bioinformatics\" } ] }")
defer b.Close() defer b.Close()
@ -133,14 +133,14 @@ func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) {
p := testBitbucketProvider(bURL.Host, "bioinformatics", "") p := testBitbucketProvider(bURL.Host, "bioinformatics", "")
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", details.Email)
} }
// Note that trying to trigger the "failed building request" case is not // Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse. // practical, since the only way it can fail is if the URL fails to parse.
func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) { func TestBitbucketProviderGetUserDetailsFailedRequest(t *testing.T) {
b := testBitbucketBackend("unused payload") b := testBitbucketBackend("unused payload")
defer b.Close() defer b.Close()
@ -151,12 +151,12 @@ func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Nil(t, details)
} }
func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { func TestBitbucketProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
b := testBitbucketBackend("{\"foo\": \"bar\"}") b := testBitbucketBackend("{\"foo\": \"bar\"}")
defer b.Close() defer b.Close()
@ -164,7 +164,7 @@ func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T)
p := testBitbucketProvider(bURL.Host, "", "") p := testBitbucketProvider(bURL.Host, "", "")
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, "", email) assert.Nil(t, details)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
} }

View File

@ -55,13 +55,13 @@ func getFacebookHeader(accessToken string) http.Header {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *FacebookProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return nil, errors.New("missing access token")
} }
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil) req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil)
if err != nil { if err != nil {
return "", err return nil, err
} }
req.Header = getFacebookHeader(s.AccessToken) req.Header = getFacebookHeader(s.AccessToken)
@ -71,12 +71,14 @@ func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, er
var r result var r result
err = requests.RequestJSON(req, &r) err = requests.RequestJSON(req, &r)
if err != nil { if err != nil {
return "", err return nil, err
} }
if r.Email == "" { if r.Email == "" {
return "", errors.New("no email") return nil, errors.New("no email")
} }
return r.Email, nil return &UserDetails{
Email: r.Email,
}, nil
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken

View File

@ -201,8 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *GitHubProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
var emails []struct { var emails []struct {
Email string `json:"email"` Email string `json:"email"`
Primary bool `json:"primary"` Primary bool `json:"primary"`
@ -213,11 +212,11 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
if p.Org != "" { if p.Org != "" {
if p.Team != "" { if p.Team != "" {
if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok { if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok {
return "", err return nil, err
} }
} else { } else {
if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok { if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok {
return "", err return nil, err
} }
} }
} }
@ -231,32 +230,34 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken)) req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
return "", err return nil, err
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s", return nil, fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body) resp.StatusCode, endpoint.String(), body)
} }
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &emails); err != nil { if err := json.Unmarshal(body, &emails); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body) return nil, fmt.Errorf("%s unmarshaling %s", err, body)
} }
for _, email := range emails { for _, email := range emails {
if email.Primary && email.Verified { if email.Primary && email.Verified {
return email.Email, nil return &UserDetails{
Email: email.Email,
}, nil
} }
} }
return "", nil return nil, nil
} }
// GetUserName returns the Account user name // GetUserName returns the Account user name

View File

@ -97,7 +97,7 @@ func TestGitHubProviderOverrides(t *testing.T) {
assert.Equal(t, "profile", p.Data().Scope) assert.Equal(t, "profile", p.Data().Scope)
} }
func TestGitHubProviderGetEmailAddress(t *testing.T) { func TestGitHubProviderGetUserDetails(t *testing.T) {
b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`}) b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`})
defer b.Close() defer b.Close()
@ -105,12 +105,12 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", details.Email)
} }
func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) { func TestGitHubProviderGetUserDetailsNotVerified(t *testing.T) {
b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": false, "primary": true} ]`}) b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": false, "primary": true} ]`})
defer b.Close() defer b.Close()
@ -118,12 +118,12 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Empty(t, "", email) assert.Nil(t, details)
} }
func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { func TestGitHubProviderGetUserDetailsWithOrg(t *testing.T) {
b := testGitHubBackend([]string{ b := testGitHubBackend([]string{
`[ {"email": "michael.bland@gsa.gov", "primary": true, "verified": true, "login":"testorg"} ]`, `[ {"email": "michael.bland@gsa.gov", "primary": true, "verified": true, "login":"testorg"} ]`,
`[ {"email": "michael.bland1@gsa.gov", "primary": true, "verified": true, "login":"testorg1"} ]`, `[ {"email": "michael.bland1@gsa.gov", "primary": true, "verified": true, "login":"testorg1"} ]`,
@ -136,14 +136,14 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
p.Org = "testorg1" p.Org = "testorg1"
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", details.Email)
} }
// Note that trying to trigger the "failed building request" case is not // Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse. // practical, since the only way it can fail is if the URL fails to parse.
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { func TestGitHubProviderGetUserDetailsFailedRequest(t *testing.T) {
b := testGitHubBackend([]string{"unused payload"}) b := testGitHubBackend([]string{"unused payload"})
defer b.Close() defer b.Close()
@ -154,12 +154,12 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Nil(t, details)
} }
func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { func TestGitHubProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"}) b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"})
defer b.Close() defer b.Close()
@ -167,9 +167,9 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Nil(t, details)
} }
func TestGitHubProviderGetUserName(t *testing.T) { func TestGitHubProviderGetUserName(t *testing.T) {

View File

@ -220,31 +220,32 @@ func (p *GitLabProvider) ValidateSessionState(s *sessions.SessionState) bool {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *GitLabProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
// Retrieve user info // Retrieve user info
userInfo, err := p.getUserInfo(s) userInfo, err := p.getUserInfo(s)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to retrieve user info: %v", err) return nil, fmt.Errorf("failed to retrieve user info: %v", err)
} }
// Check if email is verified // Check if email is verified
if !p.AllowUnverifiedEmail && !userInfo.EmailVerified { if !p.AllowUnverifiedEmail && !userInfo.EmailVerified {
return "", fmt.Errorf("user email is not verified") return nil, fmt.Errorf("user email is not verified")
} }
// Check if email has valid domain // Check if email has valid domain
err = p.verifyEmailDomain(userInfo) err = p.verifyEmailDomain(userInfo)
if err != nil { if err != nil {
return "", fmt.Errorf("email domain check failed: %v", err) return nil, fmt.Errorf("email domain check failed: %v", err)
} }
// Check group membership // Check group membership
err = p.verifyGroupMembership(userInfo) err = p.verifyGroupMembership(userInfo)
if err != nil { if err != nil {
return "", fmt.Errorf("group membership check failed: %v", err) return nil, fmt.Errorf("group membership check failed: %v", err)
} }
return &UserDetails{
return userInfo.Email, nil Email: userInfo.Email,
}, nil
} }
// GetUserName returns the Account user name // GetUserName returns the Account user name

View File

@ -63,7 +63,7 @@ func TestGitLabProviderBadToken(t *testing.T) {
p := testGitLabProvider(bURL.Host) p := testGitLabProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"}
_, err := p.GetEmailAddress(session) _, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }
@ -75,7 +75,7 @@ func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) {
p := testGitLabProvider(bURL.Host) p := testGitLabProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(session) _, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }
@ -88,9 +88,9 @@ func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) {
p.AllowUnverifiedEmail = true p.AllowUnverifiedEmail = true
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email) assert.Equal(t, "foo@bar.com", details.Email)
} }
func TestGitLabProviderUsername(t *testing.T) { func TestGitLabProviderUsername(t *testing.T) {
@ -117,9 +117,9 @@ func TestGitLabProviderGroupMembershipValid(t *testing.T) {
p.Group = "foo" p.Group = "foo"
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email) assert.Equal(t, "foo@bar.com", details.Email)
} }
func TestGitLabProviderGroupMembershipMissing(t *testing.T) { func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
@ -132,7 +132,7 @@ func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
p.Group = "baz" p.Group = "baz"
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(session) _, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }
@ -146,9 +146,9 @@ func TestGitLabProviderEmailDomainValid(t *testing.T) {
p.EmailDomains = []string{"bar.com"} p.EmailDomains = []string{"bar.com"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "foo@bar.com", email) assert.Equal(t, "foo@bar.com", details.Email)
} }
func TestGitLabProviderEmailDomainInvalid(t *testing.T) { func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
@ -161,6 +161,6 @@ func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
p.EmailDomains = []string{"baz.com"} p.EmailDomains = []string{"baz.com"}
session := &sessions.SessionState{AccessToken: "gitlab_access_token"} session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
_, err := p.GetEmailAddress(session) _, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
} }

View File

@ -233,6 +233,10 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
return p.GroupValidator(email) return p.GroupValidator(email)
} }
func (p *GoogleProvider) ValidateGroupWithSession(s *sessions.SessionState) bool {
return p.ValidateGroup(s.Email)
}
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {

View File

@ -9,6 +9,20 @@ import (
"github.com/pusher/oauth2_proxy/pkg/requests" "github.com/pusher/oauth2_proxy/pkg/requests"
) )
type NotImplementedError struct {
message string
}
func NewNotImplementedError(message string) *NotImplementedError {
return &NotImplementedError{
message: "Not implemented: " + message,
}
}
func (e *NotImplementedError) Error() string {
return e.message
}
// stripToken is a helper function to obfuscate "access_token" // stripToken is a helper function to obfuscate "access_token"
// query parameters // query parameters
func stripToken(endpoint string) string { func stripToken(endpoint string) string {

View File

@ -51,26 +51,28 @@ func getLinkedInHeader(accessToken string) http.Header {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *LinkedInProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return nil, errors.New("missing access token")
} }
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil) req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil)
if err != nil { if err != nil {
return "", err return nil, err
} }
req.Header = getLinkedInHeader(s.AccessToken) req.Header = getLinkedInHeader(s.AccessToken)
json, err := requests.Request(req) json, err := requests.Request(req)
if err != nil { if err != nil {
return "", err return nil, err
} }
email, err := json.String() email, err := json.String()
if err != nil { if err != nil {
return "", err return nil, err
} }
return email, nil return &UserDetails{
Email: email,
}, nil
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken

View File

@ -91,7 +91,7 @@ func TestLinkedInProviderOverrides(t *testing.T) {
assert.Equal(t, "profile", p.Data().Scope) assert.Equal(t, "profile", p.Data().Scope)
} }
func TestLinkedInProviderGetEmailAddress(t *testing.T) { func TestLinkedInProviderGetUserDetails(t *testing.T) {
b := testLinkedInBackend(`"user@linkedin.com"`) b := testLinkedInBackend(`"user@linkedin.com"`)
defer b.Close() defer b.Close()
@ -99,12 +99,12 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
p := testLinkedInProvider(bURL.Host) p := testLinkedInProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@linkedin.com", email) assert.Equal(t, "user@linkedin.com", details.Email)
} }
func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { func TestLinkedInProviderGetUserDetailsFailedRequest(t *testing.T) {
b := testLinkedInBackend("unused payload") b := testLinkedInBackend("unused payload")
defer b.Close() defer b.Close()
@ -115,12 +115,12 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &sessions.SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Nil(t, details)
} }
func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { func TestLinkedInProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
b := testLinkedInBackend("{\"foo\": \"bar\"}") b := testLinkedInBackend("{\"foo\": \"bar\"}")
defer b.Close() defer b.Close()
@ -128,7 +128,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
p := testLinkedInProvider(bURL.Host) p := testLinkedInProvider(bURL.Host)
session := &sessions.SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) details, err := p.GetUserDetails(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Nil(t, details)
} }

View File

@ -110,14 +110,8 @@ func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error)
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (map[string]string, error) { func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
userDetails := map[string]string{} return nil, NewNotImplementedError("")
email, err := p.GetEmailAddress(s)
if err != nil {
return nil, err
}
userDetails["email"] = email
return userDetails, nil
} }
// GetUserName returns the Account username // GetUserName returns the Account username

View File

@ -9,7 +9,7 @@ import (
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(*sessions.SessionState) (string, error) GetEmailAddress(*sessions.SessionState) (string, error)
GetUserDetails(*sessions.SessionState) (map[string]string, error) GetUserDetails(*sessions.SessionState) (*UserDetails, error)
GetUserName(*sessions.SessionState) (string, error) GetUserName(*sessions.SessionState) (string, error)
GetGroups(*sessions.SessionState, string) (map[string]string, error) GetGroups(*sessions.SessionState, string) (map[string]string, error)
Redeem(string, string) (*sessions.SessionState, error) Redeem(string, string) (*sessions.SessionState, error)
@ -23,6 +23,11 @@ type Provider interface {
CookieForSession(*sessions.SessionState, *encryption.Cipher) (string, error) CookieForSession(*sessions.SessionState, *encryption.Cipher) (string, error)
} }
type UserDetails struct {
Email string
UID string
}
// New provides a new Provider based on the configured provider string // New provides a new Provider based on the configured provider string
func New(provider string, p *ProviderData) Provider { func New(provider string, p *ProviderData) Provider {
switch provider { switch provider {