Fix tests
This commit is contained in:
parent
d7b8506ea2
commit
90208c7fe4
@ -335,11 +335,11 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState,
|
||||
if s.Email == "" {
|
||||
userDetails, err := p.provider.GetUserDetails(s)
|
||||
if err != nil {
|
||||
return s, err
|
||||
return nil, err
|
||||
}
|
||||
s.Email = userDetails["email"]
|
||||
if uid, found := userDetails["uid"]; found {
|
||||
s.ID = uid
|
||||
s.Email = userDetails.Email
|
||||
if userDetails.UID != "" {
|
||||
s.ID = userDetails.UID
|
||||
} else {
|
||||
s.ID = ""
|
||||
}
|
||||
|
@ -262,18 +262,10 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) {
|
||||
return tp.EmailAddress, nil
|
||||
}
|
||||
|
||||
func (tp *TestProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
|
||||
userDetails := map[string]string{}
|
||||
email, err := tp.GetEmailAddress(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userDetails["email"] = email
|
||||
return userDetails, nil
|
||||
func (tp *TestProvider) GetUserDetails(session *sessions.SessionState) (*providers.UserDetails, error) {
|
||||
return &providers.UserDetails{
|
||||
Email: tp.EmailAddress,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
|
||||
|
@ -119,64 +119,22 @@ func getUserIDFromJSON(json *simplejson.Json) (string, error) {
|
||||
return uid, err
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
var email string
|
||||
func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||
var err error
|
||||
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
return nil, errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
req.Header = getAzureHeader(s.AccessToken)
|
||||
|
||||
json, err := requests.Request(req)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
email, err = getEmailFromJSON(json)
|
||||
|
||||
if err == nil && email != "" {
|
||||
return email, err
|
||||
}
|
||||
|
||||
email, err = json.Get("userPrincipalName").String()
|
||||
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if email == "" {
|
||||
logger.Printf("failed to get email address")
|
||||
return "", err
|
||||
}
|
||||
|
||||
return email, err
|
||||
}
|
||||
|
||||
func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
|
||||
userDetails := map[string]string{}
|
||||
var err error
|
||||
|
||||
if s.AccessToken == "" {
|
||||
return userDetails, errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
|
||||
if err != nil {
|
||||
return userDetails, err
|
||||
}
|
||||
req.Header = getAzureHeader(s.AccessToken)
|
||||
|
||||
json, err := requests.Request(req)
|
||||
|
||||
if err != nil {
|
||||
return userDetails, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Printf(" JSON: %v", json)
|
||||
@ -184,26 +142,26 @@ func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]str
|
||||
logger.Printf("\t %20v : %v", key, value)
|
||||
}
|
||||
email, err := getEmailFromJSON(json)
|
||||
userDetails["email"] = email
|
||||
|
||||
if err != nil {
|
||||
logger.Printf("[GetEmailAddress] failed making request: %s", err)
|
||||
return userDetails, err
|
||||
logger.Printf("[GetUserDetails] failed making request: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uid, err := getUserIDFromJSON(json)
|
||||
userDetails["uid"] = uid
|
||||
if err != nil {
|
||||
logger.Printf("[GetEmailAddress] failed to get User ID: %s", err)
|
||||
logger.Printf("[GetUserDetails] failed to get User ID: %s", err)
|
||||
}
|
||||
|
||||
if email == "" {
|
||||
logger.Printf("failed to get email address")
|
||||
return userDetails, errors.New("Client email not found")
|
||||
return nil, errors.New("Client email not found")
|
||||
}
|
||||
logger.Printf("[GetEmailAddress] Chosen email address: '%s'", email)
|
||||
|
||||
return userDetails, nil
|
||||
logger.Printf("[GetUserDetails] Chosen email address: '%s'", email)
|
||||
return &UserDetails{
|
||||
Email: email,
|
||||
UID: uid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get list of groups user belong to. Filter the desired names of groups (in case of huge group set)
|
||||
|
@ -121,7 +121,7 @@ func testAzureBackend(payload string) *httptest.Server {
|
||||
}))
|
||||
}
|
||||
|
||||
func TestAzureProviderGetEmailAddress(t *testing.T) {
|
||||
func TestAzureProviderGetUserDetails(t *testing.T) {
|
||||
b := testAzureBackend(`{ "mail": "user@windows.net" }`)
|
||||
defer b.Close()
|
||||
|
||||
@ -129,12 +129,12 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@windows.net", email)
|
||||
assert.Equal(t, "user@windows.net", details.Email)
|
||||
}
|
||||
|
||||
func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
|
||||
func TestAzureProviderGetUserDetailsMailNull(t *testing.T) {
|
||||
b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`)
|
||||
defer b.Close()
|
||||
|
||||
@ -142,12 +142,12 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@windows.net", email)
|
||||
assert.Equal(t, "user@windows.net", details.Email)
|
||||
}
|
||||
|
||||
func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
|
||||
func TestAzureProviderGetUserDetailsGetUserPrincipalName(t *testing.T) {
|
||||
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`)
|
||||
defer b.Close()
|
||||
|
||||
@ -155,12 +155,12 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@windows.net", email)
|
||||
assert.Equal(t, "user@windows.net", details.Email)
|
||||
}
|
||||
|
||||
func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
|
||||
func TestAzureProviderGetUserDetailsFailToGetEmailAddress(t *testing.T) {
|
||||
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`)
|
||||
defer b.Close()
|
||||
|
||||
@ -168,12 +168,12 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||
assert.Equal(t, "", email)
|
||||
assert.Nil(t, details)
|
||||
}
|
||||
|
||||
func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
|
||||
func TestAzureProviderGetUserDetailsEmptyUserPrincipalName(t *testing.T) {
|
||||
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`)
|
||||
defer b.Close()
|
||||
|
||||
@ -181,12 +181,12 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, "Client email not found", err.Error())
|
||||
assert.Nil(t, details)
|
||||
}
|
||||
|
||||
func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
|
||||
func TestAzureProviderGetUserDetailsIncorrectOtherMails(t *testing.T) {
|
||||
b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`)
|
||||
defer b.Close()
|
||||
|
||||
@ -194,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||
assert.Equal(t, "", email)
|
||||
assert.Nil(t, details)
|
||||
}
|
||||
|
@ -64,8 +64,7 @@ func (p *BitbucketProvider) SetRepository(repository string) {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the email of the authenticated user
|
||||
func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
|
||||
func (p *BitbucketProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||
var emails struct {
|
||||
Values []struct {
|
||||
Email string `json:"email"`
|
||||
@ -86,12 +85,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
|
||||
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
err = requests.RequestJSON(req, &emails)
|
||||
if err != nil {
|
||||
logger.Printf("failed making request %s", err)
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.Team != "" {
|
||||
@ -102,12 +101,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
|
||||
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
err = requests.RequestJSON(req, &teams)
|
||||
if err != nil {
|
||||
logger.Printf("failed requesting teams membership %s", err)
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
var found = false
|
||||
for _, team := range teams.Values {
|
||||
@ -118,7 +117,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
|
||||
}
|
||||
if found != true {
|
||||
logger.Print("team membership test failed, access denied")
|
||||
return "", nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -133,12 +132,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
|
||||
nil)
|
||||
if err != nil {
|
||||
logger.Printf("failed building request %s", err)
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
err = requests.RequestJSON(req, &repositories)
|
||||
if err != nil {
|
||||
logger.Printf("failed checking repository access %s", err)
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
var found = false
|
||||
for _, repository := range repositories.Values {
|
||||
@ -149,15 +148,17 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
|
||||
}
|
||||
if found != true {
|
||||
logger.Print("repository access test failed, access denied")
|
||||
return "", nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
for _, email := range emails.Values {
|
||||
if email.Primary {
|
||||
return email.Email, nil
|
||||
return &UserDetails{
|
||||
Email: email.Email,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
return nil, nil
|
||||
}
|
||||
|
@ -112,7 +112,7 @@ func TestBitbucketProviderOverrides(t *testing.T) {
|
||||
assert.Equal(t, "profile", p.Data().Scope)
|
||||
}
|
||||
|
||||
func TestBitbucketProviderGetEmailAddress(t *testing.T) {
|
||||
func TestBitbucketProviderGetUserDetails(t *testing.T) {
|
||||
b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true } ] }")
|
||||
defer b.Close()
|
||||
|
||||
@ -120,12 +120,12 @@ func TestBitbucketProviderGetEmailAddress(t *testing.T) {
|
||||
p := testBitbucketProvider(bURL.Host, "", "")
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
|
||||
}
|
||||
|
||||
func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) {
|
||||
func TestBitbucketProviderGetUserDetailsAndGroup(t *testing.T) {
|
||||
b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true, \"username\": \"bioinformatics\" } ] }")
|
||||
defer b.Close()
|
||||
|
||||
@ -133,14 +133,14 @@ func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) {
|
||||
p := testBitbucketProvider(bURL.Host, "bioinformatics", "")
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
|
||||
}
|
||||
|
||||
// Note that trying to trigger the "failed building request" case is not
|
||||
// practical, since the only way it can fail is if the URL fails to parse.
|
||||
func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
func TestBitbucketProviderGetUserDetailsFailedRequest(t *testing.T) {
|
||||
b := testBitbucketBackend("unused payload")
|
||||
defer b.Close()
|
||||
|
||||
@ -151,12 +151,12 @@ func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
assert.Nil(t, details)
|
||||
}
|
||||
|
||||
func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
func TestBitbucketProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
|
||||
b := testBitbucketBackend("{\"foo\": \"bar\"}")
|
||||
defer b.Close()
|
||||
|
||||
@ -164,7 +164,7 @@ func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T)
|
||||
p := testBitbucketProvider(bURL.Host, "", "")
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, "", email)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Nil(t, details)
|
||||
assert.Equal(t, nil, err)
|
||||
}
|
||||
|
@ -55,13 +55,13 @@ func getFacebookHeader(accessToken string) http.Header {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
func (p *FacebookProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
return nil, errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
req.Header = getFacebookHeader(s.AccessToken)
|
||||
|
||||
@ -71,12 +71,14 @@ func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, er
|
||||
var r result
|
||||
err = requests.RequestJSON(req, &r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
if r.Email == "" {
|
||||
return "", errors.New("no email")
|
||||
return nil, errors.New("no email")
|
||||
}
|
||||
return r.Email, nil
|
||||
return &UserDetails{
|
||||
Email: r.Email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
|
@ -201,8 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
|
||||
func (p *GitHubProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||
var emails []struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
@ -213,11 +212,11 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
|
||||
if p.Org != "" {
|
||||
if p.Team != "" {
|
||||
if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -231,32 +230,34 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
|
||||
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("got %d from %q %s",
|
||||
return nil, fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
|
||||
if err := json.Unmarshal(body, &emails); err != nil {
|
||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
||||
return nil, fmt.Errorf("%s unmarshaling %s", err, body)
|
||||
}
|
||||
|
||||
for _, email := range emails {
|
||||
if email.Primary && email.Verified {
|
||||
return email.Email, nil
|
||||
return &UserDetails{
|
||||
Email: email.Email,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetUserName returns the Account user name
|
||||
|
@ -97,7 +97,7 @@ func TestGitHubProviderOverrides(t *testing.T) {
|
||||
assert.Equal(t, "profile", p.Data().Scope)
|
||||
}
|
||||
|
||||
func TestGitHubProviderGetEmailAddress(t *testing.T) {
|
||||
func TestGitHubProviderGetUserDetails(t *testing.T) {
|
||||
b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`})
|
||||
defer b.Close()
|
||||
|
||||
@ -105,12 +105,12 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
|
||||
}
|
||||
|
||||
func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
|
||||
func TestGitHubProviderGetUserDetailsNotVerified(t *testing.T) {
|
||||
b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": false, "primary": true} ]`})
|
||||
defer b.Close()
|
||||
|
||||
@ -118,12 +118,12 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Empty(t, "", email)
|
||||
assert.Nil(t, details)
|
||||
}
|
||||
|
||||
func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
|
||||
func TestGitHubProviderGetUserDetailsWithOrg(t *testing.T) {
|
||||
b := testGitHubBackend([]string{
|
||||
`[ {"email": "michael.bland@gsa.gov", "primary": true, "verified": true, "login":"testorg"} ]`,
|
||||
`[ {"email": "michael.bland1@gsa.gov", "primary": true, "verified": true, "login":"testorg1"} ]`,
|
||||
@ -136,14 +136,14 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
|
||||
p.Org = "testorg1"
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
|
||||
}
|
||||
|
||||
// Note that trying to trigger the "failed building request" case is not
|
||||
// practical, since the only way it can fail is if the URL fails to parse.
|
||||
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
func TestGitHubProviderGetUserDetailsFailedRequest(t *testing.T) {
|
||||
b := testGitHubBackend([]string{"unused payload"})
|
||||
defer b.Close()
|
||||
|
||||
@ -154,12 +154,12 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
assert.Nil(t, details)
|
||||
}
|
||||
|
||||
func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
func TestGitHubProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
|
||||
b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"})
|
||||
defer b.Close()
|
||||
|
||||
@ -167,9 +167,9 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
assert.Nil(t, details)
|
||||
}
|
||||
|
||||
func TestGitHubProviderGetUserName(t *testing.T) {
|
||||
|
@ -220,31 +220,32 @@ func (p *GitLabProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
func (p *GitLabProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||
// Retrieve user info
|
||||
userInfo, err := p.getUserInfo(s)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to retrieve user info: %v", err)
|
||||
return nil, fmt.Errorf("failed to retrieve user info: %v", err)
|
||||
}
|
||||
|
||||
// Check if email is verified
|
||||
if !p.AllowUnverifiedEmail && !userInfo.EmailVerified {
|
||||
return "", fmt.Errorf("user email is not verified")
|
||||
return nil, fmt.Errorf("user email is not verified")
|
||||
}
|
||||
|
||||
// Check if email has valid domain
|
||||
err = p.verifyEmailDomain(userInfo)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("email domain check failed: %v", err)
|
||||
return nil, fmt.Errorf("email domain check failed: %v", err)
|
||||
}
|
||||
|
||||
// Check group membership
|
||||
err = p.verifyGroupMembership(userInfo)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("group membership check failed: %v", err)
|
||||
return nil, fmt.Errorf("group membership check failed: %v", err)
|
||||
}
|
||||
|
||||
return userInfo.Email, nil
|
||||
return &UserDetails{
|
||||
Email: userInfo.Email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetUserName returns the Account user name
|
||||
|
@ -63,7 +63,7 @@ func TestGitLabProviderBadToken(t *testing.T) {
|
||||
p := testGitLabProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"}
|
||||
_, err := p.GetEmailAddress(session)
|
||||
_, err := p.GetUserDetails(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
@ -75,7 +75,7 @@ func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) {
|
||||
p := testGitLabProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||
_, err := p.GetEmailAddress(session)
|
||||
_, err := p.GetUserDetails(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
@ -88,9 +88,9 @@ func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) {
|
||||
p.AllowUnverifiedEmail = true
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "foo@bar.com", email)
|
||||
assert.Equal(t, "foo@bar.com", details.Email)
|
||||
}
|
||||
|
||||
func TestGitLabProviderUsername(t *testing.T) {
|
||||
@ -117,9 +117,9 @@ func TestGitLabProviderGroupMembershipValid(t *testing.T) {
|
||||
p.Group = "foo"
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "foo@bar.com", email)
|
||||
assert.Equal(t, "foo@bar.com", details.Email)
|
||||
}
|
||||
|
||||
func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
|
||||
@ -132,7 +132,7 @@ func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
|
||||
p.Group = "baz"
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||
_, err := p.GetEmailAddress(session)
|
||||
_, err := p.GetUserDetails(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
||||
@ -146,9 +146,9 @@ func TestGitLabProviderEmailDomainValid(t *testing.T) {
|
||||
p.EmailDomains = []string{"bar.com"}
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "foo@bar.com", email)
|
||||
assert.Equal(t, "foo@bar.com", details.Email)
|
||||
}
|
||||
|
||||
func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
|
||||
@ -161,6 +161,6 @@ func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
|
||||
p.EmailDomains = []string{"baz.com"}
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||
_, err := p.GetEmailAddress(session)
|
||||
_, err := p.GetUserDetails(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
}
|
||||
|
@ -233,6 +233,10 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
|
||||
return p.GroupValidator(email)
|
||||
}
|
||||
|
||||
func (p *GoogleProvider) ValidateGroupWithSession(s *sessions.SessionState) bool {
|
||||
return p.ValidateGroup(s.Email)
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
|
@ -9,6 +9,20 @@ import (
|
||||
"github.com/pusher/oauth2_proxy/pkg/requests"
|
||||
)
|
||||
|
||||
type NotImplementedError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func NewNotImplementedError(message string) *NotImplementedError {
|
||||
return &NotImplementedError{
|
||||
message: "Not implemented: " + message,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *NotImplementedError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// stripToken is a helper function to obfuscate "access_token"
|
||||
// query parameters
|
||||
func stripToken(endpoint string) string {
|
||||
|
@ -51,26 +51,28 @@ func getLinkedInHeader(accessToken string) http.Header {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
func (p *LinkedInProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
return nil, errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
req.Header = getLinkedInHeader(s.AccessToken)
|
||||
|
||||
json, err := requests.Request(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
email, err := json.String()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
return email, nil
|
||||
return &UserDetails{
|
||||
Email: email,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
|
@ -91,7 +91,7 @@ func TestLinkedInProviderOverrides(t *testing.T) {
|
||||
assert.Equal(t, "profile", p.Data().Scope)
|
||||
}
|
||||
|
||||
func TestLinkedInProviderGetEmailAddress(t *testing.T) {
|
||||
func TestLinkedInProviderGetUserDetails(t *testing.T) {
|
||||
b := testLinkedInBackend(`"user@linkedin.com"`)
|
||||
defer b.Close()
|
||||
|
||||
@ -99,12 +99,12 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
|
||||
p := testLinkedInProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@linkedin.com", email)
|
||||
assert.Equal(t, "user@linkedin.com", details.Email)
|
||||
}
|
||||
|
||||
func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
func TestLinkedInProviderGetUserDetailsFailedRequest(t *testing.T) {
|
||||
b := testLinkedInBackend("unused payload")
|
||||
defer b.Close()
|
||||
|
||||
@ -115,12 +115,12 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
assert.Nil(t, details)
|
||||
}
|
||||
|
||||
func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
func TestLinkedInProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
|
||||
b := testLinkedInBackend("{\"foo\": \"bar\"}")
|
||||
defer b.Close()
|
||||
|
||||
@ -128,7 +128,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
p := testLinkedInProvider(bURL.Host)
|
||||
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
details, err := p.GetUserDetails(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
assert.Nil(t, details)
|
||||
}
|
||||
|
@ -110,14 +110,8 @@ func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error)
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
|
||||
userDetails := map[string]string{}
|
||||
email, err := p.GetEmailAddress(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userDetails["email"] = email
|
||||
return userDetails, nil
|
||||
func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||
return nil, NewNotImplementedError("")
|
||||
}
|
||||
|
||||
// GetUserName returns the Account username
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
type Provider interface {
|
||||
Data() *ProviderData
|
||||
GetEmailAddress(*sessions.SessionState) (string, error)
|
||||
GetUserDetails(*sessions.SessionState) (map[string]string, error)
|
||||
GetUserDetails(*sessions.SessionState) (*UserDetails, error)
|
||||
GetUserName(*sessions.SessionState) (string, error)
|
||||
GetGroups(*sessions.SessionState, string) (map[string]string, error)
|
||||
Redeem(string, string) (*sessions.SessionState, error)
|
||||
@ -23,6 +23,11 @@ type Provider interface {
|
||||
CookieForSession(*sessions.SessionState, *encryption.Cipher) (string, error)
|
||||
}
|
||||
|
||||
type UserDetails struct {
|
||||
Email string
|
||||
UID string
|
||||
}
|
||||
|
||||
// New provides a new Provider based on the configured provider string
|
||||
func New(provider string, p *ProviderData) Provider {
|
||||
switch provider {
|
||||
|
Loading…
Reference in New Issue
Block a user