Fix tests
This commit is contained in:
parent
d7b8506ea2
commit
90208c7fe4
@ -335,11 +335,11 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState,
|
|||||||
if s.Email == "" {
|
if s.Email == "" {
|
||||||
userDetails, err := p.provider.GetUserDetails(s)
|
userDetails, err := p.provider.GetUserDetails(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return s, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s.Email = userDetails["email"]
|
s.Email = userDetails.Email
|
||||||
if uid, found := userDetails["uid"]; found {
|
if userDetails.UID != "" {
|
||||||
s.ID = uid
|
s.ID = userDetails.UID
|
||||||
} else {
|
} else {
|
||||||
s.ID = ""
|
s.ID = ""
|
||||||
}
|
}
|
||||||
|
@ -262,18 +262,10 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) {
|
func (tp *TestProvider) GetUserDetails(session *sessions.SessionState) (*providers.UserDetails, error) {
|
||||||
return tp.EmailAddress, nil
|
return &providers.UserDetails{
|
||||||
}
|
Email: tp.EmailAddress,
|
||||||
|
}, nil
|
||||||
func (tp *TestProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
|
|
||||||
userDetails := map[string]string{}
|
|
||||||
email, err := tp.GetEmailAddress(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
userDetails["email"] = email
|
|
||||||
return userDetails, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
|
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
|
||||||
|
@ -119,64 +119,22 @@ func getUserIDFromJSON(json *simplejson.Json) (string, error) {
|
|||||||
return uid, err
|
return uid, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||||
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
|
||||||
var email string
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return nil, errors.New("missing access token")
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
|
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header = getAzureHeader(s.AccessToken)
|
req.Header = getAzureHeader(s.AccessToken)
|
||||||
|
|
||||||
json, err := requests.Request(req)
|
json, err := requests.Request(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
|
||||||
|
|
||||||
email, err = getEmailFromJSON(json)
|
|
||||||
|
|
||||||
if err == nil && email != "" {
|
|
||||||
return email, err
|
|
||||||
}
|
|
||||||
|
|
||||||
email, err = json.Get("userPrincipalName").String()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.Printf("failed making request %s", err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if email == "" {
|
|
||||||
logger.Printf("failed to get email address")
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return email, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
|
|
||||||
userDetails := map[string]string{}
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if s.AccessToken == "" {
|
|
||||||
return userDetails, errors.New("missing access token")
|
|
||||||
}
|
|
||||||
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
|
|
||||||
if err != nil {
|
|
||||||
return userDetails, err
|
|
||||||
}
|
|
||||||
req.Header = getAzureHeader(s.AccessToken)
|
|
||||||
|
|
||||||
json, err := requests.Request(req)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return userDetails, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf(" JSON: %v", json)
|
logger.Printf(" JSON: %v", json)
|
||||||
@ -184,26 +142,26 @@ func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]str
|
|||||||
logger.Printf("\t %20v : %v", key, value)
|
logger.Printf("\t %20v : %v", key, value)
|
||||||
}
|
}
|
||||||
email, err := getEmailFromJSON(json)
|
email, err := getEmailFromJSON(json)
|
||||||
userDetails["email"] = email
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("[GetEmailAddress] failed making request: %s", err)
|
logger.Printf("[GetUserDetails] failed making request: %s", err)
|
||||||
return userDetails, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
uid, err := getUserIDFromJSON(json)
|
uid, err := getUserIDFromJSON(json)
|
||||||
userDetails["uid"] = uid
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("[GetEmailAddress] failed to get User ID: %s", err)
|
logger.Printf("[GetUserDetails] failed to get User ID: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if email == "" {
|
if email == "" {
|
||||||
logger.Printf("failed to get email address")
|
logger.Printf("failed to get email address")
|
||||||
return userDetails, errors.New("Client email not found")
|
return nil, errors.New("Client email not found")
|
||||||
}
|
}
|
||||||
logger.Printf("[GetEmailAddress] Chosen email address: '%s'", email)
|
logger.Printf("[GetUserDetails] Chosen email address: '%s'", email)
|
||||||
|
return &UserDetails{
|
||||||
return userDetails, nil
|
Email: email,
|
||||||
|
UID: uid,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get list of groups user belong to. Filter the desired names of groups (in case of huge group set)
|
// Get list of groups user belong to. Filter the desired names of groups (in case of huge group set)
|
||||||
|
@ -121,7 +121,7 @@ func testAzureBackend(payload string) *httptest.Server {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureProviderGetEmailAddress(t *testing.T) {
|
func TestAzureProviderGetUserDetails(t *testing.T) {
|
||||||
b := testAzureBackend(`{ "mail": "user@windows.net" }`)
|
b := testAzureBackend(`{ "mail": "user@windows.net" }`)
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -129,12 +129,12 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
|
|||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@windows.net", email)
|
assert.Equal(t, "user@windows.net", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
|
func TestAzureProviderGetUserDetailsMailNull(t *testing.T) {
|
||||||
b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`)
|
b := testAzureBackend(`{ "mail": null, "otherMails": ["user@windows.net", "altuser@windows.net"] }`)
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -142,12 +142,12 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
|
|||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@windows.net", email)
|
assert.Equal(t, "user@windows.net", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
|
func TestAzureProviderGetUserDetailsGetUserPrincipalName(t *testing.T) {
|
||||||
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`)
|
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "user@windows.net" }`)
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -155,12 +155,12 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
|
|||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@windows.net", email)
|
assert.Equal(t, "user@windows.net", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
|
func TestAzureProviderGetUserDetailsFailToGetEmailAddress(t *testing.T) {
|
||||||
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`)
|
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": null }`)
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -168,12 +168,12 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
|
|||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||||
assert.Equal(t, "", email)
|
assert.Nil(t, details)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
|
func TestAzureProviderGetUserDetailsEmptyUserPrincipalName(t *testing.T) {
|
||||||
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`)
|
b := testAzureBackend(`{ "mail": null, "otherMails": [], "userPrincipalName": "" }`)
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -181,12 +181,12 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
|
|||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, "Client email not found", err.Error())
|
||||||
assert.Equal(t, "", email)
|
assert.Nil(t, details)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
|
func TestAzureProviderGetUserDetailsIncorrectOtherMails(t *testing.T) {
|
||||||
b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`)
|
b := testAzureBackend(`{ "mail": null, "otherMails": "", "userPrincipalName": null }`)
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -194,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
|
|||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||||
assert.Equal(t, "", email)
|
assert.Nil(t, details)
|
||||||
}
|
}
|
||||||
|
@ -64,8 +64,7 @@ func (p *BitbucketProvider) SetRepository(repository string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the email of the authenticated user
|
// GetEmailAddress returns the email of the authenticated user
|
||||||
func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
func (p *BitbucketProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||||
|
|
||||||
var emails struct {
|
var emails struct {
|
||||||
Values []struct {
|
Values []struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
@ -86,12 +85,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
|
|||||||
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed building request %s", err)
|
logger.Printf("failed building request %s", err)
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = requests.RequestJSON(req, &emails)
|
err = requests.RequestJSON(req, &emails)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed making request %s", err)
|
logger.Printf("failed making request %s", err)
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.Team != "" {
|
if p.Team != "" {
|
||||||
@ -102,12 +101,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
|
|||||||
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
|
teamURL.String()+"?role=member&access_token="+s.AccessToken, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed building request %s", err)
|
logger.Printf("failed building request %s", err)
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = requests.RequestJSON(req, &teams)
|
err = requests.RequestJSON(req, &teams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed requesting teams membership %s", err)
|
logger.Printf("failed requesting teams membership %s", err)
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
var found = false
|
var found = false
|
||||||
for _, team := range teams.Values {
|
for _, team := range teams.Values {
|
||||||
@ -118,7 +117,7 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
|
|||||||
}
|
}
|
||||||
if found != true {
|
if found != true {
|
||||||
logger.Print("team membership test failed, access denied")
|
logger.Print("team membership test failed, access denied")
|
||||||
return "", nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,12 +132,12 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
|
|||||||
nil)
|
nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed building request %s", err)
|
logger.Printf("failed building request %s", err)
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = requests.RequestJSON(req, &repositories)
|
err = requests.RequestJSON(req, &repositories)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("failed checking repository access %s", err)
|
logger.Printf("failed checking repository access %s", err)
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
var found = false
|
var found = false
|
||||||
for _, repository := range repositories.Values {
|
for _, repository := range repositories.Values {
|
||||||
@ -149,15 +148,17 @@ func (p *BitbucketProvider) GetEmailAddress(s *sessions.SessionState) (string, e
|
|||||||
}
|
}
|
||||||
if found != true {
|
if found != true {
|
||||||
logger.Print("repository access test failed, access denied")
|
logger.Print("repository access test failed, access denied")
|
||||||
return "", nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, email := range emails.Values {
|
for _, email := range emails.Values {
|
||||||
if email.Primary {
|
if email.Primary {
|
||||||
return email.Email, nil
|
return &UserDetails{
|
||||||
|
Email: email.Email,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
@ -112,7 +112,7 @@ func TestBitbucketProviderOverrides(t *testing.T) {
|
|||||||
assert.Equal(t, "profile", p.Data().Scope)
|
assert.Equal(t, "profile", p.Data().Scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitbucketProviderGetEmailAddress(t *testing.T) {
|
func TestBitbucketProviderGetUserDetails(t *testing.T) {
|
||||||
b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true } ] }")
|
b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true } ] }")
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -120,12 +120,12 @@ func TestBitbucketProviderGetEmailAddress(t *testing.T) {
|
|||||||
p := testBitbucketProvider(bURL.Host, "", "")
|
p := testBitbucketProvider(bURL.Host, "", "")
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) {
|
func TestBitbucketProviderGetUserDetailsAndGroup(t *testing.T) {
|
||||||
b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true, \"username\": \"bioinformatics\" } ] }")
|
b := testBitbucketBackend("{\"values\": [ { \"email\": \"michael.bland@gsa.gov\", \"is_primary\": true, \"username\": \"bioinformatics\" } ] }")
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -133,14 +133,14 @@ func TestBitbucketProviderGetEmailAddressAndGroup(t *testing.T) {
|
|||||||
p := testBitbucketProvider(bURL.Host, "bioinformatics", "")
|
p := testBitbucketProvider(bURL.Host, "bioinformatics", "")
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note that trying to trigger the "failed building request" case is not
|
// Note that trying to trigger the "failed building request" case is not
|
||||||
// practical, since the only way it can fail is if the URL fails to parse.
|
// practical, since the only way it can fail is if the URL fails to parse.
|
||||||
func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) {
|
func TestBitbucketProviderGetUserDetailsFailedRequest(t *testing.T) {
|
||||||
b := testBitbucketBackend("unused payload")
|
b := testBitbucketBackend("unused payload")
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -151,12 +151,12 @@ func TestBitbucketProviderGetEmailAddressFailedRequest(t *testing.T) {
|
|||||||
// token. Alternatively, we could allow the parsing of the payload as
|
// token. Alternatively, we could allow the parsing of the payload as
|
||||||
// JSON to fail.
|
// JSON to fail.
|
||||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Nil(t, details)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
func TestBitbucketProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
|
||||||
b := testBitbucketBackend("{\"foo\": \"bar\"}")
|
b := testBitbucketBackend("{\"foo\": \"bar\"}")
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -164,7 +164,7 @@ func TestBitbucketProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T)
|
|||||||
p := testBitbucketProvider(bURL.Host, "", "")
|
p := testBitbucketProvider(bURL.Host, "", "")
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, "", email)
|
assert.Nil(t, details)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
}
|
}
|
||||||
|
@ -55,13 +55,13 @@ func getFacebookHeader(accessToken string) http.Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
func (p *FacebookProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return nil, errors.New("missing access token")
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil)
|
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?fields=name,email", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header = getFacebookHeader(s.AccessToken)
|
req.Header = getFacebookHeader(s.AccessToken)
|
||||||
|
|
||||||
@ -71,12 +71,14 @@ func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, er
|
|||||||
var r result
|
var r result
|
||||||
err = requests.RequestJSON(req, &r)
|
err = requests.RequestJSON(req, &r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
if r.Email == "" {
|
if r.Email == "" {
|
||||||
return "", errors.New("no email")
|
return nil, errors.New("no email")
|
||||||
}
|
}
|
||||||
return r.Email, nil
|
return &UserDetails{
|
||||||
|
Email: r.Email,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
|
@ -201,8 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
func (p *GitHubProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||||
|
|
||||||
var emails []struct {
|
var emails []struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Primary bool `json:"primary"`
|
Primary bool `json:"primary"`
|
||||||
@ -213,11 +212,11 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
|
|||||||
if p.Org != "" {
|
if p.Org != "" {
|
||||||
if p.Team != "" {
|
if p.Team != "" {
|
||||||
if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok {
|
if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok {
|
if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -231,32 +230,34 @@ func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, erro
|
|||||||
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
|
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
|
||||||
resp, err := http.DefaultClient.Do(req)
|
resp, err := http.DefaultClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
body, err := ioutil.ReadAll(resp.Body)
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return "", fmt.Errorf("got %d from %q %s",
|
return nil, fmt.Errorf("got %d from %q %s",
|
||||||
resp.StatusCode, endpoint.String(), body)
|
resp.StatusCode, endpoint.String(), body)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
logger.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &emails); err != nil {
|
if err := json.Unmarshal(body, &emails); err != nil {
|
||||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
return nil, fmt.Errorf("%s unmarshaling %s", err, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, email := range emails {
|
for _, email := range emails {
|
||||||
if email.Primary && email.Verified {
|
if email.Primary && email.Verified {
|
||||||
return email.Email, nil
|
return &UserDetails{
|
||||||
|
Email: email.Email,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserName returns the Account user name
|
// GetUserName returns the Account user name
|
||||||
|
@ -97,7 +97,7 @@ func TestGitHubProviderOverrides(t *testing.T) {
|
|||||||
assert.Equal(t, "profile", p.Data().Scope)
|
assert.Equal(t, "profile", p.Data().Scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitHubProviderGetEmailAddress(t *testing.T) {
|
func TestGitHubProviderGetUserDetails(t *testing.T) {
|
||||||
b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`})
|
b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": true, "primary": true} ]`})
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -105,12 +105,12 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
|
|||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
|
func TestGitHubProviderGetUserDetailsNotVerified(t *testing.T) {
|
||||||
b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": false, "primary": true} ]`})
|
b := testGitHubBackend([]string{`[ {"email": "michael.bland@gsa.gov", "verified": false, "primary": true} ]`})
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -118,12 +118,12 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
|
|||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Empty(t, "", email)
|
assert.Nil(t, details)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
|
func TestGitHubProviderGetUserDetailsWithOrg(t *testing.T) {
|
||||||
b := testGitHubBackend([]string{
|
b := testGitHubBackend([]string{
|
||||||
`[ {"email": "michael.bland@gsa.gov", "primary": true, "verified": true, "login":"testorg"} ]`,
|
`[ {"email": "michael.bland@gsa.gov", "primary": true, "verified": true, "login":"testorg"} ]`,
|
||||||
`[ {"email": "michael.bland1@gsa.gov", "primary": true, "verified": true, "login":"testorg1"} ]`,
|
`[ {"email": "michael.bland1@gsa.gov", "primary": true, "verified": true, "login":"testorg1"} ]`,
|
||||||
@ -136,14 +136,14 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
|
|||||||
p.Org = "testorg1"
|
p.Org = "testorg1"
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
assert.Equal(t, "michael.bland@gsa.gov", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note that trying to trigger the "failed building request" case is not
|
// Note that trying to trigger the "failed building request" case is not
|
||||||
// practical, since the only way it can fail is if the URL fails to parse.
|
// practical, since the only way it can fail is if the URL fails to parse.
|
||||||
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
|
func TestGitHubProviderGetUserDetailsFailedRequest(t *testing.T) {
|
||||||
b := testGitHubBackend([]string{"unused payload"})
|
b := testGitHubBackend([]string{"unused payload"})
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -154,12 +154,12 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
|
|||||||
// token. Alternatively, we could allow the parsing of the payload as
|
// token. Alternatively, we could allow the parsing of the payload as
|
||||||
// JSON to fail.
|
// JSON to fail.
|
||||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Nil(t, details)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
func TestGitHubProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
|
||||||
b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"})
|
b := testGitHubBackend([]string{"{\"foo\": \"bar\"}"})
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -167,9 +167,9 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
|||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Nil(t, details)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitHubProviderGetUserName(t *testing.T) {
|
func TestGitHubProviderGetUserName(t *testing.T) {
|
||||||
|
@ -220,31 +220,32 @@ func (p *GitLabProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
func (p *GitLabProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||||
// Retrieve user info
|
// Retrieve user info
|
||||||
userInfo, err := p.getUserInfo(s)
|
userInfo, err := p.getUserInfo(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to retrieve user info: %v", err)
|
return nil, fmt.Errorf("failed to retrieve user info: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if email is verified
|
// Check if email is verified
|
||||||
if !p.AllowUnverifiedEmail && !userInfo.EmailVerified {
|
if !p.AllowUnverifiedEmail && !userInfo.EmailVerified {
|
||||||
return "", fmt.Errorf("user email is not verified")
|
return nil, fmt.Errorf("user email is not verified")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if email has valid domain
|
// Check if email has valid domain
|
||||||
err = p.verifyEmailDomain(userInfo)
|
err = p.verifyEmailDomain(userInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("email domain check failed: %v", err)
|
return nil, fmt.Errorf("email domain check failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check group membership
|
// Check group membership
|
||||||
err = p.verifyGroupMembership(userInfo)
|
err = p.verifyGroupMembership(userInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("group membership check failed: %v", err)
|
return nil, fmt.Errorf("group membership check failed: %v", err)
|
||||||
}
|
}
|
||||||
|
return &UserDetails{
|
||||||
return userInfo.Email, nil
|
Email: userInfo.Email,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserName returns the Account user name
|
// GetUserName returns the Account user name
|
||||||
|
@ -63,7 +63,7 @@ func TestGitLabProviderBadToken(t *testing.T) {
|
|||||||
p := testGitLabProvider(bURL.Host)
|
p := testGitLabProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"}
|
session := &sessions.SessionState{AccessToken: "unexpected_gitlab_access_token"}
|
||||||
_, err := p.GetEmailAddress(session)
|
_, err := p.GetUserDetails(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,7 +75,7 @@ func TestGitLabProviderUnverifiedEmailDenied(t *testing.T) {
|
|||||||
p := testGitLabProvider(bURL.Host)
|
p := testGitLabProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||||
_, err := p.GetEmailAddress(session)
|
_, err := p.GetUserDetails(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,9 +88,9 @@ func TestGitLabProviderUnverifiedEmailAllowed(t *testing.T) {
|
|||||||
p.AllowUnverifiedEmail = true
|
p.AllowUnverifiedEmail = true
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "foo@bar.com", email)
|
assert.Equal(t, "foo@bar.com", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitLabProviderUsername(t *testing.T) {
|
func TestGitLabProviderUsername(t *testing.T) {
|
||||||
@ -117,9 +117,9 @@ func TestGitLabProviderGroupMembershipValid(t *testing.T) {
|
|||||||
p.Group = "foo"
|
p.Group = "foo"
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "foo@bar.com", email)
|
assert.Equal(t, "foo@bar.com", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
|
func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
|
||||||
@ -132,7 +132,7 @@ func TestGitLabProviderGroupMembershipMissing(t *testing.T) {
|
|||||||
p.Group = "baz"
|
p.Group = "baz"
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||||
_, err := p.GetEmailAddress(session)
|
_, err := p.GetUserDetails(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,9 +146,9 @@ func TestGitLabProviderEmailDomainValid(t *testing.T) {
|
|||||||
p.EmailDomains = []string{"bar.com"}
|
p.EmailDomains = []string{"bar.com"}
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "foo@bar.com", email)
|
assert.Equal(t, "foo@bar.com", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
|
func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
|
||||||
@ -161,6 +161,6 @@ func TestGitLabProviderEmailDomainInvalid(t *testing.T) {
|
|||||||
p.EmailDomains = []string{"baz.com"}
|
p.EmailDomains = []string{"baz.com"}
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
session := &sessions.SessionState{AccessToken: "gitlab_access_token"}
|
||||||
_, err := p.GetEmailAddress(session)
|
_, err := p.GetUserDetails(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
@ -233,6 +233,10 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
|
|||||||
return p.GroupValidator(email)
|
return p.GroupValidator(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *GoogleProvider) ValidateGroupWithSession(s *sessions.SessionState) bool {
|
||||||
|
return p.ValidateGroup(s.Email)
|
||||||
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new ID token if required
|
// RefreshToken to fetch a new ID token if required
|
||||||
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||||
|
@ -9,6 +9,20 @@ import (
|
|||||||
"github.com/pusher/oauth2_proxy/pkg/requests"
|
"github.com/pusher/oauth2_proxy/pkg/requests"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type NotImplementedError struct {
|
||||||
|
message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewNotImplementedError(message string) *NotImplementedError {
|
||||||
|
return &NotImplementedError{
|
||||||
|
message: "Not implemented: " + message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NotImplementedError) Error() string {
|
||||||
|
return e.message
|
||||||
|
}
|
||||||
|
|
||||||
// stripToken is a helper function to obfuscate "access_token"
|
// stripToken is a helper function to obfuscate "access_token"
|
||||||
// query parameters
|
// query parameters
|
||||||
func stripToken(endpoint string) string {
|
func stripToken(endpoint string) string {
|
||||||
|
@ -51,26 +51,28 @@ func getLinkedInHeader(accessToken string) http.Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
func (p *LinkedInProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return nil, errors.New("missing access token")
|
||||||
}
|
}
|
||||||
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil)
|
req, err := http.NewRequest("GET", p.ProfileURL.String()+"?format=json", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header = getLinkedInHeader(s.AccessToken)
|
req.Header = getLinkedInHeader(s.AccessToken)
|
||||||
|
|
||||||
json, err := requests.Request(req)
|
json, err := requests.Request(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
email, err := json.String()
|
email, err := json.String()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
return email, nil
|
return &UserDetails{
|
||||||
|
Email: email,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
|
@ -91,7 +91,7 @@ func TestLinkedInProviderOverrides(t *testing.T) {
|
|||||||
assert.Equal(t, "profile", p.Data().Scope)
|
assert.Equal(t, "profile", p.Data().Scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLinkedInProviderGetEmailAddress(t *testing.T) {
|
func TestLinkedInProviderGetUserDetails(t *testing.T) {
|
||||||
b := testLinkedInBackend(`"user@linkedin.com"`)
|
b := testLinkedInBackend(`"user@linkedin.com"`)
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -99,12 +99,12 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
|
|||||||
p := testLinkedInProvider(bURL.Host)
|
p := testLinkedInProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@linkedin.com", email)
|
assert.Equal(t, "user@linkedin.com", details.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
|
func TestLinkedInProviderGetUserDetailsFailedRequest(t *testing.T) {
|
||||||
b := testLinkedInBackend("unused payload")
|
b := testLinkedInBackend("unused payload")
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -115,12 +115,12 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
|
|||||||
// token. Alternatively, we could allow the parsing of the payload as
|
// token. Alternatively, we could allow the parsing of the payload as
|
||||||
// JSON to fail.
|
// JSON to fail.
|
||||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Nil(t, details)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
func TestLinkedInProviderGetUserDetailsEmailNotPresentInPayload(t *testing.T) {
|
||||||
b := testLinkedInBackend("{\"foo\": \"bar\"}")
|
b := testLinkedInBackend("{\"foo\": \"bar\"}")
|
||||||
defer b.Close()
|
defer b.Close()
|
||||||
|
|
||||||
@ -128,7 +128,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
|||||||
p := testLinkedInProvider(bURL.Host)
|
p := testLinkedInProvider(bURL.Host)
|
||||||
|
|
||||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
details, err := p.GetUserDetails(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Nil(t, details)
|
||||||
}
|
}
|
||||||
|
@ -110,14 +110,8 @@ func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error)
|
|||||||
return "", errors.New("not implemented")
|
return "", errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
|
func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
|
||||||
userDetails := map[string]string{}
|
return nil, NewNotImplementedError("")
|
||||||
email, err := p.GetEmailAddress(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
userDetails["email"] = email
|
|
||||||
return userDetails, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserName returns the Account username
|
// GetUserName returns the Account username
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
type Provider interface {
|
type Provider interface {
|
||||||
Data() *ProviderData
|
Data() *ProviderData
|
||||||
GetEmailAddress(*sessions.SessionState) (string, error)
|
GetEmailAddress(*sessions.SessionState) (string, error)
|
||||||
GetUserDetails(*sessions.SessionState) (map[string]string, error)
|
GetUserDetails(*sessions.SessionState) (*UserDetails, error)
|
||||||
GetUserName(*sessions.SessionState) (string, error)
|
GetUserName(*sessions.SessionState) (string, error)
|
||||||
GetGroups(*sessions.SessionState, string) (map[string]string, error)
|
GetGroups(*sessions.SessionState, string) (map[string]string, error)
|
||||||
Redeem(string, string) (*sessions.SessionState, error)
|
Redeem(string, string) (*sessions.SessionState, error)
|
||||||
@ -23,6 +23,11 @@ type Provider interface {
|
|||||||
CookieForSession(*sessions.SessionState, *encryption.Cipher) (string, error)
|
CookieForSession(*sessions.SessionState, *encryption.Cipher) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UserDetails struct {
|
||||||
|
Email string
|
||||||
|
UID string
|
||||||
|
}
|
||||||
|
|
||||||
// New provides a new Provider based on the configured provider string
|
// New provides a new Provider based on the configured provider string
|
||||||
func New(provider string, p *ProviderData) Provider {
|
func New(provider string, p *ProviderData) Provider {
|
||||||
switch provider {
|
switch provider {
|
||||||
|
Loading…
Reference in New Issue
Block a user