156 lines
4.8 KiB
Go
156 lines
4.8 KiB
Go
package providers
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pusher/oauth2_proxy/cookie"
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
const secret = "0123456789abcdefghijklmnopqrstuv"
|
|
const altSecret = "0000000000abcdefghijklmnopqrstuv"
|
|
|
|
func TestSessionStateSerialization(t *testing.T) {
|
|
c, err := cookie.NewCipher([]byte(secret))
|
|
assert.Equal(t, nil, err)
|
|
c2, err := cookie.NewCipher([]byte(altSecret))
|
|
assert.Equal(t, nil, err)
|
|
s := &SessionState{
|
|
Email: "user@domain.com",
|
|
AccessToken: "token1234",
|
|
IDToken: "rawtoken1234",
|
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
|
RefreshToken: "refresh4321",
|
|
}
|
|
encoded, err := s.EncodeSessionState(c)
|
|
assert.Equal(t, nil, err)
|
|
assert.Equal(t, 4, strings.Count(encoded, "|"))
|
|
|
|
ss, err := DecodeSessionState(encoded, c)
|
|
t.Logf("%#v", ss)
|
|
assert.Equal(t, nil, err)
|
|
assert.Equal(t, "user", ss.User)
|
|
assert.Equal(t, s.Email, ss.Email)
|
|
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
|
assert.Equal(t, s.IDToken, ss.IDToken)
|
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
|
|
|
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
|
ss, err = DecodeSessionState(encoded, c2)
|
|
t.Logf("%#v", ss)
|
|
assert.Equal(t, nil, err)
|
|
assert.Equal(t, "user", ss.User)
|
|
assert.Equal(t, s.Email, ss.Email)
|
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
|
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
|
assert.NotEqual(t, s.IDToken, ss.IDToken)
|
|
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
|
|
}
|
|
|
|
func TestSessionStateSerializationWithUser(t *testing.T) {
|
|
c, err := cookie.NewCipher([]byte(secret))
|
|
assert.Equal(t, nil, err)
|
|
c2, err := cookie.NewCipher([]byte(altSecret))
|
|
assert.Equal(t, nil, err)
|
|
s := &SessionState{
|
|
User: "just-user",
|
|
Email: "user@domain.com",
|
|
AccessToken: "token1234",
|
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
|
RefreshToken: "refresh4321",
|
|
}
|
|
encoded, err := s.EncodeSessionState(c)
|
|
assert.Equal(t, nil, err)
|
|
assert.Equal(t, 4, strings.Count(encoded, "|"))
|
|
|
|
ss, err := DecodeSessionState(encoded, c)
|
|
t.Logf("%#v", ss)
|
|
assert.Equal(t, nil, err)
|
|
assert.Equal(t, s.User, ss.User)
|
|
assert.Equal(t, s.Email, ss.Email)
|
|
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
|
|
|
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
|
ss, err = DecodeSessionState(encoded, c2)
|
|
t.Logf("%#v", ss)
|
|
assert.Equal(t, nil, err)
|
|
assert.Equal(t, s.User, ss.User)
|
|
assert.Equal(t, s.Email, ss.Email)
|
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
|
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
|
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
|
|
}
|
|
|
|
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
|
s := &SessionState{
|
|
Email: "user@domain.com",
|
|
AccessToken: "token1234",
|
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
|
RefreshToken: "refresh4321",
|
|
}
|
|
encoded, err := s.EncodeSessionState(nil)
|
|
assert.Equal(t, nil, err)
|
|
expected := fmt.Sprintf("email:%s user:", s.Email)
|
|
assert.Equal(t, expected, encoded)
|
|
|
|
// only email should have been serialized
|
|
ss, err := DecodeSessionState(encoded, nil)
|
|
assert.Equal(t, nil, err)
|
|
assert.Equal(t, "user", ss.User)
|
|
assert.Equal(t, s.Email, ss.Email)
|
|
assert.Equal(t, "", ss.AccessToken)
|
|
assert.Equal(t, "", ss.RefreshToken)
|
|
}
|
|
|
|
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
|
s := &SessionState{
|
|
User: "just-user",
|
|
Email: "user@domain.com",
|
|
AccessToken: "token1234",
|
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
|
RefreshToken: "refresh4321",
|
|
}
|
|
encoded, err := s.EncodeSessionState(nil)
|
|
assert.Equal(t, nil, err)
|
|
expected := fmt.Sprintf("email:%s user:%s", s.Email, s.User)
|
|
assert.Equal(t, expected, encoded)
|
|
|
|
// only email should have been serialized
|
|
ss, err := DecodeSessionState(encoded, nil)
|
|
assert.Equal(t, nil, err)
|
|
assert.Equal(t, s.User, ss.User)
|
|
assert.Equal(t, s.Email, ss.Email)
|
|
assert.Equal(t, "", ss.AccessToken)
|
|
assert.Equal(t, "", ss.RefreshToken)
|
|
}
|
|
|
|
func TestSessionStateAccountInfo(t *testing.T) {
|
|
s := &SessionState{
|
|
Email: "user@domain.com",
|
|
User: "just-user",
|
|
}
|
|
expected := fmt.Sprintf("email:%v user:%v", s.Email, s.User)
|
|
assert.Equal(t, expected, s.accountInfo())
|
|
|
|
s.Email = ""
|
|
expected = fmt.Sprintf("email:%v user:%v", s.Email, s.User)
|
|
assert.Equal(t, expected, s.accountInfo())
|
|
}
|
|
|
|
func TestExpired(t *testing.T) {
|
|
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
|
|
assert.Equal(t, true, s.IsExpired())
|
|
|
|
s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
|
|
assert.Equal(t, false, s.IsExpired())
|
|
|
|
s = &SessionState{}
|
|
assert.Equal(t, false, s.IsExpired())
|
|
}
|