Merge pull request #466 from clobrano/github-use-login-as-user

GitHub use login as user
This commit is contained in:
Heather Hendy 2017-11-20 12:48:14 -07:00 committed by GitHub
commit b0c1c85177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 315 additions and 45 deletions

View File

@ -244,6 +244,13 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
if s.Email == "" { if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(s) s.Email, err = p.provider.GetEmailAddress(s)
} }
if s.User == "" {
s.User, err = p.provider.GetUserName(s)
if err != nil && err.Error() == "not implemented" {
err = nil
}
}
return return
} }

View File

@ -218,10 +218,10 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s", return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body) resp.StatusCode, endpoint.String(), body)
} else {
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
} }
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &emails); err != nil { if err := json.Unmarshal(body, &emails); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body) return "", fmt.Errorf("%s unmarshaling %s", err, body)
} }
@ -234,3 +234,46 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
return "", nil return "", nil
} }
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
var user struct {
Login string `json:"login"`
Email string `json:"email"`
}
endpoint := &url.URL{
Scheme: p.ValidateURL.Scheme,
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user"),
}
req, err := http.NewRequest("GET", endpoint.String(), nil)
if err != nil {
return "", fmt.Errorf("could not create new GET request: %v", err)
}
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
body, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &user); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}
return user.Login, nil
}

146
providers/github_test.go Normal file
View File

@ -0,0 +1,146 @@
package providers
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
func testGitHubProvider(hostname string) *GitHubProvider {
p := NewGitHubProvider(
&ProviderData{
ProviderName: "",
LoginURL: &url.URL{},
RedeemURL: &url.URL{},
ProfileURL: &url.URL{},
ValidateURL: &url.URL{},
Scope: ""})
if hostname != "" {
updateURL(p.Data().LoginURL, hostname)
updateURL(p.Data().RedeemURL, hostname)
updateURL(p.Data().ProfileURL, hostname)
updateURL(p.Data().ValidateURL, hostname)
}
return p
}
func testGitHubBackend(payload string) *httptest.Server {
pathToQueryMap := map[string]string{
"/user": "",
"/user/emails": "",
}
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
query, ok := pathToQueryMap[url.Path]
if !ok {
w.WriteHeader(404)
} else if url.RawQuery != query {
w.WriteHeader(404)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
}
func TestGitHubProviderDefaults(t *testing.T) {
p := testGitHubProvider("")
assert.NotEqual(t, nil, p)
assert.Equal(t, "GitHub", p.Data().ProviderName)
assert.Equal(t, "https://github.com/login/oauth/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://github.com/login/oauth/access_token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://api.github.com/",
p.Data().ValidateURL.String())
assert.Equal(t, "user:email", p.Data().Scope)
}
func TestGitHubProviderOverrides(t *testing.T) {
p := NewGitHubProvider(
&ProviderData{
LoginURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/login/oauth/authorize"},
RedeemURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/login/oauth/access_token"},
ValidateURL: &url.URL{
Scheme: "https",
Host: "api.example.com",
Path: "/"},
Scope: "profile"})
assert.NotEqual(t, nil, p)
assert.Equal(t, "GitHub", p.Data().ProviderName)
assert.Equal(t, "https://example.com/login/oauth/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://example.com/login/oauth/access_token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://api.example.com/",
p.Data().ValidateURL.String())
assert.Equal(t, "profile", p.Data().Scope)
}
func TestGitHubProviderGetEmailAddress(t *testing.T) {
b := testGitHubBackend(`[ {"email": "michael.bland@gsa.gov", "primary": true} ]`)
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
}
// Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse.
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testGitHubBackend("unused payload")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}
func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testGitHubBackend("{\"foo\": \"bar\"}")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}
func TestGitHubProviderGetUserName(t *testing.T) {
b := testGitHubBackend(`{"email": "michael.bland@gsa.gov", "login": "mbland"}`)
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetUserName(session)
assert.Equal(t, nil, err)
assert.Equal(t, "mbland", email)
}

View File

