Github provider: use login as user

- Save both user and email in session state:
    Encoding/decoding methods save both email and user
    field in session state, for use cases when User is not derived from
    email's local-parth, like for GitHub provider.

    For retrocompatibility, if no user is obtained by the provider,
    (e.g. User is an empty string) the encoding/decoding methods fall back
    to the previous behavior and use the email's local-part

    Updated also related tests and added two more tests to show behavior
    when session contains a non-empty user value.

- Added first basic GitHub provider tests

- Added GetUserName method to Provider interface
    The new GetUserName method is intended to return the User
    value when this is not the email's local-part.

    Added also the default implementation to provider_default.go

- Added call to GetUserName in redeemCode

    the new GetUserName method is used in redeemCode
    to get SessionState User value.

    For backward compatibility, if GetUserName error is
    "not implemented", the error is ignored.

- Added GetUserName method and tests to github provider.
This commit is contained in:
Carlo Lobrano 2017-09-26 23:31:27 +02:00
parent 6ddbb2c572
commit 731fa9f8e0
7 changed files with 315 additions and 45 deletions

View File

@ -244,6 +244,13 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
if s.Email == "" { if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(s) s.Email, err = p.provider.GetEmailAddress(s)
} }
if s.User == "" {
s.User, err = p.provider.GetUserName(s)
if err != nil && err.Error() == "not implemented" {
err = nil
}
}
return return
} }

View File

@ -218,10 +218,10 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s", return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body) resp.StatusCode, endpoint.String(), body)
} else {
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
} }
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &emails); err != nil { if err := json.Unmarshal(body, &emails); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body) return "", fmt.Errorf("%s unmarshaling %s", err, body)
} }
@ -234,3 +234,46 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
return "", nil return "", nil
} }
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
var user struct {
Login string `json:"login"`
Email string `json:"email"`
}
endpoint := &url.URL{
Scheme: p.ValidateURL.Scheme,
Host: p.ValidateURL.Host,
Path: path.Join(p.ValidateURL.Path, "/user"),
}
req, err := http.NewRequest("GET", endpoint.String(), nil)
if err != nil {
return "", fmt.Errorf("could not create new GET request: %v", err)
}
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
body, err := ioutil.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
return "", err
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s",
resp.StatusCode, endpoint.String(), body)
}
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
if err := json.Unmarshal(body, &user); err != nil {
return "", fmt.Errorf("%s unmarshaling %s", err, body)
}
return user.Login, nil
}

146
providers/github_test.go Normal file
View File

@ -0,0 +1,146 @@
package providers
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
func testGitHubProvider(hostname string) *GitHubProvider {
p := NewGitHubProvider(
&ProviderData{
ProviderName: "",
LoginURL: &url.URL{},
RedeemURL: &url.URL{},
ProfileURL: &url.URL{},
ValidateURL: &url.URL{},
Scope: ""})
if hostname != "" {
updateURL(p.Data().LoginURL, hostname)
updateURL(p.Data().RedeemURL, hostname)
updateURL(p.Data().ProfileURL, hostname)
updateURL(p.Data().ValidateURL, hostname)
}
return p
}
func testGitHubBackend(payload string) *httptest.Server {
pathToQueryMap := map[string]string{
"/user": "",
"/user/emails": "",
}
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
query, ok := pathToQueryMap[url.Path]
if !ok {
w.WriteHeader(404)
} else if url.RawQuery != query {
w.WriteHeader(404)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
}
func TestGitHubProviderDefaults(t *testing.T) {
p := testGitHubProvider("")
assert.NotEqual(t, nil, p)
assert.Equal(t, "GitHub", p.Data().ProviderName)
assert.Equal(t, "https://github.com/login/oauth/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://github.com/login/oauth/access_token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://api.github.com/",
p.Data().ValidateURL.String())
assert.Equal(t, "user:email", p.Data().Scope)
}
func TestGitHubProviderOverrides(t *testing.T) {
p := NewGitHubProvider(
&ProviderData{
LoginURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/login/oauth/authorize"},
RedeemURL: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/login/oauth/access_token"},
ValidateURL: &url.URL{
Scheme: "https",
Host: "api.example.com",
Path: "/"},
Scope: "profile"})
assert.NotEqual(t, nil, p)
assert.Equal(t, "GitHub", p.Data().ProviderName)
assert.Equal(t, "https://example.com/login/oauth/authorize",
p.Data().LoginURL.String())
assert.Equal(t, "https://example.com/login/oauth/access_token",
p.Data().RedeemURL.String())
assert.Equal(t, "https://api.example.com/",
p.Data().ValidateURL.String())
assert.Equal(t, "profile", p.Data().Scope)
}
func TestGitHubProviderGetEmailAddress(t *testing.T) {
b := testGitHubBackend(`[ {"email": "michael.bland@gsa.gov", "primary": true} ]`)
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
}
// Note that trying to trigger the "failed building request" case is not
// practical, since the only way it can fail is if the URL fails to parse.
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
b := testGitHubBackend("unused payload")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}
func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b := testGitHubBackend("{\"foo\": \"bar\"}")
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
}
func TestGitHubProviderGetUserName(t *testing.T) {
b := testGitHubBackend(`{"email": "michael.bland@gsa.gov", "login": "mbland"}`)
defer b.Close()
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetUserName(session)
assert.Equal(t, nil, err)
assert.Equal(t, "mbland", email)
}

