Move SessionState to its own package
This commit is contained in:
parent
a1130e41a3
commit
2ab8a7d95d
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/mbland/hmacauth"
|
"github.com/mbland/hmacauth"
|
||||||
"github.com/pusher/oauth2_proxy/cookie"
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/pusher/oauth2_proxy/providers"
|
"github.com/pusher/oauth2_proxy/providers"
|
||||||
"github.com/yhat/wsutil"
|
"github.com/yhat/wsutil"
|
||||||
)
|
)
|
||||||
@ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
|
|||||||
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
|
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) {
|
func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
return nil, errors.New("missing code")
|
return nil, errors.New("missing code")
|
||||||
}
|
}
|
||||||
@ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LoadCookiedSession reads the user's authentication details from the request
|
// LoadCookiedSession reads the user's authentication details from the request
|
||||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
|
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) {
|
||||||
var age time.Duration
|
var age time.Duration
|
||||||
c, err := loadCookie(req, p.CookieName)
|
c, err := loadCookie(req, p.CookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SaveSession creates a new session cookie value and sets this on the response
|
// SaveSession creates a new session cookie value and sets this on the response
|
||||||
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
|
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
||||||
value, err := p.provider.CookieForSession(s, p.CookieCipher)
|
value, err := p.provider.CookieForSession(s, p.CookieCipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
user, ok := p.ManualSignIn(rw, req)
|
user, ok := p.ManualSignIn(rw, req)
|
||||||
if ok {
|
if ok {
|
||||||
session := &providers.SessionState{User: user}
|
session := &sessions.SessionState{User: user}
|
||||||
p.SaveSession(rw, req, session)
|
p.SaveSession(rw, req, session)
|
||||||
http.Redirect(rw, req, redirect, 302)
|
http.Redirect(rw, req, redirect, 302)
|
||||||
} else {
|
} else {
|
||||||
@ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
|||||||
|
|
||||||
// CheckBasicAuth checks the requests Authorization header for basic auth
|
// CheckBasicAuth checks the requests Authorization header for basic auth
|
||||||
// credentials and authenticates these against the proxies HtpasswdFile
|
// credentials and authenticates these against the proxies HtpasswdFile
|
||||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
|
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) {
|
||||||
if p.HtpasswdFile == nil {
|
if p.HtpasswdFile == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState,
|
|||||||
}
|
}
|
||||||
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
|
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
|
||||||
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
||||||
return &providers.SessionState{User: pair[0]}, nil
|
return &sessions.SessionState{User: pair[0]}, nil
|
||||||
}
|
}
|
||||||
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -16,6 +16,7 @@ import (
|
|||||||
|
|
||||||
"github.com/mbland/hmacauth"
|
"github.com/mbland/hmacauth"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/pusher/oauth2_proxy/providers"
|
"github.com/pusher/oauth2_proxy/providers"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) {
|
func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) {
|
||||||
return tp.EmailAddress, nil
|
return tp.EmailAddress, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool {
|
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
|
||||||
return tp.ValidToken
|
return tp.ValidToken
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook
|
|||||||
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
|
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error {
|
||||||
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
|
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) {
|
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) {
|
||||||
return p.proxy.LoadCookiedSession(p.req)
|
return p.proxy.LoadCookiedSession(p.req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestLoadCookiedSession(t *testing.T) {
|
func TestLoadCookiedSession(t *testing.T) {
|
||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithDefaults()
|
||||||
|
|
||||||
startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
||||||
pcTest.SaveSession(startSession, time.Now())
|
pcTest.SaveSession(startSession, time.Now())
|
||||||
|
|
||||||
session, _, err := pcTest.LoadCookiedSession()
|
session, _, err := pcTest.LoadCookiedSession()
|
||||||
@ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
|
|||||||
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
|
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||||
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||||
|
|
||||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
pcTest.SaveSession(startSession, reference)
|
pcTest.SaveSession(startSession, reference)
|
||||||
|
|
||||||
session, age, err := pcTest.LoadCookiedSession()
|
session, age, err := pcTest.LoadCookiedSession()
|
||||||
@ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
|||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithDefaults()
|
||||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
pcTest.SaveSession(startSession, reference)
|
pcTest.SaveSession(startSession, reference)
|
||||||
|
|
||||||
session, _, err := pcTest.LoadCookiedSession()
|
session, _, err := pcTest.LoadCookiedSession()
|
||||||
@ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
|||||||
pcTest := NewProcessCookieTestWithDefaults()
|
pcTest := NewProcessCookieTestWithDefaults()
|
||||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
pcTest.SaveSession(startSession, reference)
|
pcTest.SaveSession(startSession, reference)
|
||||||
|
|
||||||
pcTest.proxy.CookieRefresh = time.Hour
|
pcTest.proxy.CookieRefresh = time.Hour
|
||||||
@ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest {
|
|||||||
|
|
||||||
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
||||||
test := NewAuthOnlyEndpointTest()
|
test := NewAuthOnlyEndpointTest()
|
||||||
startSession := &providers.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
test.SaveSession(startSession, time.Now())
|
test.SaveSession(startSession, time.Now())
|
||||||
|
|
||||||
@ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
|||||||
test := NewAuthOnlyEndpointTest()
|
test := NewAuthOnlyEndpointTest()
|
||||||
test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||||
startSession := &providers.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
test.SaveSession(startSession, reference)
|
test.SaveSession(startSession, reference)
|
||||||
|
|
||||||
@ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
|||||||
|
|
||||||
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||||
test := NewAuthOnlyEndpointTest()
|
test := NewAuthOnlyEndpointTest()
|
||||||
startSession := &providers.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||||
test.SaveSession(startSession, time.Now())
|
test.SaveSession(startSession, time.Now())
|
||||||
test.validateUser = false
|
test.validateUser = false
|
||||||
@ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
|||||||
pcTest.req, _ = http.NewRequest("GET",
|
pcTest.req, _ = http.NewRequest("GET",
|
||||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||||
|
|
||||||
startSession := &providers.SessionState{
|
startSession := &sessions.SessionState{
|
||||||
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
|
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
|
||||||
pcTest.SaveSession(startSession, time.Now())
|
pcTest.SaveSession(startSession, time.Now())
|
||||||
|
|
||||||
@ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
|
|||||||
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
|
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
|
||||||
req.Header = st.header
|
req.Header = st.header
|
||||||
|
|
||||||
state := &providers.SessionState{
|
state := &sessions.SessionState{
|
||||||
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
||||||
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
|
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package sessions
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
@ -1,4 +1,4 @@
|
|||||||
package providers
|
package sessions_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/cookie"
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
c2, err := cookie.NewCipher([]byte(altSecret))
|
c2, err := cookie.NewCipher([]byte(altSecret))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
s := &SessionState{
|
s := &sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
IDToken: "rawtoken1234",
|
IDToken: "rawtoken1234",
|
||||||
@ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
encoded, err := s.EncodeSessionState(c)
|
encoded, err := s.EncodeSessionState(c)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
ss, err := DecodeSessionState(encoded, c)
|
ss, err := sessions.DecodeSessionState(encoded, c)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@domain.com", ss.User)
|
assert.Equal(t, "user@domain.com", ss.User)
|
||||||
@ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
|
||||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||||
ss, err = DecodeSessionState(encoded, c2)
|
ss, err = sessions.DecodeSessionState(encoded, c2)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, "user@domain.com", ss.User)
|
assert.NotEqual(t, "user@domain.com", ss.User)
|
||||||
@ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
c2, err := cookie.NewCipher([]byte(altSecret))
|
c2, err := cookie.NewCipher([]byte(altSecret))
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
s := &SessionState{
|
s := &sessions.SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
encoded, err := s.EncodeSessionState(c)
|
encoded, err := s.EncodeSessionState(c)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
ss, err := DecodeSessionState(encoded, c)
|
ss, err := sessions.DecodeSessionState(encoded, c)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, s.User, ss.User)
|
assert.Equal(t, s.User, ss.User)
|
||||||
@ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
|
||||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||||
ss, err = DecodeSessionState(encoded, c2)
|
ss, err = sessions.DecodeSessionState(encoded, c2)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.NotEqual(t, s.User, ss.User)
|
assert.NotEqual(t, s.User, ss.User)
|
||||||
@ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||||
s := &SessionState{
|
s := &sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||||
@ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
// only email should have been serialized
|
// only email should have been serialized
|
||||||
ss, err := DecodeSessionState(encoded, nil)
|
ss, err := sessions.DecodeSessionState(encoded, nil)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@domain.com", ss.User)
|
assert.Equal(t, "user@domain.com", ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
@ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||||
s := &SessionState{
|
s := &sessions.SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
|||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
|
|
||||||
// only email should have been serialized
|
// only email should have been serialized
|
||||||
ss, err := DecodeSessionState(encoded, nil)
|
ss, err := sessions.DecodeSessionState(encoded, nil)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, s.User, ss.User)
|
assert.Equal(t, s.User, ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
@ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestExpired(t *testing.T) {
|
func TestExpired(t *testing.T) {
|
||||||
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
|
s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
|
||||||
assert.Equal(t, true, s.IsExpired())
|
assert.Equal(t, true, s.IsExpired())
|
||||||
|
|
||||||
s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
|
s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
|
||||||
assert.Equal(t, false, s.IsExpired())
|
assert.Equal(t, false, s.IsExpired())
|
||||||
|
|
||||||
s = &SessionState{}
|
s = &sessions.SessionState{}
|
||||||
assert.Equal(t, false, s.IsExpired())
|
assert.Equal(t, false, s.IsExpired())
|
||||||
}
|
}
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
SessionState
|
sessions.SessionState
|
||||||
Encoded string
|
Encoded string
|
||||||
Cipher *cookie.Cipher
|
Cipher *cookie.Cipher
|
||||||
Error bool
|
Error bool
|
||||||
@ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
|
|||||||
|
|
||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
||||||
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||||
if tc.Error {
|
if tc.Error {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Empty(t, encoded)
|
assert.Empty(t, encoded)
|
||||||
@ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDecodeSessionState tests DecodeSessionState with the test vector
|
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
|
||||||
func TestDecodeSessionState(t *testing.T) {
|
func TestDecodeSessionState(t *testing.T) {
|
||||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||||
eJSON, _ := e.MarshalJSON()
|
eJSON, _ := e.MarshalJSON()
|
||||||
@ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
|
|
||||||
testCases := []testCase{
|
testCases := []testCase{
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "user@domain.com",
|
User: "user@domain.com",
|
||||||
},
|
},
|
||||||
Encoded: `{"Email":"user@domain.com"}`,
|
Encoded: `{"Email":"user@domain.com"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: `{"User":"just-user"}`,
|
Encoded: `{"User":"just-user"}`,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Cipher: c,
|
Cipher: c,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
},
|
},
|
||||||
@ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Error: true,
|
Error: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
},
|
},
|
||||||
@ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Error: true,
|
Error: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
Cipher: c,
|
Cipher: c,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
SessionState: SessionState{
|
SessionState: sessions.SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
User: "just-user",
|
User: "just-user",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
@ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i, tc := range testCases {
|
for i, tc := range testCases {
|
||||||
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
|
ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||||
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||||
if tc.Error {
|
if tc.Error {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, ss)
|
assert.Nil(t, ss)
|
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/bitly/go-simplejson"
|
"github.com/bitly/go-simplejson"
|
||||||
"github.com/pusher/oauth2_proxy/api"
|
"github.com/pusher/oauth2_proxy/api"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AzureProvider represents an Azure based Identity Provider
|
// AzureProvider represents an Azure based Identity Provider
|
||||||
@ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
var email string
|
var email string
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@windows.net", email)
|
assert.Equal(t, "user@windows.net", email)
|
||||||
@ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@windows.net", email)
|
assert.Equal(t, "user@windows.net", email)
|
||||||
@ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@windows.net", email)
|
assert.Equal(t, "user@windows.net", email)
|
||||||
@ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testAzureProvider(bURL.Host)
|
p := testAzureProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/api"
|
"github.com/pusher/oauth2_proxy/api"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FacebookProvider represents an Facebook based Identity Provider
|
// FacebookProvider represents an Facebook based Identity Provider
|
||||||
@ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return "", errors.New("missing access token")
|
||||||
}
|
}
|
||||||
@ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool {
|
func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
|
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GitHubProvider represents an GitHub based Identity Provider
|
// GitHubProvider represents an GitHub based Identity Provider
|
||||||
@ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
|
|
||||||
var emails []struct {
|
var emails []struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
@ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetUserName returns the Account user name
|
// GetUserName returns the Account user name
|
||||||
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
|
func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
|
||||||
var user struct {
|
var user struct {
|
||||||
Login string `json:"login"`
|
Login string `json:"login"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||||
@ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Empty(t, "", email)
|
assert.Empty(t, "", email)
|
||||||
@ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
|
|||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
p.Org = "testorg1"
|
p.Org = "testorg1"
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||||
@ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
|
|||||||
// We'll trigger a request failure by using an unexpected access
|
// We'll trigger a request failure by using an unexpected access
|
||||||
// token. Alternatively, we could allow the parsing of the payload as
|
// token. Alternatively, we could allow the parsing of the payload as
|
||||||
// JSON to fail.
|
// JSON to fail.
|
||||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitHubProvider(bURL.Host)
|
p := testGitHubProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetUserName(session)
|
email, err := p.GetUserName(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "mbland", email)
|
assert.Equal(t, "mbland", email)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/api"
|
"github.com/pusher/oauth2_proxy/api"
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GitLabProvider represents an GitLab based Identity Provider
|
// GitLabProvider represents an GitLab based Identity Provider
|
||||||
@ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
|
|
||||||
req, err := http.NewRequest("GET",
|
req, err := http.NewRequest("GET",
|
||||||
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitLabProvider(bURL.Host)
|
p := testGitLabProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||||
@ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
|
|||||||
// We'll trigger a request failure by using an unexpected access
|
// We'll trigger a request failure by using an unexpected access
|
||||||
// token. Alternatively, we could allow the parsing of the payload as
|
// token. Alternatively, we could allow the parsing of the payload as
|
||||||
// JSON to fail.
|
// JSON to fail.
|
||||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testGitLabProvider(bURL.Host)
|
p := testGitLabProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/logger"
|
"github.com/pusher/oauth2_proxy/logger"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
admin "google.golang.org/api/admin/directory/v1"
|
admin "google.golang.org/api/admin/directory/v1"
|
||||||
@ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
err = errors.New("missing code")
|
||||||
return
|
return
|
||||||
@ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s = &SessionState{
|
s = &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||||
@ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
|
|||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new ID token if required
|
// RefreshToken to fetch a new ID token if required
|
||||||
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct {
|
|||||||
*ProviderData
|
*ProviderData
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
return "", errors.New("not implemented")
|
return "", errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note that we're testing the internal validateToken() used to implement
|
// Note that we're testing the internal validateToken() used to implement
|
||||||
// several Provider's ValidateSessionState() implementations
|
// several Provider's ValidateSessionState() implementations
|
||||||
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool {
|
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/api"
|
"github.com/pusher/oauth2_proxy/api"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LinkedInProvider represents an LinkedIn based Identity Provider
|
// LinkedInProvider represents an LinkedIn based Identity Provider
|
||||||
@ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
if s.AccessToken == "" {
|
if s.AccessToken == "" {
|
||||||
return "", errors.New("missing access token")
|
return "", errors.New("missing access token")
|
||||||
}
|
}
|
||||||
@ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool {
|
func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
|
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testLinkedInProvider(bURL.Host)
|
p := testLinkedInProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, "user@linkedin.com", email)
|
assert.Equal(t, "user@linkedin.com", email)
|
||||||
@ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
|
|||||||
// We'll trigger a request failure by using an unexpected access
|
// We'll trigger a request failure by using an unexpected access
|
||||||
// token. Alternatively, we could allow the parsing of the payload as
|
// token. Alternatively, we could allow the parsing of the payload as
|
||||||
// JSON to fail.
|
// JSON to fail.
|
||||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
@ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
|||||||
bURL, _ := url.Parse(b.URL)
|
bURL, _ := url.Parse(b.URL)
|
||||||
p := testLinkedInProvider(bURL.Host)
|
p := testLinkedInProvider(bURL.Host)
|
||||||
|
|
||||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||||
email, err := p.GetEmailAddress(session)
|
email, err := p.GetEmailAddress(session)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
assert.Equal(t, "", email)
|
assert.Equal(t, "", email)
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/dgrijalva/jwt-go"
|
"github.com/dgrijalva/jwt-go"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
err = errors.New("missing code")
|
||||||
return
|
return
|
||||||
@ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Store the data that we found in the session state
|
// Store the data that we found in the session state
|
||||||
s = &SessionState{
|
s = &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
IDToken: jsonResponse.IDToken,
|
IDToken: jsonResponse.IDToken,
|
||||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||||
|
@ -5,9 +5,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
|
|
||||||
oidc "github.com/coreos/go-oidc"
|
oidc "github.com/coreos/go-oidc"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OIDCProvider represents an OIDC based Identity Provider
|
// OIDCProvider represents an OIDC based Identity Provider
|
||||||
@ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||||
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
c := oauth2.Config{
|
c := oauth2.Config{
|
||||||
ClientID: p.ClientID,
|
ClientID: p.ClientID,
|
||||||
@ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
|
|||||||
|
|
||||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
// RefreshToken to fetch a new ID token if required
|
// RefreshToken to fetch a new ID token if required
|
||||||
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
@ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
|
func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) {
|
||||||
c := oauth2.Config{
|
c := oauth2.Config{
|
||||||
ClientID: p.ClientID,
|
ClientID: p.ClientID,
|
||||||
ClientSecret: p.ClientSecret,
|
ClientSecret: p.ClientSecret,
|
||||||
@ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) {
|
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||||
rawIDToken, ok := token.Extra("id_token").(string)
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||||
@ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
|||||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &SessionState{
|
return &sessions.SessionState{
|
||||||
AccessToken: token.AccessToken,
|
AccessToken: token.AccessToken,
|
||||||
IDToken: rawIDToken,
|
IDToken: rawIDToken,
|
||||||
RefreshToken: token.RefreshToken,
|
RefreshToken: token.RefreshToken,
|
||||||
@ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState checks that the session's IDToken is still valid
|
// ValidateSessionState checks that the session's IDToken is still valid
|
||||||
func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool {
|
func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -10,10 +10,11 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/pusher/oauth2_proxy/cookie"
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Redeem provides a default implementation of the OAuth2 token redemption process
|
// Redeem provides a default implementation of the OAuth2 token redemption process
|
||||||
func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||||
if code == "" {
|
if code == "" {
|
||||||
err = errors.New("missing code")
|
err = errors.New("missing code")
|
||||||
return
|
return
|
||||||
@ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
|
|||||||
}
|
}
|
||||||
err = json.Unmarshal(body, &jsonResponse)
|
err = json.Unmarshal(body, &jsonResponse)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
s = &SessionState{
|
s = &sessions.SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
@ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if a := v.Get("access_token"); a != "" {
|
if a := v.Get("access_token"); a != "" {
|
||||||
s = &SessionState{AccessToken: a}
|
s = &sessions.SessionState{AccessToken: a}
|
||||||
} else {
|
} else {
|
||||||
err = fmt.Errorf("no access token found %s", body)
|
err = fmt.Errorf("no access token found %s", body)
|
||||||
}
|
}
|
||||||
@ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CookieForSession serializes a session state for storage in a cookie
|
// CookieForSession serializes a session state for storage in a cookie
|
||||||
func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) {
|
func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) {
|
||||||
return s.EncodeSessionState(c)
|
return s.EncodeSessionState(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SessionFromCookie deserializes a session from a cookie value
|
// SessionFromCookie deserializes a session from a cookie value
|
||||||
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) {
|
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) {
|
||||||
return DecodeSessionState(v, c)
|
return sessions.DecodeSessionState(v, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetEmailAddress returns the Account email address
|
// GetEmailAddress returns the Account email address
|
||||||
func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
|
func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||||
return "", errors.New("not implemented")
|
return "", errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserName returns the Account username
|
// GetUserName returns the Account username
|
||||||
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
|
func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) {
|
||||||
return "", errors.New("not implemented")
|
return "", errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ValidateSessionState validates the AccessToken
|
// ValidateSessionState validates the AccessToken
|
||||||
func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
|
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
|
||||||
return validateToken(p, s.AccessToken, nil)
|
return validateToken(p, s.AccessToken, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
// RefreshSessionIfNeeded should refresh the user's session if required and
|
||||||
// do nothing if a refresh is not required
|
// do nothing if a refresh is not required
|
||||||
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -4,12 +4,13 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRefresh(t *testing.T) {
|
func TestRefresh(t *testing.T) {
|
||||||
p := &ProviderData{}
|
p := &ProviderData{}
|
||||||
refreshed, err := p.RefreshSessionIfNeeded(&SessionState{
|
refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{
|
||||||
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
|
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
|
||||||
})
|
})
|
||||||
assert.Equal(t, false, refreshed)
|
assert.Equal(t, false, refreshed)
|
||||||
|
@ -2,20 +2,21 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/pusher/oauth2_proxy/cookie"
|
"github.com/pusher/oauth2_proxy/cookie"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Provider represents an upstream identity provider implementation
|
// Provider represents an upstream identity provider implementation
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
Data() *ProviderData
|
Data() *ProviderData
|
||||||
GetEmailAddress(*SessionState) (string, error)
|
GetEmailAddress(*sessions.SessionState) (string, error)
|
||||||
GetUserName(*SessionState) (string, error)
|
GetUserName(*sessions.SessionState) (string, error)
|
||||||
Redeem(string, string) (*SessionState, error)
|
Redeem(string, string) (*sessions.SessionState, error)
|
||||||
ValidateGroup(string) bool
|
ValidateGroup(string) bool
|
||||||
ValidateSessionState(*SessionState) bool
|
ValidateSessionState(*sessions.SessionState) bool
|
||||||
GetLoginURL(redirectURI, finalRedirect string) string
|
GetLoginURL(redirectURI, finalRedirect string) string
|
||||||
RefreshSessionIfNeeded(*SessionState) (bool, error)
|
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
|
||||||
SessionFromCookie(string, *cookie.Cipher) (*SessionState, error)
|
SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error)
|
||||||
CookieForSession(*SessionState, *cookie.Cipher) (string, error)
|
CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New provides a new Provider based on the configured provider string
|
// New provides a new Provider based on the configured provider string
|
||||||
|
Loading…
Reference in New Issue
Block a user