@ -106,6 +106,11 @@ func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
// GetUserName returns the Account username
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
return "", errors.New("not implemented")
}
// ValidateGroup validates that the provided email exists in the configured provider // ValidateGroup validates that the provided email exists in the configured provider
// email group(s). // email group(s).
func (p *ProviderData) ValidateGroup(email string) bool { func (p *ProviderData) ValidateGroup(email string) bool {

View File

@ -7,6 +7,7 @@ import (
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(*SessionState) (string, error) GetEmailAddress(*SessionState) (string, error)
GetUserName(*SessionState) (string, error)
Redeem(string, string) (*SessionState, error) Redeem(string, string) (*SessionState, error)
ValidateGroup(string) bool ValidateGroup(string) bool
ValidateSessionState(*SessionState) bool ValidateSessionState(*SessionState) bool

View File

@ -25,7 +25,7 @@ func (s *SessionState) IsExpired() bool {
} }
func (s *SessionState) String() string { func (s *SessionState) String() string {
o := fmt.Sprintf("Session{%s", s.userOrEmail()) o := fmt.Sprintf("Session{%s", s.accountInfo())
if s.AccessToken != "" { if s.AccessToken != "" {
o += " token:true" o += " token:true"
} }
@ -40,17 +40,13 @@ func (s *SessionState) String() string {
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
if c == nil || s.AccessToken == "" { if c == nil || s.AccessToken == "" {
return s.userOrEmail(), nil return s.accountInfo(), nil
} }
return s.EncryptedString(c) return s.EncryptedString(c)
} }
func (s *SessionState) userOrEmail() string { func (s *SessionState) accountInfo() string {
u := s.User return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
if s.Email != "" {
u = s.Email
}
return u
} }
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
@ -60,56 +56,64 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
} }
a := s.AccessToken a := s.AccessToken
if a != "" { if a != "" {
a, err = c.Encrypt(a) if a, err = c.Encrypt(a); err != nil {
if err != nil {
return "", err return "", err
} }
} }
r := s.RefreshToken r := s.RefreshToken
if r != "" { if r != "" {
r, err = c.Encrypt(r) if r, err = c.Encrypt(r); err != nil {
if err != nil {
return "", err return "", err
} }
} }
return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil
}
func decodeSessionStatePlain(v string) (s *SessionState, err error) {
chunks := strings.Split(v, " ")
if len(chunks) != 2 {
return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks))
}
email := strings.TrimPrefix(chunks[0], "email:")
user := strings.TrimPrefix(chunks[1], "user:")
if user == "" {
user = strings.Split(email, "@")[0]
}
return &SessionState{User: user, Email: email}, nil
} }
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
chunks := strings.Split(v, "|") if c == nil {
if len(chunks) == 1 { return decodeSessionStatePlain(v)
if strings.Contains(chunks[0], "@") {
u := strings.Split(v, "@")[0]
return &SessionState{Email: v, User: u}, nil
}
return &SessionState{User: v}, nil
} }
chunks := strings.Split(v, "|")
if len(chunks) != 4 { if len(chunks) != 4 {
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
return return
} }
s = &SessionState{} sessionState, err := decodeSessionStatePlain(chunks[0])
if c != nil && chunks[1] != "" { if err != nil {
s.AccessToken, err = c.Decrypt(chunks[1]) return nil, err
if err != nil { }
if chunks[1] != "" {
if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil {
return nil, err return nil, err
} }
} }
if c != nil && chunks[3] != "" {
s.RefreshToken, err = c.Decrypt(chunks[3])
if err != nil {
return nil, err
}
}
if u := chunks[0]; strings.Contains(u, "@") {
s.Email = u
s.User = strings.Split(u, "@")[0]
} else {
s.User = u
}
ts, _ := strconv.Atoi(chunks[2]) ts, _ := strconv.Atoi(chunks[2])
s.ExpiresOn = time.Unix(int64(ts), 0) sessionState.ExpiresOn = time.Unix(int64(ts), 0)
return
if chunks[3] != "" {
if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil {
return nil, err
}
}
return sessionState, nil
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"fmt"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -30,6 +31,7 @@ func TestSessionStateSerialization(t *testing.T) {
ss, err := DecodeSessionState(encoded, c) ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user", ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
@ -39,6 +41,43 @@ func TestSessionStateSerialization(t *testing.T) {
ss, err = DecodeSessionState(encoded, c2) ss, err = DecodeSessionState(encoded, c2)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user", ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
}
func TestSessionStateSerializationWithUser(t *testing.T) {
c, err := cookie.NewCipher([]byte(secret))
assert.Equal(t, nil, err)
c2, err := cookie.NewCipher([]byte(altSecret))
assert.Equal(t, nil, err)
s := &SessionState{
User: "just-user",
Email: "user@domain.com",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err)
assert.Equal(t, 3, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = DecodeSessionState(encoded, c2)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.AccessToken, ss.AccessToken)
@ -46,7 +85,6 @@ func TestSessionStateSerialization(t *testing.T) {
} }
func TestSessionStateSerializationNoCipher(t *testing.T) { func TestSessionStateSerializationNoCipher(t *testing.T) {
s := &SessionState{ s := &SessionState{
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
@ -55,25 +93,51 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(nil) encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, s.Email, encoded) expected := fmt.Sprintf("email:%s user:", s.Email)
assert.Equal(t, expected, encoded)
// only email should have been serialized // only email should have been serialized
ss, err := DecodeSessionState(encoded, nil) ss, err := DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user", ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, "", ss.AccessToken) assert.Equal(t, "", ss.AccessToken)
assert.Equal(t, "", ss.RefreshToken) assert.Equal(t, "", ss.RefreshToken)
} }
func TestSessionStateUserOrEmail(t *testing.T) { func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
s := &SessionState{
User: "just-user",
Email: "user@domain.com",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err)
expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User)
assert.Equal(t, expected, encoded)
// only email should have been serialized
ss, err := DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, "", ss.AccessToken)
assert.Equal(t, "", ss.RefreshToken)
}
func TestSessionStateAccountInfo(t *testing.T) {
s := &SessionState{ s := &SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
} }
assert.Equal(t, "user@domain.com", s.userOrEmail()) expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User)
assert.Equal(t, expected, s.accountInfo())
s.Email = "" s.Email = ""
assert.Equal(t, "just-user", s.userOrEmail()) expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User)
assert.Equal(t, expected, s.accountInfo())
} }
func TestExpired(t *testing.T) { func TestExpired(t *testing.T) {