View File

@ -106,6 +106,11 @@ func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
// GetUserName returns the Account username
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
return "", errors.New("not implemented")
}
// ValidateGroup validates that the provided email exists in the configured provider // ValidateGroup validates that the provided email exists in the configured provider
// email group(s). // email group(s).
func (p *ProviderData) ValidateGroup(email string) bool { func (p *ProviderData) ValidateGroup(email string) bool {

View File

@ -7,6 +7,7 @@ import (
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(*SessionState) (string, error) GetEmailAddress(*SessionState) (string, error)
GetUserName(*SessionState) (string, error)
Redeem(string, string) (*SessionState, error) Redeem(string, string) (*SessionState, error)
ValidateGroup(string) bool ValidateGroup(string) bool
ValidateSessionState(*SessionState) bool ValidateSessionState(*SessionState) bool

View File

@ -25,7 +25,7 @@ func (s *SessionState) IsExpired() bool {
} }
func (s *SessionState) String() string { func (s *SessionState) String() string {
o := fmt.Sprintf("Session{%s", s.userOrEmail()) o := fmt.Sprintf("Session{%s", s.accountInfo())
if s.AccessToken != "" { if s.AccessToken != "" {
o += " token:true" o += " token:true"
} }
@ -40,17 +40,13 @@ func (s *SessionState) String() string {
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
if c == nil || s.AccessToken == "" { if c == nil || s.AccessToken == "" {
return s.userOrEmail(), nil return s.accountInfo(), nil
} }
return s.EncryptedString(c) return s.EncryptedString(c)
} }
func (s *SessionState) userOrEmail() string { func (s *SessionState) accountInfo() string {
u := s.User return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
if s.Email != "" {
u = s.Email
}
return u
} }
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
@ -60,56 +56,64 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
} }
a := s.AccessToken a := s.AccessToken
if a != "" { if a != "" {
a, err = c.Encrypt(a) if a, err = c.Encrypt(a); err != nil {
if err != nil {
return "", err return "", err
} }
} }
r := s.RefreshToken r := s.RefreshToken
if r != "" { if r != "" {
r, err = c.Encrypt(r) if r, err = c.Encrypt(r); err != nil {
if err != nil {
return "", err return "", err
} }
} }
return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil
}
func decodeSessionStatePlain(v string) (s *SessionState, err error) {
chunks := strings.Split(v, " ")
if len(chunks) != 2 {
return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks))
}
email := strings.TrimPrefix(chunks[0], "email:")
user := strings.TrimPrefix(chunks[1], "user:")
if user == "" {
user = strings.Split(email, "@")[0]
}
return &SessionState{User: user, Email: email}, nil
} }
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
chunks := strings.Split(v, "|") if c == nil {
if len(chunks) == 1 { return decodeSessionStatePlain(v)
if strings.Contains(chunks[0], "@") {
u := strings.Split(v, "@")[0]
return &SessionState{Email: v, User: u}, nil
}
return &SessionState{User: v}, nil
} }
chunks := strings.Split(v, "|")
if len(chunks) != 4 { if len(chunks) != 4 {
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
return return
} }
s = &SessionState{} sessionState, err := decodeSessionStatePlain(chunks[0])
if c != nil && chunks[1] != "" { if err != nil {
s.AccessToken, err = c.Decrypt(chunks[1]) return nil, err
if err != nil { }
if chunks[1] != "" {
if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil {
return nil, err return nil, err
} }
} }
if c != nil && chunks[3] != "" {
s.RefreshToken, err = c.Decrypt(chunks[3])
if err != nil {
return nil, err
}
}
if u := chunks[0]; strings.Contains(u, "@") {
s.Email = u
s.User = strings.Split(u, "@")[0]
} else {
s.User = u
}
ts, _ := strconv.Atoi(chunks[2]) ts, _ := strconv.Atoi(chunks[2])
s.ExpiresOn = time.Unix(int64(ts), 0) sessionState.ExpiresOn = time.Unix(int64(ts), 0)
return
if chunks[3] != "" {
if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil {
return nil, err
}
}
return sessionState, nil
} }

