Merge pull request #466 from clobrano/github-use-login-as-user
GitHub use login as user
This commit is contained in:
commit
b0c1c85177
@ -244,6 +244,13 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
|
|||||||
if s.Email == "" {
|
if s.Email == "" {
|
||||||
s.Email, err = p.provider.GetEmailAddress(s)
|
s.Email, err = p.provider.GetEmailAddress(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.User == "" {
|
||||||
|
s.User, err = p.provider.GetUserName(s)
|
||||||
|
if err != nil && err.Error() == "not implemented" {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -218,10 +218,10 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
|||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
return "", fmt.Errorf("got %d from %q %s",
|
return "", fmt.Errorf("got %d from %q %s",
|
||||||
resp.StatusCode, endpoint.String(), body)
|
resp.StatusCode, endpoint.String(), body)
|
||||||
} else {
|
|
||||||
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||||
|
|
||||||
if err := json.Unmarshal(body, &emails); err != nil {
|
if err := json.Unmarshal(body, &emails); err != nil {
|
||||||
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
||||||
}
|
}
|
||||||
@ -234,3 +234,46 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
|||||||
|
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
|
||||||
|
var user struct {
|
||||||
|
Login string `json:"login"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := &url.URL{
|
||||||
|
Scheme: p.ValidateURL.Scheme,
|
||||||
|
Host: p.ValidateURL.Host,
|
||||||
|
Path: path.Join(p.ValidateURL.Path, "/user"),
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", endpoint.String(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("could not create new GET request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("token %s", s.AccessToken))
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(resp.Body)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
return "", fmt.Errorf("got %d from %q %s",
|
||||||
|
resp.StatusCode, endpoint.String(), body)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("got %d from %q %s", resp.StatusCode, endpoint.String(), body)
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &user); err != nil {
|
||||||
|
return "", fmt.Errorf("%s unmarshaling %s", err, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
return user.Login, nil
|
||||||
|
}
|
||||||
|
146
providers/github_test.go
Normal file
146
providers/github_test.go
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testGitHubProvider(hostname string) *GitHubProvider {
|
||||||
|
p := NewGitHubProvider(
|
||||||
|
&ProviderData{
|
||||||
|
ProviderName: "",
|
||||||
|
LoginURL: &url.URL{},
|
||||||
|
RedeemURL: &url.URL{},
|
||||||
|
ProfileURL: &url.URL{},
|
||||||
|
ValidateURL: &url.URL{},
|
||||||
|
Scope: ""})
|
||||||
|
if hostname != "" {
|
||||||
|
updateURL(p.Data().LoginURL, hostname)
|
||||||
|
updateURL(p.Data().RedeemURL, hostname)
|
||||||
|
updateURL(p.Data().ProfileURL, hostname)
|
||||||
|
updateURL(p.Data().ValidateURL, hostname)
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func testGitHubBackend(payload string) *httptest.Server {
|
||||||
|
pathToQueryMap := map[string]string{
|
||||||
|
"/user": "",
|
||||||
|
"/user/emails": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
return httptest.NewServer(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
url := r.URL
|
||||||
|
query, ok := pathToQueryMap[url.Path]
|
||||||
|
if !ok {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
} else if url.RawQuery != query {
|
||||||
|
w.WriteHeader(404)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte(payload))
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitHubProviderDefaults(t *testing.T) {
|
||||||
|
p := testGitHubProvider("")
|
||||||
|
assert.NotEqual(t, nil, p)
|
||||||
|
assert.Equal(t, "GitHub", p.Data().ProviderName)
|
||||||
|
assert.Equal(t, "https://github.com/login/oauth/authorize",
|
||||||
|
p.Data().LoginURL.String())
|
||||||
|
assert.Equal(t, "https://github.com/login/oauth/access_token",
|
||||||
|
p.Data().RedeemURL.String())
|
||||||
|
assert.Equal(t, "https://api.github.com/",
|
||||||
|
p.Data().ValidateURL.String())
|
||||||
|
assert.Equal(t, "user:email", p.Data().Scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitHubProviderOverrides(t *testing.T) {
|
||||||
|
p := NewGitHubProvider(
|
||||||
|
&ProviderData{
|
||||||
|
LoginURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "example.com",
|
||||||
|
Path: "/login/oauth/authorize"},
|
||||||
|
RedeemURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "example.com",
|
||||||
|
Path: "/login/oauth/access_token"},
|
||||||
|
ValidateURL: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "api.example.com",
|
||||||
|
Path: "/"},
|
||||||
|
Scope: "profile"})
|
||||||
|
assert.NotEqual(t, nil, p)
|
||||||
|
assert.Equal(t, "GitHub", p.Data().ProviderName)
|
||||||
|
assert.Equal(t, "https://example.com/login/oauth/authorize",
|
||||||
|
p.Data().LoginURL.String())
|
||||||
|
assert.Equal(t, "https://example.com/login/oauth/access_token",
|
||||||
|
p.Data().RedeemURL.String())
|
||||||
|
assert.Equal(t, "https://api.example.com/",
|
||||||
|
p.Data().ValidateURL.String())
|
||||||
|
assert.Equal(t, "profile", p.Data().Scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitHubProviderGetEmailAddress(t *testing.T) {
|
||||||
|
b := testGitHubBackend(`[ {"email": "michael.bland@gsa.gov", "primary": true} ]`)
|
||||||
|
defer b.Close()
|
||||||
|
|
||||||
|
bURL, _ := url.Parse(b.URL)
|
||||||
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
|
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||||
|
email, err := p.GetEmailAddress(session)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note that trying to trigger the "failed building request" case is not
|
||||||
|
// practical, since the only way it can fail is if the URL fails to parse.
|
||||||
|
func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||||
|
b := testGitHubBackend("unused payload")
|
||||||
|
defer b.Close()
|
||||||
|
|
||||||
|
bURL, _ := url.Parse(b.URL)
|
||||||
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
|
// We'll trigger a request failure by using an unexpected access
|
||||||
|
// token. Alternatively, we could allow the parsing of the payload as
|
||||||
|
// JSON to fail.
|
||||||
|
session := &SessionState{AccessToken: "unexpected_access_token"}
|
||||||
|
email, err := p.GetEmailAddress(session)
|
||||||
|
assert.NotEqual(t, nil, err)
|
||||||
|
assert.Equal(t, "", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||||
|
b := testGitHubBackend("{\"foo\": \"bar\"}")
|
||||||
|
defer b.Close()
|
||||||
|
|
||||||
|
bURL, _ := url.Parse(b.URL)
|
||||||
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
|
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||||
|
email, err := p.GetEmailAddress(session)
|
||||||
|
assert.NotEqual(t, nil, err)
|
||||||
|
assert.Equal(t, "", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGitHubProviderGetUserName(t *testing.T) {
|
||||||
|
b := testGitHubBackend(`{"email": "michael.bland@gsa.gov", "login": "mbland"}`)
|
||||||
|
defer b.Close()
|
||||||
|
|
||||||
|
bURL, _ := url.Parse(b.URL)
|
||||||
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
|
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||||
|
email, err := p.GetUserName(session)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "mbland", email)
|
||||||
|
}
|
@ -106,6 +106,11 @@ func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
|
|||||||
return "", errors.New("not implemented")
|
return "", errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUserName returns the Account username
|
||||||
|
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
|
||||||
|
return "", errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateGroup validates that the provided email exists in the configured provider
|
// ValidateGroup validates that the provided email exists in the configured provider
|
||||||
// email group(s).
|
// email group(s).
|
||||||
func (p *ProviderData) ValidateGroup(email string) bool {
|
func (p *ProviderData) ValidateGroup(email string) bool {
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
type Provider interface {
|
type Provider interface {
|
||||||
Data() *ProviderData
|
Data() *ProviderData
|
||||||
GetEmailAddress(*SessionState) (string, error)
|
GetEmailAddress(*SessionState) (string, error)
|
||||||
|
GetUserName(*SessionState) (string, error)
|
||||||
Redeem(string, string) (*SessionState, error)
|
Redeem(string, string) (*SessionState, error)
|
||||||
ValidateGroup(string) bool
|
ValidateGroup(string) bool
|
||||||
ValidateSessionState(*SessionState) bool
|
ValidateSessionState(*SessionState) bool
|
||||||
|
@ -25,7 +25,7 @@ func (s *SessionState) IsExpired() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionState) String() string {
|
func (s *SessionState) String() string {
|
||||||
o := fmt.Sprintf("Session{%s", s.userOrEmail())
|
o := fmt.Sprintf("Session{%s", s.accountInfo())
|
||||||
if s.AccessToken != "" {
|
if s.AccessToken != "" {
|
||||||
o += " token:true"
|
o += " token:true"
|
||||||
}
|
}
|
||||||
@ -40,17 +40,13 @@ func (s *SessionState) String() string {
|
|||||||
|
|
||||||
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
|
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
|
||||||
if c == nil || s.AccessToken == "" {
|
if c == nil || s.AccessToken == "" {
|
||||||
return s.userOrEmail(), nil
|
return s.accountInfo(), nil
|
||||||
}
|
}
|
||||||
return s.EncryptedString(c)
|
return s.EncryptedString(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionState) userOrEmail() string {
|
func (s *SessionState) accountInfo() string {
|
||||||
u := s.User
|
return fmt.Sprintf("email:%s user:%s", s.Email, s.User)
|
||||||
if s.Email != "" {
|
|
||||||
u = s.Email
|
|
||||||
}
|
|
||||||
return u
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
|
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
|
||||||
@ -60,56 +56,64 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
|
|||||||
}
|
}
|
||||||
a := s.AccessToken
|
a := s.AccessToken
|
||||||
if a != "" {
|
if a != "" {
|
||||||
a, err = c.Encrypt(a)
|
if a, err = c.Encrypt(a); err != nil {
|
||||||
if err != nil {
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r := s.RefreshToken
|
r := s.RefreshToken
|
||||||
if r != "" {
|
if r != "" {
|
||||||
r, err = c.Encrypt(r)
|
if r, err = c.Encrypt(r); err != nil {
|
||||||
if err != nil {
|
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil
|
return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeSessionStatePlain(v string) (s *SessionState, err error) {
|
||||||
|
chunks := strings.Split(v, " ")
|
||||||
|
if len(chunks) != 2 {
|
||||||
|
return nil, fmt.Errorf("could not decode session state: expected 2 chunks got %d", len(chunks))
|
||||||
|
}
|
||||||
|
|
||||||
|
email := strings.TrimPrefix(chunks[0], "email:")
|
||||||
|
user := strings.TrimPrefix(chunks[1], "user:")
|
||||||
|
if user == "" {
|
||||||
|
user = strings.Split(email, "@")[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SessionState{User: user, Email: email}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
|
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
|
||||||
chunks := strings.Split(v, "|")
|
if c == nil {
|
||||||
if len(chunks) == 1 {
|
return decodeSessionStatePlain(v)
|
||||||
if strings.Contains(chunks[0], "@") {
|
|
||||||
u := strings.Split(v, "@")[0]
|
|
||||||
return &SessionState{Email: v, User: u}, nil
|
|
||||||
}
|
|
||||||
return &SessionState{User: v}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chunks := strings.Split(v, "|")
|
||||||
if len(chunks) != 4 {
|
if len(chunks) != 4 {
|
||||||
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
|
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
s = &SessionState{}
|
sessionState, err := decodeSessionStatePlain(chunks[0])
|
||||||
if c != nil && chunks[1] != "" {
|
if err != nil {
|
||||||
s.AccessToken, err = c.Decrypt(chunks[1])
|
return nil, err
|
||||||
if err != nil {
|
}
|
||||||
|
|
||||||
|
if chunks[1] != "" {
|
||||||
|
if sessionState.AccessToken, err = c.Decrypt(chunks[1]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if c != nil && chunks[3] != "" {
|
|
||||||
s.RefreshToken, err = c.Decrypt(chunks[3])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if u := chunks[0]; strings.Contains(u, "@") {
|
|
||||||
s.Email = u
|
|
||||||
s.User = strings.Split(u, "@")[0]
|
|
||||||
} else {
|
|
||||||
s.User = u
|
|
||||||
}
|
|
||||||
ts, _ := strconv.Atoi(chunks[2])
|
ts, _ := strconv.Atoi(chunks[2])
|
||||||
s.ExpiresOn = time.Unix(int64(ts), 0)
|
sessionState.ExpiresOn = time.Unix(int64(ts), 0)
|
||||||
return
|
|
||||||
|
if chunks[3] != "" {
|
||||||
|
if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessionState, nil
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package providers
|
package providers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -30,6 +31,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
ss, err := DecodeSessionState(encoded, c)
|
ss, err := DecodeSessionState(encoded, c)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "user", ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||||
@ -39,6 +41,43 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
ss, err = DecodeSessionState(encoded, c2)
|
ss, err = DecodeSessionState(encoded, c2)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "user", ss.User)
|
||||||
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||||
|
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||||
|
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||||
|
c, err := cookie.NewCipher([]byte(secret))
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
c2, err := cookie.NewCipher([]byte(altSecret))
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
s := &SessionState{
|
||||||
|
User: "just-user",
|
||||||
|
Email: "user@domain.com",
|
||||||
|
AccessToken: "token1234",
|
||||||
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||||
|
RefreshToken: "refresh4321",
|
||||||
|
}
|
||||||
|
encoded, err := s.EncodeSessionState(c)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, 3, strings.Count(encoded, "|"))
|
||||||
|
|
||||||
|
ss, err := DecodeSessionState(encoded, c)
|
||||||
|
t.Logf("%#v", ss)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, s.User, ss.User)
|
||||||
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
|
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||||
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||||
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
|
||||||
|
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||||
|
ss, err = DecodeSessionState(encoded, c2)
|
||||||
|
t.Logf("%#v", ss)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, s.User, ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||||
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||||
@ -46,7 +85,6 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||||
|
|
||||||
s := &SessionState{
|
s := &SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -55,25 +93,51 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
|||||||
}
|
}
|
||||||
encoded, err := s.EncodeSessionState(nil)
|
encoded, err := s.EncodeSessionState(nil)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, s.Email, encoded)
|
expected := fmt.Sprintf("email:%s user:", s.Email)
|
||||||
|
assert.Equal(t, expected, encoded)
|
||||||
|
|
||||||
// only email should have been serialized
|
// only email should have been serialized
|
||||||
ss, err := DecodeSessionState(encoded, nil)
|
ss, err := DecodeSessionState(encoded, nil)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, "user", ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
assert.Equal(t, "", ss.AccessToken)
|
assert.Equal(t, "", ss.AccessToken)
|
||||||
assert.Equal(t, "", ss.RefreshToken)
|
assert.Equal(t, "", ss.RefreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionStateUserOrEmail(t *testing.T) {
|
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||||
|
s := &SessionState{
|
||||||
|
User: "just-user",
|
||||||
|
Email: "user@domain.com",
|
||||||
|
AccessToken: "token1234",
|
||||||
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||||
|
RefreshToken: "refresh4321",
|
||||||
|
}
|
||||||
|
encoded, err := s.EncodeSessionState(nil)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User)
|
||||||
|
assert.Equal(t, expected, encoded)
|
||||||
|
|
||||||
|
// only email should have been serialized
|
||||||
|
ss, err := DecodeSessionState(encoded, nil)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, s.User, ss.User)
|
||||||
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
|
assert.Equal(t, "", ss.AccessToken)
|
||||||
|
assert.Equal(t, "", ss.RefreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionStateAccountInfo(t *testing.T) {
|
||||||
s := &SessionState{
|
s := &SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
}
|
}
|
||||||
assert.Equal(t, "user@domain.com", s.userOrEmail())
|
expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User)
|
||||||
|
assert.Equal(t, expected, s.accountInfo())
|
||||||
|
|
||||||
s.Email = ""
|
s.Email = ""
|
||||||
assert.Equal(t, "just-user", s.userOrEmail())
|
expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User)
|
||||||
|
assert.Equal(t, expected, s.accountInfo())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExpired(t *testing.T) {
|
func TestExpired(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user