From 731fa9f8e00b294ff1bf4a687e63d2ecdd9a4a50 Mon Sep 17 00:00:00 2001 From: Carlo Lobrano Date: Tue, 26 Sep 2017 23:31:27 +0200 Subject: [PATCH] Github provider: use login as user - Save both user and email in session state: Encoding/decoding methods save both email and user field in session state, for use cases when User is not derived from email's local-parth, like for GitHub provider. For retrocompatibility, if no user is obtained by the provider, (e.g. User is an empty string) the encoding/decoding methods fall back to the previous behavior and use the email's local-part Updated also related tests and added two more tests to show behavior when session contains a non-empty user value. - Added first basic GitHub provider tests - Added GetUserName method to Provider interface The new GetUserName method is intended to return the User value when this is not the email's local-part. Added also the default implementation to provider_default.go - Added call to GetUserName in redeemCode the new GetUserName method is used in redeemCode to get SessionState User value. For backward compatibility, if GetUserName error is "not implemented", the error is ignored. - Added GetUserName method and tests to github provider. --- oauthproxy.go | 7 ++ providers/github.go | 47 +++++++++- providers/github_test.go | 146 ++++++++++++++++++++++++++++++++ providers/provider_default.go | 5 ++ providers/providers.go | 1 + providers/session_state.go | 80 ++++++++--------- providers/session_state_test.go | 74 ++++++++++++++-- 7 files changed, 315 insertions(+), 45 deletions(-) create mode 100644 providers/github_test.go diff --git a/oauthproxy.go b/oauthproxy.go index bc69369..f94aa6e 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -244,6 +244,13 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e if s.Email == "" { s.Email, err = p.provider.GetEmailAddress(s) } + + if s.User == "" { + s.User, err = p.provider.GetUserName(s) + if err != nil && err.Error() == "not implemented" { + err = nil + } + } return } diff --git a/providers/github.go b/providers/github.go index 512eed8..f3af86f 100644 --- a/providers/github.go +++ b/providers/github.go @@ -218,10 +218,10 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { if resp.StatusCode != 200 { return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } else { - log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) } + log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) + if err := json.Unmarshal(body, &emails); err != nil { return "", fmt.Errorf("%s unmarshaling %s", err, body) } @@ -234,3 +234,46 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { return "", nil } + +func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { + var user struct { + Login string `json:"login"` + Email string `json:"email"` + } + + endpoint := &url.URL{ + Scheme: p.ValidateURL.Scheme, + Host: p.ValidateURL.Host, + Path: path.Join(p.ValidateURL.Path, "/user"), + } + + req, err := http.NewRequest("GET", endpoint.String(), nil) + if err != nil { + return "", fmt.Errorf("could not create new GET request: %v", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + + body, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + return "", err + } + + if resp.StatusCode != 200 { + return "", fmt.Errorf("got %d from %q %s", + resp.StatusCode, endpoint.String(), body) + } + + log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) + + if err := json.Unmarshal(body, &user); err != nil { + return "", fmt.Errorf("%s unmarshaling %s", err, body) + } + + return user.Login, nil +} diff --git a/providers/github_test.go b/providers/github_test.go new file mode 100644 index 0000000..8080525 --- /dev/null +++ b/providers/github_test.go @@ -0,0 +1,146 @@ +package providers + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +func testGitHubProvider(hostname string) *GitHubProvider { + p := NewGitHubProvider( + &ProviderData{ + ProviderName: "", + LoginURL: &url.URL{}, + RedeemURL: &url.URL{}, + ProfileURL: &url.URL{}, + ValidateURL: &url.URL{}, + Scope: ""}) + if hostname != "" { + updateURL(p.Data().LoginURL, hostname) + updateURL(p.Data().RedeemURL, hostname) + updateURL(p.Data().ProfileURL, hostname) + updateURL(p.Data().ValidateURL, hostname) + } + return p +} + +func testGitHubBackend(payload string) *httptest.Server { + pathToQueryMap := map[string]string{ + "/user": "", + "/user/emails": "", + } + + return httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + url := r.URL + query, ok := pathToQueryMap[url.Path] + if !ok { + w.WriteHeader(404) + } else if url.RawQuery != query { + w.WriteHeader(404) + } else { + w.WriteHeader(200) + w.Write([]byte(payload)) + } + })) +} + +func TestGitHubProviderDefaults(t *testing.T) { + p := testGitHubProvider("") + assert.NotEqual(t, nil, p) + assert.Equal(t, "GitHub", p.Data().ProviderName) + assert.Equal(t, "https://github.com/login/oauth/authorize", + p.Data().LoginURL.String()) + assert.Equal(t, "https://github.com/login/oauth/access_token", + p.Data().RedeemURL.String()) + assert.Equal(t, "https://api.github.com/", + p.Data().ValidateURL.String()) + assert.Equal(t, "user:email", p.Data().Scope) +} + +func TestGitHubProviderOverrides(t *testing.T) { + p := NewGitHubProvider( + &ProviderData{ + LoginURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/login/oauth/authorize"}, + RedeemURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/login/oauth/access_token"}, + ValidateURL: &url.URL{ + Scheme: "https", + Host: "api.example.com", + Path: "/"}, + Scope: "profile"}) + assert.NotEqual(t, nil, p) + assert.Equal(t, "GitHub", p.Data().ProviderName) + assert.Equal(t, "https://example.com/login/oauth/authorize", + p.Data().LoginURL.String()) + assert.Equal(t, "https://example.com/login/oauth/access_token", + p.Data().RedeemURL.String()) + assert.Equal(t, "https://api.example.com/", + p.Data().ValidateURL.String()) + assert.Equal(t, "profile", p.Data().Scope) +} + +func TestGitHubProviderGetEmailAddress(t *testing.T) { + b := testGitHubBackend(`[ {"email": "michael.bland@gsa.gov", "primary": true} ]`) + defer b.Close() + + bURL, _ := url.Parse(b.URL) + p := testGitHubProvider(bURL.Host) + + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetEmailAddress(session) + assert.Equal(t, nil, err) + assert.Equal(t, "michael.bland@gsa.gov", email) +} + +// Note that trying to trigger the "failed building request" case is not +// practical, since the only way it can fail is if the URL fails to parse. +func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { + b := testGitHubBackend("unused payload") + defer b.Close() + + bURL, _ := url.Parse(b.URL) + p := testGitHubProvider(bURL.Host) + + // We'll trigger a request failure by using an unexpected access + // token. Alternatively, we could allow the parsing of the payload as + // JSON to fail. + session := &SessionState{AccessToken: "unexpected_access_token"} + email, err := p.GetEmailAddress(session) + assert.NotEqual(t, nil, err) + assert.Equal(t, "", email) +} + +func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { + b := testGitHubBackend("{\"foo\": \"bar\"}") + defer b.Close() + + bURL, _ := url.Parse(b.URL) + p := testGitHubProvider(bURL.Host) + + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetEmailAddress(session) + assert.NotEqual(t, nil, err) + assert.Equal(t, "", email) +} + +func TestGitHubProviderGetUserName(t *testing.T) { + b := testGitHubBackend(`{"email": "michael.bland@gsa.gov", "login": "mbland"}`) + defer b.Close() + + bURL, _ := url.Parse(b.URL) + p := testGitHubProvider(bURL.Host) + + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetUserName(session) + assert.Equal(t, nil, err) + assert.Equal(t, "mbland", email) +} diff --git a/providers/provider_default.go b/providers/provider_default.go index 1d1daea..355e6c3 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -106,6 +106,11 @@ func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { return "", errors.New("not implemented") } +// GetUserName returns the Account username +func (p *ProviderData) GetUserName(s *SessionState) (string, error) { + return "", errors.New("not implemented") +} + // ValidateGroup validates that the provided email exists in the configured provider // email group(s). func (p *ProviderData) ValidateGroup(email string) bool { diff --git a/providers/providers.go b/providers/providers.go index 8a4e7ca..70e707b 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -7,6 +7,7 @@ import ( type Provider interface { Data() *ProviderData GetEmailAddress(*SessionState) (string, error) + GetUserName(*SessionState) (string, error) Redeem(string, string) (*SessionState, error) ValidateGroup(string) bool ValidateSessionState(*SessionState) bool diff --git a/providers/session_state.go b/providers/session_state.go index 214b5a4..805c702 100644 --- a/providers/session_state.go +++ b/providers/session_state.go @@ -25,7 +25,7 @@ func (s *SessionState) IsExpired() bool { } func (s *SessionState) String() string { - o := fmt.Sprintf("Session{%s", s.userOrEmail()) + o := fmt.Sprintf("Session{%s", s.accountInfo()) if s.AccessToken != "" { o += " token:true" } @@ -40,17 +40,13 @@ func (s *SessionState) String() string { func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { if c == nil || s.AccessToken == "" { - return s.userOrEmail(), nil + return s.accountInfo(), nil } return s.EncryptedString(c) } -func (s *SessionState) userOrEmail() string { - u := s.User - if s.Email != "" { - u = s.Email - } - return u +func (s *SessionState) accountInfo() string { + return fmt.Sprintf("email:%s user:%s", s.Email, s.User) } func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { @@ -60,56 +56,64 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { } a := s.AccessToken if a != "" { - a, err = c.Encrypt(a) - if err != nil { + if a, err = c.Encrypt(a); err != nil { return "", err } } r := s.RefreshToken if r != "" { - r, err = c.Encrypt(r) - if err != nil { + if r, err = c.Encrypt(r); err != nil { return "", err } } - return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil + return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil +} + +func decodeSessionStatePlain(v string) (s *SessionState, err error) { + chunks := strings.Split(v, " ") + if len(chunks) != 2 { + return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks)) + } + + email := strings.TrimPrefix(chunks[0], "email:") + user := strings.TrimPrefix(chunks[1], "user:") + if user == "" { + user = strings.Split(email, "@")[0] + } + + return &SessionState{User: user, Email: email}, nil } func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { - chunks := strings.Split(v, "|") - if len(chunks) == 1 { - if strings.Contains(chunks[0], "@") { - u := strings.Split(v, "@")[0] - return &SessionState{Email: v, User: u}, nil - } - return &SessionState{User: v}, nil + if c == nil { + return decodeSessionStatePlain(v) } + chunks := strings.Split(v, "|") if len(chunks) != 4 { err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) return } - s = &SessionState{} - if c != nil && chunks[1] != "" { - s.AccessToken, err = c.Decrypt(chunks[1]) - if err != nil { + sessionState, err := decodeSessionStatePlain(chunks[0]) + if err != nil { + return nil, err + } + + if chunks[1] != "" { + if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil { return nil, err } } - if c != nil && chunks[3] != "" { - s.RefreshToken, err = c.Decrypt(chunks[3]) - if err != nil { - return nil, err - } - } - if u := chunks[0]; strings.Contains(u, "@") { - s.Email = u - s.User = strings.Split(u, "@")[0] - } else { - s.User = u - } + ts, _ := strconv.Atoi(chunks[2]) - s.ExpiresOn = time.Unix(int64(ts), 0) - return + sessionState.ExpiresOn = time.Unix(int64(ts), 0) + + if chunks[3] != "" { + if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil { + return nil, err + } + } + + return sessionState, nil } diff --git a/providers/session_state_test.go b/providers/session_state_test.go index 0cf6d3e..d3cc8f8 100644 --- a/providers/session_state_test.go +++ b/providers/session_state_test.go @@ -1,6 +1,7 @@ package providers import ( + "fmt" "strings" "testing" "time" @@ -30,6 +31,7 @@ func TestSessionStateSerialization(t *testing.T) { ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) + assert.Equal(t, "user", ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) @@ -39,6 +41,43 @@ func TestSessionStateSerialization(t *testing.T) { ss, err = DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.Equal(t, nil, err) + assert.Equal(t, "user", ss.User) + assert.Equal(t, s.Email, ss.Email) + assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) + assert.NotEqual(t, s.AccessToken, ss.AccessToken) + assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) +} + +func TestSessionStateSerializationWithUser(t *testing.T) { + c, err := cookie.NewCipher([]byte(secret)) + assert.Equal(t, nil, err) + c2, err := cookie.NewCipher([]byte(altSecret)) + assert.Equal(t, nil, err) + s := &SessionState{ + User: "just-user", + Email: "user@domain.com", + AccessToken: "token1234", + ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), + RefreshToken: "refresh4321", + } + encoded, err := s.EncodeSessionState(c) + assert.Equal(t, nil, err) + assert.Equal(t, 3, strings.Count(encoded, "|")) + + ss, err := DecodeSessionState(encoded, c) + t.Logf("%#v", ss) + assert.Equal(t, nil, err) + assert.Equal(t, s.User, ss.User) + assert.Equal(t, s.Email, ss.Email) + assert.Equal(t, s.AccessToken, ss.AccessToken) + assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) + assert.Equal(t, s.RefreshToken, ss.RefreshToken) + + // ensure a different cipher can't decode properly (ie: it gets gibberish) + ss, err = DecodeSessionState(encoded, c2) + t.Logf("%#v", ss) + assert.Equal(t, nil, err) + assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.NotEqual(t, s.AccessToken, ss.AccessToken) @@ -46,7 +85,6 @@ func TestSessionStateSerialization(t *testing.T) { } func TestSessionStateSerializationNoCipher(t *testing.T) { - s := &SessionState{ Email: "user@domain.com", AccessToken: "token1234", @@ -55,25 +93,51 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { } encoded, err := s.EncodeSessionState(nil) assert.Equal(t, nil, err) - assert.Equal(t, s.Email, encoded) + expected := fmt.Sprintf("email:%s user:", s.Email) + assert.Equal(t, expected, encoded) // only email should have been serialized ss, err := DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) + assert.Equal(t, "user", ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, "", ss.AccessToken) assert.Equal(t, "", ss.RefreshToken) } -func TestSessionStateUserOrEmail(t *testing.T) { +func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { + s := &SessionState{ + User: "just-user", + Email: "user@domain.com", + AccessToken: "token1234", + ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), + RefreshToken: "refresh4321", + } + encoded, err := s.EncodeSessionState(nil) + assert.Equal(t, nil, err) + expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User) + assert.Equal(t, expected, encoded) + // only email should have been serialized + ss, err := DecodeSessionState(encoded, nil) + assert.Equal(t, nil, err) + assert.Equal(t, s.User, ss.User) + assert.Equal(t, s.Email, ss.Email) + assert.Equal(t, "", ss.AccessToken) + assert.Equal(t, "", ss.RefreshToken) +} + +func TestSessionStateAccountInfo(t *testing.T) { s := &SessionState{ Email: "user@domain.com", User: "just-user", } - assert.Equal(t, "user@domain.com", s.userOrEmail()) + expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User) + assert.Equal(t, expected, s.accountInfo()) + s.Email = "" - assert.Equal(t, "just-user", s.userOrEmail()) + expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User) + assert.Equal(t, expected, s.accountInfo()) } func TestExpired(t *testing.T) {