Merge pull request #466 from clobrano/github-use-login-as-user
GitHub use login as user
This commit is contained in:
commit
b0c1c85177
@ -244,6 +244,13 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
|
||||
if s.Email == "" {
|
||||
s.Email, err = p.provider.GetEmailAddress(s)
|
||||
}
|
||||
|
||||
if s.User == "" {
|
||||
s.User, err = p.provider.GetUserName(s)
|
||||
if err != nil && err.Error() == "not implemented" {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -218,10 +218,10 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, endpoint.String(), body)
|
||||
} else {
|
||||
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
|
||||
if err := json.Unmarshal(body, &emails); err != nil {
|
||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
||||
}
|
||||
@ -234,3 +234,46 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
|
||||
var user struct {
|
||||
Login string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
endpoint := &url.URL{
|
||||
Scheme: p.ValidateURL.Scheme,
|
||||
Host: p.ValidateURL.Host,
|
||||
Path: path.Join(p.ValidateURL.Path, "/user"),
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", endpoint.String(), nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not create new GET request: %v", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
defer resp.Body.Close()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
return "", fmt.Errorf("got %d from %q %s",
|
||||
resp.StatusCode, endpoint.String(), body)
|
||||
}
|
||||
|
||||
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||
|
||||
if err := json.Unmarshal(body, &user); err != nil {
|
||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
||||
}
|
||||
|
||||
return user.Login, nil
|
||||
}
|
||||
|
146
providers/github_test.go
Normal file
146
providers/github_test.go
Normal file
@ -0,0 +1,146 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testGitHubProvider(hostname string) *GitHubProvider {
|
||||
p := NewGitHubProvider(
|
||||
&ProviderData{
|
||||
ProviderName: "",
|
||||
LoginURL: &url.URL{},
|
||||
RedeemURL: &url.URL{},
|
||||
ProfileURL: &url.URL{},
|
||||
ValidateURL: &url.URL{},
|
||||
Scope: ""})
|
||||
if hostname != "" {
|
||||
updateURL(p.Data().LoginURL, hostname)
|
||||
updateURL(p.Data().RedeemURL, hostname)
|
||||
updateURL(p.Data().ProfileURL, hostname)
|
||||
updateURL(p.Data().ValidateURL, hostname)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func testGitHubBackend(payload string) *httptest.Server {
|
||||
pathToQueryMap := map[string]string{
|
||||
"/user": "",
|
||||
"/user/emails": "",
|
||||
}
|
||||
|
||||
return httptest.NewServer(http.HandlerFunc(
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
url := r.URL
|
||||
query, ok := pathToQueryMap[url.Path]
|
||||
if !ok {
|
||||
w.WriteHeader(404)
|
||||
} else if url.RawQuery != query {
|
||||
w.WriteHeader(404)
|
||||
} else {
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte(payload))
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func TestGitHubProviderDefaults(t *testing.T) {
|
||||
p := testGitHubProvider("")
|
||||
assert.NotEqual(t, nil, p)
|
||||
assert.Equal(t, "GitHub", p.Data().ProviderName)
|
||||
assert.Equal(t, "https://github.com/login/oauth/authorize",
|
||||
p.Data().LoginURL.String())
|
||||
assert.Equal(t, "https://github.com/login/oauth/access_token",
|
||||
p.Data().RedeemURL.String())
|
||||
assert.Equal(t, "https://api.github.com/",
|
||||
p.Data().ValidateURL.String())
|
||||
assert.Equal(t, "user:email", p.Data().Scope)
|
||||
}
|
||||
|
||||
func TestGitHubProviderOverrides(t *testing.T) {
|
||||
p := NewGitHubProvider(
|
||||
&ProviderData{
|
||||
LoginURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "example.com",
|
||||
Path: "/login/oauth/authorize"},
|
||||
RedeemURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "example.com",
|
||||
Path: "/login/oauth/access_token"},
|
||||
ValidateURL: &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "api.example.com",
|
||||
Path: "/"},
|
||||
Scope: "profile"})
|
||||
assert.NotEqual(t, nil, p)
|
||||
assert.Equal(t, "GitHub", p.Data().ProviderName)
|
||||
assert.Equal(t, "https://example.com/login/oauth/authorize",
|
||||
p.Data().LoginURL.String())
|
||||
assert.Equal(t, "https://example.com/login/oauth/access_token",
|
||||
p.Data().RedeemURL.String())
|
||||
assert.Equal(t, "https://api.example.com/",
|
||||
p.Data().ValidateURL.String())
|
||||
assert.Equal(t, "profile", p.Data().Scope)
|
||||
}
|
||||
|
||||
func TestGitHubProviderGetEmailAddress(t *testing.T) {
|
||||
b := testGitHubBackend(`[ {"email": "michael.bland@gsa.gov", "primary": true} ]`)
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
}
|
||||
|
||||
// Note that trying to trigger the "failed building request" case is not
|
||||
// practical, since the only way it can fail is if the URL fails to parse.
|
||||
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
b := testGitHubBackend("unused payload")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
}
|
||||
|
||||
func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
b := testGitHubBackend("{\"foo\": \"bar\"}")
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
}
|
||||
|
||||
func TestGitHubProviderGetUserName(t *testing.T) {
|
||||
b := testGitHubBackend(`{"email": "michael.bland@gsa.gov", "login": "mbland"}`)
|
||||
defer b.Close()
|
||||
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetUserName(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "mbland", email)
|
||||
}
|
@ -106,6 +106,11 @@ func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
// GetUserName returns the Account username
|
||||
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ValidateGroup validates that the provided email exists in the configured provider
|
||||
// email group(s).
|
||||
func (p *ProviderData) ValidateGroup(email string) bool {
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
type Provider interface {
|
||||
Data() *ProviderData
|
||||
GetEmailAddress(*SessionState) (string, error)
|
||||
GetUserName(*SessionState) (string, error)
|
||||
Redeem(string, string) (*SessionState, error)
|
||||
ValidateGroup(string) bool
|
||||
ValidateSessionState(*SessionState) bool
|
||||
|
@ -25,7 +25,7 @@ func (s *SessionState) IsExpired() bool {
|
||||
}
|
||||
|
||||
func (s *SessionState) String() string {
|
||||
o := fmt.Sprintf("Session{%s", s.userOrEmail())
|
||||
o := fmt.Sprintf("Session{%s", s.accountInfo())
|
||||
if s.AccessToken != "" {
|
||||
o += " token:true"
|
||||
}
|
||||
@ -40,17 +40,13 @@ func (s *SessionState) String() string {
|
||||
|
||||
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
|
||||
if c == nil || s.AccessToken == "" {
|
||||
return s.userOrEmail(), nil
|
||||
return s.accountInfo(), nil
|
||||
}
|
||||
return s.EncryptedString(c)
|
||||
}
|
||||
|
||||
func (s *SessionState) userOrEmail() string {
|
||||
u := s.User
|
||||
if s.Email != "" {
|
||||
u = s.Email
|
||||
}
|
||||
return u
|
||||
func (s *SessionState) accountInfo() string {
|
||||
return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
|
||||
}
|
||||
|
||||
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
|
||||
@ -60,56 +56,64 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
|
||||
}
|
||||
a := s.AccessToken
|
||||
if a != "" {
|
||||
a, err = c.Encrypt(a)
|
||||
if err != nil {
|
||||
if a, err = c.Encrypt(a); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
r := s.RefreshToken
|
||||
if r != "" {
|
||||
r, err = c.Encrypt(r)
|
||||
if err != nil {
|
||||
if r, err = c.Encrypt(r); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil
|
||||
return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil
|
||||
}
|
||||
|
||||
func decodeSessionStatePlain(v string) (s *SessionState, err error) {
|
||||
chunks := strings.Split(v, " ")
|
||||
if len(chunks) != 2 {
|
||||
return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks))
|
||||
}
|
||||
|
||||
email := strings.TrimPrefix(chunks[0], "email:")
|
||||
user := strings.TrimPrefix(chunks[1], "user:")
|
||||
if user == "" {
|
||||
user = strings.Split(email, "@")[0]
|
||||
}
|
||||
|
||||
return &SessionState{User: user, Email: email}, nil
|
||||
}
|
||||
|
||||
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
|
||||
chunks := strings.Split(v, "|")
|
||||
if len(chunks) == 1 {
|
||||
if strings.Contains(chunks[0], "@") {
|
||||
u := strings.Split(v, "@")[0]
|
||||
return &SessionState{Email: v, User: u}, nil
|
||||
}
|
||||
return &SessionState{User: v}, nil
|
||||
if c == nil {
|
||||
return decodeSessionStatePlain(v)
|
||||
}
|
||||
|
||||
chunks := strings.Split(v, "|")
|
||||
if len(chunks) != 4 {
|
||||
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
|
||||
return
|
||||
}
|
||||
|
||||
s = &SessionState{}
|
||||
if c != nil && chunks[1] != "" {
|
||||
s.AccessToken, err = c.Decrypt(chunks[1])
|
||||
sessionState, err := decodeSessionStatePlain(chunks[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if c != nil && chunks[3] != "" {
|
||||
s.RefreshToken, err = c.Decrypt(chunks[3])
|
||||
if err != nil {
|
||||
|
||||
if chunks[1] != "" {
|
||||
if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if u := chunks[0]; strings.Contains(u, "@") {
|
||||
s.Email = u
|
||||
s.User = strings.Split(u, "@")[0]
|
||||
} else {
|
||||
s.User = u
|
||||
}
|
||||
|
||||
ts, _ := strconv.Atoi(chunks[2])
|
||||
s.ExpiresOn = time.Unix(int64(ts), 0)
|
||||
return
|
||||
sessionState.ExpiresOn = time.Unix(int64(ts), 0)
|
||||
|
||||
if chunks[3] != "" {
|
||||
if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return sessionState, nil
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -30,6 +31,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user", ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
@ -39,6 +41,43 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
ss, err = DecodeSessionState(encoded, c2)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user", ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
c, err := cookie.NewCipher([]byte(secret))
|
||||
assert.Equal(t, nil, err)
|
||||
c2, err := cookie.NewCipher([]byte(altSecret))
|
||||
assert.Equal(t, nil, err)
|
||||
s := &SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, 3, strings.Count(encoded, "|"))
|
||||
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||
ss, err = DecodeSessionState(encoded, c2)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||
@ -46,7 +85,6 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
|
||||
s := &SessionState{
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
@ -55,25 +93,51 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.Email, encoded)
|
||||
expected := fmt.Sprintf("email:%s user:", s.Email)
|
||||
assert.Equal(t, expected, encoded)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user", ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, "", ss.AccessToken)
|
||||
assert.Equal(t, "", ss.RefreshToken)
|
||||
}
|
||||
|
||||
func TestSessionStateUserOrEmail(t *testing.T) {
|
||||
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
s := &SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
RefreshToken: "refresh4321",
|
||||
}
|
||||
encoded, err := s.EncodeSessionState(nil)
|
||||
assert.Equal(t, nil, err)
|
||||
expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User)
|
||||
assert.Equal(t, expected, encoded)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
assert.Equal(t, "", ss.AccessToken)
|
||||
assert.Equal(t, "", ss.RefreshToken)
|
||||
}
|
||||
|
||||
func TestSessionStateAccountInfo(t *testing.T) {
|
||||
s := &SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
}
|
||||
assert.Equal(t, "user@domain.com", s.userOrEmail())
|
||||
expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User)
|
||||
assert.Equal(t, expected, s.accountInfo())
|
||||
|
||||
s.Email = ""
|
||||
assert.Equal(t, "just-user", s.userOrEmail())
|
||||
expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User)
|
||||
assert.Equal(t, expected, s.accountInfo())
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user