View File

@ -1,6 +1,7 @@
package providers package providers
import ( import (
"fmt"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -30,6 +31,7 @@ func TestSessionStateSerialization(t *testing.T) {
ss, err := DecodeSessionState(encoded, c) ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user", ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
@ -39,6 +41,43 @@ func TestSessionStateSerialization(t *testing.T) {
ss, err = DecodeSessionState(encoded, c2) ss, err = DecodeSessionState(encoded, c2)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user", ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
}
func TestSessionStateSerializationWithUser(t *testing.T) {
c, err := cookie.NewCipher([]byte(secret))
assert.Equal(t, nil, err)
c2, err := cookie.NewCipher([]byte(altSecret))
assert.Equal(t, nil, err)
s := &SessionState{
User: "just-user",
Email: "user@domain.com",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err)
assert.Equal(t, 3, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = DecodeSessionState(encoded, c2)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken) assert.NotEqual(t, s.AccessToken, ss.AccessToken)
@ -46,7 +85,6 @@ func TestSessionStateSerialization(t *testing.T) {
} }
func TestSessionStateSerializationNoCipher(t *testing.T) { func TestSessionStateSerializationNoCipher(t *testing.T) {
s := &SessionState{ s := &SessionState{
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
@ -55,25 +93,51 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
} }
encoded, err := s.EncodeSessionState(nil) encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, s.Email, encoded) expected := fmt.Sprintf("email:%s user:", s.Email)
assert.Equal(t, expected, encoded)
// only email should have been serialized // only email should have been serialized
ss, err := DecodeSessionState(encoded, nil) ss, err := DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user", ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, "", ss.AccessToken) assert.Equal(t, "", ss.AccessToken)
assert.Equal(t, "", ss.RefreshToken) assert.Equal(t, "", ss.RefreshToken)
} }
func TestSessionStateUserOrEmail(t *testing.T) { func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
s := &SessionState{
User: "just-user",
Email: "user@domain.com",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err)
expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User)
assert.Equal(t, expected, encoded)
// only email should have been serialized
ss, err := DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, "", ss.AccessToken)
assert.Equal(t, "", ss.RefreshToken)
}
func TestSessionStateAccountInfo(t *testing.T) {
s := &SessionState{ s := &SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
} }
assert.Equal(t, "user@domain.com", s.userOrEmail()) expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User)
assert.Equal(t, expected, s.accountInfo())
s.Email = "" s.Email = ""
assert.Equal(t, "just-user", s.userOrEmail()) expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User)
assert.Equal(t, expected, s.accountInfo())
} }
func TestExpired(t *testing.T) { func TestExpired(t *testing.T) {