Move SessionState to its own package
This commit is contained in:
parent
a1130e41a3
commit
2ab8a7d95d
@ -16,6 +16,7 @@ import (
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"github.com/yhat/wsutil"
|
||||
)
|
||||
@ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
|
||||
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) {
|
||||
func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) {
|
||||
if code == "" {
|
||||
return nil, errors.New("missing code")
|
||||
}
|
||||
@ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request,
|
||||
}
|
||||
|
||||
// LoadCookiedSession reads the user's authentication details from the request
|
||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
|
||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) {
|
||||
var age time.Duration
|
||||
c, err := loadCookie(req, p.CookieName)
|
||||
if err != nil {
|
||||
@ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt
|
||||
}
|
||||
|
||||
// SaveSession creates a new session cookie value and sets this on the response
|
||||
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
|
||||
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
|
||||
value, err := p.provider.CookieForSession(s, p.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||
|
||||
user, ok := p.ManualSignIn(rw, req)
|
||||
if ok {
|
||||
session := &providers.SessionState{User: user}
|
||||
session := &sessions.SessionState{User: user}
|
||||
p.SaveSession(rw, req, session)
|
||||
http.Redirect(rw, req, redirect, 302)
|
||||
} else {
|
||||
@ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
||||
|
||||
// CheckBasicAuth checks the requests Authorization header for basic auth
|
||||
// credentials and authenticates these against the proxies HtpasswdFile
|
||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
|
||||
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) {
|
||||
if p.HtpasswdFile == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState,
|
||||
}
|
||||
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
|
||||
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
|
||||
return &providers.SessionState{User: pair[0]}, nil
|
||||
return &sessions.SessionState{User: pair[0]}, nil
|
||||
}
|
||||
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
|
||||
return nil, nil
|
||||
|
@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/providers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) {
|
||||
func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) {
|
||||
return tp.EmailAddress, nil
|
||||
}
|
||||
|
||||
func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool {
|
||||
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
|
||||
return tp.ValidToken
|
||||
}
|
||||
|
||||
@ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook
|
||||
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
||||
}
|
||||
|
||||
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
|
||||
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error {
|
||||
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) {
|
||||
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) {
|
||||
return p.proxy.LoadCookiedSession(p.req)
|
||||
}
|
||||
|
||||
func TestLoadCookiedSession(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
|
||||
startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
||||
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, time.Now())
|
||||
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
@ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
|
||||
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(-2) * time.Hour)
|
||||
|
||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
|
||||
session, age, err := pcTest.LoadCookiedSession()
|
||||
@ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
|
||||
session, _, err := pcTest.LoadCookiedSession()
|
||||
@ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
|
||||
pcTest := NewProcessCookieTestWithDefaults()
|
||||
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
pcTest.SaveSession(startSession, reference)
|
||||
|
||||
pcTest.proxy.CookieRefresh = time.Hour
|
||||
@ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest {
|
||||
|
||||
func TestAuthOnlyEndpointAccepted(t *testing.T) {
|
||||
test := NewAuthOnlyEndpointTest()
|
||||
startSession := &providers.SessionState{
|
||||
startSession := &sessions.SessionState{
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
test.SaveSession(startSession, time.Now())
|
||||
|
||||
@ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
||||
test := NewAuthOnlyEndpointTest()
|
||||
test.proxy.CookieExpire = time.Duration(24) * time.Hour
|
||||
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
|
||||
startSession := &providers.SessionState{
|
||||
startSession := &sessions.SessionState{
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
test.SaveSession(startSession, reference)
|
||||
|
||||
@ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
|
||||
|
||||
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
|
||||
test := NewAuthOnlyEndpointTest()
|
||||
startSession := &providers.SessionState{
|
||||
startSession := &sessions.SessionState{
|
||||
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
|
||||
test.SaveSession(startSession, time.Now())
|
||||
test.validateUser = false
|
||||
@ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
|
||||
pcTest.req, _ = http.NewRequest("GET",
|
||||
pcTest.opts.ProxyPrefix+"/auth", nil)
|
||||
|
||||
startSession := &providers.SessionState{
|
||||
startSession := &sessions.SessionState{
|
||||
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
|
||||
pcTest.SaveSession(startSession, time.Now())
|
||||
|
||||
@ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
|
||||
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
|
||||
req.Header = st.header
|
||||
|
||||
state := &providers.SessionState{
|
||||
state := &sessions.SessionState{
|
||||
Email: "mbland@acm.org", AccessToken: "my_access_token"}
|
||||
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
|
||||
if err != nil {
|
||||
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"encoding/json"
|
@ -1,4 +1,4 @@
|
||||
package providers
|
||||
package sessions_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
c2, err := cookie.NewCipher([]byte(altSecret))
|
||||
assert.Equal(t, nil, err)
|
||||
s := &SessionState{
|
||||
s := &sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
IDToken: "rawtoken1234",
|
||||
@ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
ss, err := sessions.DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@domain.com", ss.User)
|
||||
@ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||
ss, err = DecodeSessionState(encoded, c2)
|
||||
ss, err = sessions.DecodeSessionState(encoded, c2)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, "user@domain.com", ss.User)
|
||||
@ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
c2, err := cookie.NewCipher([]byte(altSecret))
|
||||
assert.Equal(t, nil, err)
|
||||
s := &SessionState{
|
||||
s := &sessions.SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
@ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
encoded, err := s.EncodeSessionState(c)
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
ss, err := DecodeSessionState(encoded, c)
|
||||
ss, err := sessions.DecodeSessionState(encoded, c)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
@ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||
|
||||
// ensure a different cipher can't decode properly (ie: it gets gibberish)
|
||||
ss, err = DecodeSessionState(encoded, c2)
|
||||
ss, err = sessions.DecodeSessionState(encoded, c2)
|
||||
t.Logf("%#v", ss)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, s.User, ss.User)
|
||||
@ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
s := &SessionState{
|
||||
s := &sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||
@ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
ss, err := sessions.DecodeSessionState(encoded, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@domain.com", ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
@ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
s := &SessionState{
|
||||
s := &sessions.SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
AccessToken: "token1234",
|
||||
@ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
assert.Equal(t, nil, err)
|
||||
|
||||
// only email should have been serialized
|
||||
ss, err := DecodeSessionState(encoded, nil)
|
||||
ss, err := sessions.DecodeSessionState(encoded, nil)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, s.User, ss.User)
|
||||
assert.Equal(t, s.Email, ss.Email)
|
||||
@ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestExpired(t *testing.T) {
|
||||
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
|
||||
s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
|
||||
assert.Equal(t, true, s.IsExpired())
|
||||
|
||||
s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
|
||||
s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
|
||||
assert.Equal(t, false, s.IsExpired())
|
||||
|
||||
s = &SessionState{}
|
||||
s = &sessions.SessionState{}
|
||||
assert.Equal(t, false, s.IsExpired())
|
||||
}
|
||||
|
||||
type testCase struct {
|
||||
SessionState
|
||||
sessions.SessionState
|
||||
Encoded string
|
||||
Cipher *cookie.Cipher
|
||||
Error bool
|
||||
@ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
@ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
|
||||
for i, tc := range testCases {
|
||||
encoded, err := tc.EncodeSessionState(tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, encoded)
|
||||
@ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecodeSessionState tests DecodeSessionState with the test vector
|
||||
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
|
||||
func TestDecodeSessionState(t *testing.T) {
|
||||
e := time.Now().Add(time.Duration(1) * time.Hour)
|
||||
eJSON, _ := e.MarshalJSON()
|
||||
@ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "user@domain.com",
|
||||
},
|
||||
Encoded: `{"Email":"user@domain.com"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: `{"User":"just-user"}`,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
@ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
},
|
||||
@ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
User: "just-user",
|
||||
Email: "user@domain.com",
|
||||
},
|
||||
@ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Error: true,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
@ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
Cipher: c,
|
||||
},
|
||||
{
|
||||
SessionState: SessionState{
|
||||
SessionState: sessions.SessionState{
|
||||
Email: "user@domain.com",
|
||||
User: "just-user",
|
||||
AccessToken: "token1234",
|
||||
@ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) {
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||
ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
|
||||
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
|
||||
if tc.Error {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, ss)
|
@ -9,6 +9,7 @@ import (
|
||||
"github.com/bitly/go-simplejson"
|
||||
"github.com/pusher/oauth2_proxy/api"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// AzureProvider represents an Azure based Identity Provider
|
||||
@ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
var email string
|
||||
var err error
|
||||
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@windows.net", email)
|
||||
@ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@windows.net", email)
|
||||
@ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@windows.net", email)
|
||||
@ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||
assert.Equal(t, "", email)
|
||||
@ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
@ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testAzureProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, "type assertion to string failed", err.Error())
|
||||
assert.Equal(t, "", email)
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/api"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// FacebookProvider represents an Facebook based Identity Provider
|
||||
@ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
@ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool {
|
||||
func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// GitHubProvider represents an GitHub based Identity Provider
|
||||
@ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
|
||||
var emails []struct {
|
||||
Email string `json:"email"`
|
||||
@ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
}
|
||||
|
||||
// GetUserName returns the Account user name
|
||||
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
|
||||
func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
|
||||
var user struct {
|
||||
Login string `json:"login"`
|
||||
Email string `json:"email"`
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
@ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Empty(t, "", email)
|
||||
@ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
p.Org = "testorg1"
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
@ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
@ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
@ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitHubProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetUserName(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "mbland", email)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/pusher/oauth2_proxy/api"
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// GitLabProvider represents an GitLab based Identity Provider
|
||||
@ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
|
||||
req, err := http.NewRequest("GET",
|
||||
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitLabProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "michael.bland@gsa.gov", email)
|
||||
@ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
@ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testGitLabProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/logger"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
admin "google.golang.org/api/admin/directory/v1"
|
||||
@ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
||||
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
@ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s = &SessionState{
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||
@ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
||||
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct {
|
||||
*ProviderData
|
||||
}
|
||||
|
||||
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
// Note that we're testing the internal validateToken() used to implement
|
||||
// several Provider's ValidateSessionState() implementations
|
||||
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool {
|
||||
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/api"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// LinkedInProvider represents an LinkedIn based Identity Provider
|
||||
@ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header {
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
if s.AccessToken == "" {
|
||||
return "", errors.New("missing access token")
|
||||
}
|
||||
@ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool {
|
||||
func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testLinkedInProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.Equal(t, nil, err)
|
||||
assert.Equal(t, "user@linkedin.com", email)
|
||||
@ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
|
||||
// We'll trigger a request failure by using an unexpected access
|
||||
// token. Alternatively, we could allow the parsing of the payload as
|
||||
// JSON to fail.
|
||||
session := &SessionState{AccessToken: "unexpected_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
@ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
|
||||
bURL, _ := url.Parse(b.URL)
|
||||
p := testLinkedInProvider(bURL.Host)
|
||||
|
||||
session := &SessionState{AccessToken: "imaginary_access_token"}
|
||||
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
|
||||
email, err := p.GetEmailAddress(session)
|
||||
assert.NotEqual(t, nil, err)
|
||||
assert.Equal(t, "", email)
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
@ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
||||
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
@ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er
|
||||
}
|
||||
|
||||
// Store the data that we found in the session state
|
||||
s = &SessionState{
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||
|
@ -5,9 +5,9 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OIDCProvider represents an OIDC based Identity Provider
|
||||
@ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an ID token
|
||||
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
||||
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
ctx := context.Background()
|
||||
c := oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
@ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new ID token if required
|
||||
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
||||
func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
@ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
|
||||
func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) {
|
||||
c := oauth2.Config{
|
||||
ClientID: p.ClientID,
|
||||
ClientSecret: p.ClientSecret,
|
||||
@ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) {
|
||||
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||
@ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||
}
|
||||
|
||||
return &SessionState{
|
||||
return &sessions.SessionState{
|
||||
AccessToken: token.AccessToken,
|
||||
IDToken: rawIDToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
@ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
|
||||
}
|
||||
|
||||
// ValidateSessionState checks that the session's IDToken is still valid
|
||||
func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool {
|
||||
func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
ctx := context.Background()
|
||||
_, err := p.Verifier.Verify(ctx, s.IDToken)
|
||||
if err != nil {
|
||||
|
@ -10,10 +10,11 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// Redeem provides a default implementation of the OAuth2 token redemption process
|
||||
func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) {
|
||||
func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
@ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
if err == nil {
|
||||
s = &SessionState{
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
}
|
||||
return
|
||||
@ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
|
||||
return
|
||||
}
|
||||
if a := v.Get("access_token"); a != "" {
|
||||
s = &SessionState{AccessToken: a}
|
||||
s = &sessions.SessionState{AccessToken: a}
|
||||
} else {
|
||||
err = fmt.Errorf("no access token found %s", body)
|
||||
}
|
||||
@ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
|
||||
}
|
||||
|
||||
// CookieForSession serializes a session state for storage in a cookie
|
||||
func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) {
|
||||
func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) {
|
||||
return s.EncodeSessionState(c)
|
||||
}
|
||||
|
||||
// SessionFromCookie deserializes a session from a cookie value
|
||||
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) {
|
||||
return DecodeSessionState(v, c)
|
||||
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) {
|
||||
return sessions.DecodeSessionState(v, c)
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
|
||||
func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
// GetUserName returns the Account username
|
||||
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
|
||||
func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
@ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool {
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
|
||||
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
return validateToken(p, s.AccessToken, nil)
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded should refresh the user's session if required and
|
||||
// do nothing if a refresh is not required
|
||||
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
||||
func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
|
@ -4,12 +4,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRefresh(t *testing.T) {
|
||||
p := &ProviderData{}
|
||||
refreshed, err := p.RefreshSessionIfNeeded(&SessionState{
|
||||
refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{
|
||||
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
|
||||
})
|
||||
assert.Equal(t, false, refreshed)
|
||||
|
@ -2,20 +2,21 @@ package providers
|
||||
|
||||
import (
|
||||
"github.com/pusher/oauth2_proxy/cookie"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
)
|
||||
|
||||
// Provider represents an upstream identity provider implementation
|
||||
type Provider interface {
|
||||
Data() *ProviderData
|
||||
GetEmailAddress(*SessionState) (string, error)
|
||||
GetUserName(*SessionState) (string, error)
|
||||
Redeem(string, string) (*SessionState, error)
|
||||
GetEmailAddress(*sessions.SessionState) (string, error)
|
||||
GetUserName(*sessions.SessionState) (string, error)
|
||||
Redeem(string, string) (*sessions.SessionState, error)
|
||||
ValidateGroup(string) bool
|
||||
ValidateSessionState(*SessionState) bool
|
||||
ValidateSessionState(*sessions.SessionState) bool
|
||||
GetLoginURL(redirectURI, finalRedirect string) string
|
||||
RefreshSessionIfNeeded(*SessionState) (bool, error)
|
||||
SessionFromCookie(string, *cookie.Cipher) (*SessionState, error)
|
||||
CookieForSession(*SessionState, *cookie.Cipher) (string, error)
|
||||
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
|
||||
SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error)
|
||||
CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error)
|
||||
}
|
||||
|
||||
// New provides a new Provider based on the configured provider string
|
||||
|
Loading…
Reference in New Issue
Block a user