diff --git a/oauthproxy.go b/oauthproxy.go index bc69369..f94aa6e 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -244,6 +244,13 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e if s.Email == "" { s.Email, err = p.provider.GetEmailAddress(s) } + + if s.User == "" { + s.User, err = p.provider.GetUserName(s) + if err != nil && err.Error() == "not implemented" { + err = nil + } + } return } diff --git a/providers/github.go b/providers/github.go index 512eed8..f3af86f 100644 --- a/providers/github.go +++ b/providers/github.go @@ -218,10 +218,10 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { if resp.StatusCode != 200 { return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) - } else { - log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) } + log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) + if err := json.Unmarshal(body, &emails); err != nil { return "", fmt.Errorf("%s unmarshaling %s", err, body) } @@ -234,3 +234,46 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { return "", nil } + +func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { + var user struct { + Login string `json:"login"` + Email string `json:"email"` + } + + endpoint := &url.URL{ + Scheme: p.ValidateURL.Scheme, + Host: p.ValidateURL.Host, + Path: path.Join(p.ValidateURL.Path, "/user"), + } + + req, err := http.NewRequest("GET", endpoint.String(), nil) + if err != nil { + return "", fmt.Errorf("could not create new GET request: %v", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken)) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", err + } + + body, err := ioutil.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { + return "", err + } + + if resp.StatusCode != 200 { + return "", fmt.Errorf("got %d from %q %s", + resp.StatusCode, endpoint.String(), body) + } + + log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body) + + if err := json.Unmarshal(body, &user); err != nil { + return "", fmt.Errorf("%s unmarshaling %s", err, body) + } + + return user.Login, nil +} diff --git a/providers/github_test.go b/providers/github_test.go new file mode 100644 index 0000000..8080525 --- /dev/null +++ b/providers/github_test.go @@ -0,0 +1,146 @@ +package providers + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +func testGitHubProvider(hostname string) *GitHubProvider { + p := NewGitHubProvider( + &ProviderData{ + ProviderName: "", + LoginURL: &url.URL{}, + RedeemURL: &url.URL{}, + ProfileURL: &url.URL{}, + ValidateURL: &url.URL{}, + Scope: ""}) + if hostname != "" { + updateURL(p.Data().LoginURL, hostname) + updateURL(p.Data().RedeemURL, hostname) + updateURL(p.Data().ProfileURL, hostname) + updateURL(p.Data().ValidateURL, hostname) + } + return p +} + +func testGitHubBackend(payload string) *httptest.Server { + pathToQueryMap := map[string]string{ + "/user": "", + "/user/emails": "", + } + + return httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + url := r.URL + query, ok := pathToQueryMap[url.Path] + if !ok { + w.WriteHeader(404) + } else if url.RawQuery != query { + w.WriteHeader(404) + } else { + w.WriteHeader(200) + w.Write([]byte(payload)) + } + })) +} + +func TestGitHubProviderDefaults(t *testing.T) { + p := testGitHubProvider("") + assert.NotEqual(t, nil, p) + assert.Equal(t, "GitHub", p.Data().ProviderName) + assert.Equal(t, "https://github.com/login/oauth/authorize", + p.Data().LoginURL.String()) + assert.Equal(t, "https://github.com/login/oauth/access_token", + p.Data().RedeemURL.String()) + assert.Equal(t, "https://api.github.com/", + p.Data().ValidateURL.String()) + assert.Equal(t, "user:email", p.Data().Scope) +} + +func TestGitHubProviderOverrides(t *testing.T) { + p := NewGitHubProvider( + &ProviderData{ + LoginURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/login/oauth/authorize"}, + RedeemURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/login/oauth/access_token"}, + ValidateURL: &url.URL{ + Scheme: "https", + Host: "api.example.com", + Path: "/"}, + Scope: "profile"}) + assert.NotEqual(t, nil, p) + assert.Equal(t, "GitHub", p.Data().ProviderName) + assert.Equal(t, "https://example.com/login/oauth/authorize", + p.Data().LoginURL.String()) + assert.Equal(t, "https://example.com/login/oauth/access_token", + p.Data().RedeemURL.String()) + assert.Equal(t, "https://api.example.com/", + p.Data().ValidateURL.String()) + assert.Equal(t, "profile", p.Data().Scope) +} + +func TestGitHubProviderGetEmailAddress(t *testing.T) { + b := testGitHubBackend(`[ {"email": "michael.bland@gsa.gov", "primary": true} ]`) + defer b.Close() + + bURL, _ := url.Parse(b.URL) + p := testGitHubProvider(bURL.Host) + + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetEmailAddress(session) + assert.Equal(t, nil, err) + assert.Equal(t, "michael.bland@gsa.gov", email) +} + +// Note that trying to trigger the "failed building request" case is not +// practical, since the only way it can fail is if the URL fails to parse. +func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { + b := testGitHubBackend("unused payload") + defer b.Close() + + bURL, _ := url.Parse(b.URL) + p := testGitHubProvider(bURL.Host) + + // We'll trigger a request failure by using an unexpected access + // token. Alternatively, we could allow the parsing of the payload as + // JSON to fail. + session := &SessionState{AccessToken: "unexpected_access_token"} + email, err := p.GetEmailAddress(session) + assert.NotEqual(t, nil, err) + assert.Equal(t, "", email) +} + +func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { + b := testGitHubBackend("{\"foo\": \"bar\"}") + defer b.Close() + + bURL, _ := url.Parse(b.URL) + p := testGitHubProvider(bURL.Host) + + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetEmailAddress(session) + assert.NotEqual(t, nil, err) + assert.Equal(t, "", email) +} + +func TestGitHubProviderGetUserName(t *testing.T) { + b := testGitHubBackend(`{"email": "michael.bland@gsa.gov", "login": "mbland"}`) + defer b.Close() + + bURL, _ := url.Parse(b.URL) + p := testGitHubProvider(bURL.Host) + + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetUserName(session) + assert.Equal(t, nil, err) + assert.Equal(t, "mbland", email) +} diff --git a/providers/provider_default.go b/providers/provider_default.go index 1d1daea..355e6c3 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -106,6 +106,11 @@ func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { return "", errors.New("not implemented") } +// GetUserName returns the Account username +func (p *ProviderData) GetUserName(s *SessionState) (string, error) { + return "", errors.New("not implemented") +} + // ValidateGroup validates that the provided email exists in the configured provider // email group(s). func (p *ProviderData) ValidateGroup(email string) bool { diff --git a/providers/providers.go b/providers/providers.go index 8a4e7ca..70e707b 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -7,6 +7,7 @@ import ( type Provider interface { Data() *ProviderData GetEmailAddress(*SessionState) (string, error) + GetUserName(*SessionState) (string, error) Redeem(string, string) (*SessionState, error) ValidateGroup(string) bool ValidateSessionState(*SessionState) bool diff --git a/providers/session_state.go b/providers/session_state.go index 214b5a4..805c702 100644 --- a/providers/session_state.go +++ b/providers/session_state.go @@ -25,7 +25,7 @@ func (s *SessionState) IsExpired() bool { } func (s *SessionState) String() string { - o := fmt.Sprintf("Session{%s", s.userOrEmail()) + o := fmt.Sprintf("Session{%s", s.accountInfo()) if s.AccessToken != "" { o += " token:true" } @@ -40,17 +40,13 @@ func (s *SessionState) String() string { func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) { if c == nil || s.AccessToken == "" { - return s.userOrEmail(), nil + return s.accountInfo(), nil } return s.EncryptedString(c) } -func (s *SessionState) userOrEmail() string { - u := s.User - if s.Email != "" { - u = s.Email - } - return u +func (s *SessionState) accountInfo() string { + return fmt.Sprintf("email:%s user:%s", s.Email, s.User) } func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { @@ -60,56 +56,64 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { } a := s.AccessToken if a != "" { - a, err = c.Encrypt(a) - if err != nil { + if a, err = c.Encrypt(a); err != nil { return "", err } } r := s.RefreshToken if r != "" { - r, err = c.Encrypt(r) - if err != nil { + if r, err = c.Encrypt(r); err != nil { return "", err } } - return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil + return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil +} + +func decodeSessionStatePlain(v string) (s *SessionState, err error) { + chunks := strings.Split(v, " ") + if len(chunks) != 2 { + return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks)) + } + + email := strings.TrimPrefix(chunks[0], "email:") + user := strings.TrimPrefix(chunks[1], "user:") + if user == "" { + user = strings.Split(email, "@")[0] + } + + return &SessionState{User: user, Email: email}, nil } func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) { - chunks := strings.Split(v, "|") - if len(chunks) == 1 { - if strings.Contains(chunks[0], "@") { - u := strings.Split(v, "@")[0] - return &SessionState{Email: v, User: u}, nil - } - return &SessionState{User: v}, nil + if c == nil { + return decodeSessionStatePlain(v) } + chunks := strings.Split(v, "|") if len(chunks) != 4 { err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) return } - s = &SessionState{} - if c != nil && chunks[1] != "" { - s.AccessToken, err = c.Decrypt(chunks[1]) - if err != nil { + sessionState, err := decodeSessionStatePlain(chunks[0]) + if err != nil { + return nil, err + } + + if chunks[1] != "" { + if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil { return nil, err } } - if c != nil && chunks[3] != "" { - s.RefreshToken, err = c.Decrypt(chunks[3]) - if err != nil { - return nil, err - } - } - if u := chunks[0]; strings.Contains(u, "@") { - s.Email = u - s.User = strings.Split(u, "@")[0] - } else { - s.User = u - } + ts, _ := strconv.Atoi(chunks[2]) - s.ExpiresOn = time.Unix(int64(ts), 0) - return + sessionState.ExpiresOn = time.Unix(int64(ts), 0) + + if chunks[3] != "" { + if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil { + return nil, err + } + } + + return sessionState, nil } diff --git a/providers/session_state_test.go b/providers/session_state_test.go index 0cf6d3e..d3cc8f8 100644 --- a/providers/session_state_test.go +++ b/providers/session_state_test.go @@ -1,6 +1,7 @@ package providers import ( + "fmt" "strings" "testing" "time" @@ -30,6 +31,7 @@ func TestSessionStateSerialization(t *testing.T) { ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) + assert.Equal(t, "user", ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.AccessToken, ss.AccessToken) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) @@ -39,6 +41,43 @@ func TestSessionStateSerialization(t *testing.T) { ss, err = DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.Equal(t, nil, err) + assert.Equal(t, "user", ss.User) + assert.Equal(t, s.Email, ss.Email) + assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) + assert.NotEqual(t, s.AccessToken, ss.AccessToken) + assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) +} + +func TestSessionStateSerializationWithUser(t *testing.T) { + c, err := cookie.NewCipher([]byte(secret)) + assert.Equal(t, nil, err) + c2, err := cookie.NewCipher([]byte(altSecret)) + assert.Equal(t, nil, err) + s := &SessionState{ + User: "just-user", + Email: "user@domain.com", + AccessToken: "token1234", + ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), + RefreshToken: "refresh4321", + } + encoded, err := s.EncodeSessionState(c) + assert.Equal(t, nil, err) + assert.Equal(t, 3, strings.Count(encoded, "|")) + + ss, err := DecodeSessionState(encoded, c) + t.Logf("%#v", ss) + assert.Equal(t, nil, err) + assert.Equal(t, s.User, ss.User) + assert.Equal(t, s.Email, ss.Email) + assert.Equal(t, s.AccessToken, ss.AccessToken) + assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) + assert.Equal(t, s.RefreshToken, ss.RefreshToken) + + // ensure a different cipher can't decode properly (ie: it gets gibberish) + ss, err = DecodeSessionState(encoded, c2) + t.Logf("%#v", ss) + assert.Equal(t, nil, err) + assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.NotEqual(t, s.AccessToken, ss.AccessToken) @@ -46,7 +85,6 @@ func TestSessionStateSerialization(t *testing.T) { } func TestSessionStateSerializationNoCipher(t *testing.T) { - s := &SessionState{ Email: "user@domain.com", AccessToken: "token1234", @@ -55,25 +93,51 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { } encoded, err := s.EncodeSessionState(nil) assert.Equal(t, nil, err) - assert.Equal(t, s.Email, encoded) + expected := fmt.Sprintf("email:%s user:", s.Email) + assert.Equal(t, expected, encoded) // only email should have been serialized ss, err := DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) + assert.Equal(t, "user", ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, "", ss.AccessToken) assert.Equal(t, "", ss.RefreshToken) } -func TestSessionStateUserOrEmail(t *testing.T) { +func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { + s := &SessionState{ + User: "just-user", + Email: "user@domain.com", + AccessToken: "token1234", + ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), + RefreshToken: "refresh4321", + } + encoded, err := s.EncodeSessionState(nil) + assert.Equal(t, nil, err) + expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User) + assert.Equal(t, expected, encoded) + // only email should have been serialized + ss, err := DecodeSessionState(encoded, nil) + assert.Equal(t, nil, err) + assert.Equal(t, s.User, ss.User) + assert.Equal(t, s.Email, ss.Email) + assert.Equal(t, "", ss.AccessToken) + assert.Equal(t, "", ss.RefreshToken) +} + +func TestSessionStateAccountInfo(t *testing.T) { s := &SessionState{ Email: "user@domain.com", User: "just-user", } - assert.Equal(t, "user@domain.com", s.userOrEmail()) + expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User) + assert.Equal(t, expected, s.accountInfo()) + s.Email = "" - assert.Equal(t, "just-user", s.userOrEmail()) + expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User) + assert.Equal(t, expected, s.accountInfo()) } func TestExpired(t *testing.T) {