Move SessionState to its own package

This commit is contained in:
Joel Speed 2019-05-05 13:33:13 +01:00
parent a1130e41a3
commit 2ab8a7d95d
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
20 changed files with 127 additions and 109 deletions

View File

@ -16,6 +16,7 @@ import (
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/providers" "github.com/pusher/oauth2_proxy/providers"
"github.com/yhat/wsutil" "github.com/yhat/wsutil"
) )
@ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
} }
func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
return nil, errors.New("missing code") return nil, errors.New("missing code")
} }
@ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request,
} }
// LoadCookiedSession reads the user's authentication details from the request // LoadCookiedSession reads the user's authentication details from the request
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) {
var age time.Duration var age time.Duration
c, err := loadCookie(req, p.CookieName) c, err := loadCookie(req, p.CookieName)
if err != nil { if err != nil {
@ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt
} }
// SaveSession creates a new session cookie value and sets this on the response // SaveSession creates a new session cookie value and sets this on the response
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
value, err := p.provider.CookieForSession(s, p.CookieCipher) value, err := p.provider.CookieForSession(s, p.CookieCipher)
if err != nil { if err != nil {
return err return err
@ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
user, ok := p.ManualSignIn(rw, req) user, ok := p.ManualSignIn(rw, req)
if ok { if ok {
session := &providers.SessionState{User: user} session := &sessions.SessionState{User: user}
p.SaveSession(rw, req, session) p.SaveSession(rw, req, session)
http.Redirect(rw, req, redirect, 302) http.Redirect(rw, req, redirect, 302)
} else { } else {
@ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
// CheckBasicAuth checks the requests Authorization header for basic auth // CheckBasicAuth checks the requests Authorization header for basic auth
// credentials and authenticates these against the proxies HtpasswdFile // credentials and authenticates these against the proxies HtpasswdFile
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) {
if p.HtpasswdFile == nil { if p.HtpasswdFile == nil {
return nil, nil return nil, nil
} }
@ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState,
} }
if p.HtpasswdFile.Validate(pair[0], pair[1]) { if p.HtpasswdFile.Validate(pair[0], pair[1]) {
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
return &providers.SessionState{User: pair[0]}, nil return &sessions.SessionState{User: pair[0]}, nil
} }
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
return nil, nil return nil, nil

View File

@ -16,6 +16,7 @@ import (
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/providers" "github.com/pusher/oauth2_proxy/providers"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
} }
} }
func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) {
return tp.EmailAddress, nil return tp.EmailAddress, nil
} }
func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
return tp.ValidToken return tp.ValidToken
} }
@ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
} }
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error {
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
if err != nil { if err != nil {
return err return err
@ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
return nil return nil
} }
func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) { func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) {
return p.proxy.LoadCookiedSession(p.req) return p.proxy.LoadCookiedSession(p.req)
} }
func TestLoadCookiedSession(t *testing.T) { func TestLoadCookiedSession(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, time.Now()) pcTest.SaveSession(startSession, time.Now())
session, _, err := pcTest.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
@ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-2) * time.Hour) reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
session, age, err := pcTest.LoadCookiedSession() session, age, err := pcTest.LoadCookiedSession()
@ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
session, _, err := pcTest.LoadCookiedSession() session, _, err := pcTest.LoadCookiedSession()
@ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults() pcTest := NewProcessCookieTestWithDefaults()
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, reference) pcTest.SaveSession(startSession, reference)
pcTest.proxy.CookieRefresh = time.Hour pcTest.proxy.CookieRefresh = time.Hour
@ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest {
func TestAuthOnlyEndpointAccepted(t *testing.T) { func TestAuthOnlyEndpointAccepted(t *testing.T) {
test := NewAuthOnlyEndpointTest() test := NewAuthOnlyEndpointTest()
startSession := &providers.SessionState{ startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now()) test.SaveSession(startSession, time.Now())
@ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
test := NewAuthOnlyEndpointTest() test := NewAuthOnlyEndpointTest()
test.proxy.CookieExpire = time.Duration(24) * time.Hour test.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{ startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, reference) test.SaveSession(startSession, reference)
@ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test := NewAuthOnlyEndpointTest() test := NewAuthOnlyEndpointTest()
startSession := &providers.SessionState{ startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now()) test.SaveSession(startSession, time.Now())
test.validateUser = false test.validateUser = false
@ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET", pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil) pcTest.opts.ProxyPrefix+"/auth", nil)
startSession := &providers.SessionState{ startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
pcTest.SaveSession(startSession, time.Now()) pcTest.SaveSession(startSession, time.Now())
@ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
req := httptest.NewRequest(method, "/foo/bar", bodyBuf) req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
req.Header = st.header req.Header = st.header
state := &providers.SessionState{ state := &sessions.SessionState{
Email: "mbland@acm.org", AccessToken: "my_access_token"} Email: "mbland@acm.org", AccessToken: "my_access_token"}
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package providers package sessions
import ( import (
"encoding/json" "encoding/json"

View File

@ -1,4 +1,4 @@
package providers package sessions_test
import ( import (
"fmt" "fmt"
@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
c2, err := cookie.NewCipher([]byte(altSecret)) c2, err := cookie.NewCipher([]byte(altSecret))
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
s := &SessionState{ s := &sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
IDToken: "rawtoken1234", IDToken: "rawtoken1234",
@ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) {
encoded, err := s.EncodeSessionState(c) encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
ss, err := DecodeSessionState(encoded, c) ss, err := sessions.DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@domain.com", ss.User) assert.Equal(t, "user@domain.com", ss.User)
@ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, s.RefreshToken, ss.RefreshToken) assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish) // ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = DecodeSessionState(encoded, c2) ss, err = sessions.DecodeSessionState(encoded, c2)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.NotEqual(t, "user@domain.com", ss.User) assert.NotEqual(t, "user@domain.com", ss.User)
@ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
c2, err := cookie.NewCipher([]byte(altSecret)) c2, err := cookie.NewCipher([]byte(altSecret))
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
s := &SessionState{ s := &sessions.SessionState{
User: "just-user", User: "just-user",
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
@ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
encoded, err := s.EncodeSessionState(c) encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
ss, err := DecodeSessionState(encoded, c) ss, err := sessions.DecodeSessionState(encoded, c)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User) assert.Equal(t, s.User, ss.User)
@ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, s.RefreshToken, ss.RefreshToken) assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish) // ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = DecodeSessionState(encoded, c2) ss, err = sessions.DecodeSessionState(encoded, c2)
t.Logf("%#v", ss) t.Logf("%#v", ss)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.NotEqual(t, s.User, ss.User) assert.NotEqual(t, s.User, ss.User)
@ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
} }
func TestSessionStateSerializationNoCipher(t *testing.T) { func TestSessionStateSerializationNoCipher(t *testing.T) {
s := &SessionState{ s := &sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
@ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
// only email should have been serialized // only email should have been serialized
ss, err := DecodeSessionState(encoded, nil) ss, err := sessions.DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@domain.com", ss.User) assert.Equal(t, "user@domain.com", ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
@ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
} }
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
s := &SessionState{ s := &sessions.SessionState{
User: "just-user", User: "just-user",
Email: "user@domain.com", Email: "user@domain.com",
AccessToken: "token1234", AccessToken: "token1234",
@ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
// only email should have been serialized // only email should have been serialized
ss, err := DecodeSessionState(encoded, nil) ss, err := sessions.DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User) assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.Email, ss.Email)
@ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
} }
func TestExpired(t *testing.T) { func TestExpired(t *testing.T) {
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
assert.Equal(t, true, s.IsExpired()) assert.Equal(t, true, s.IsExpired())
s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
assert.Equal(t, false, s.IsExpired()) assert.Equal(t, false, s.IsExpired())
s = &SessionState{} s = &sessions.SessionState{}
assert.Equal(t, false, s.IsExpired()) assert.Equal(t, false, s.IsExpired())
} }
type testCase struct { type testCase struct {
SessionState sessions.SessionState
Encoded string Encoded string
Cipher *cookie.Cipher Cipher *cookie.Cipher
Error bool Error bool
@ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) {
testCases := []testCase{ testCases := []testCase{
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
}, },
Encoded: `{"Email":"user@domain.com","User":"just-user"}`, Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
AccessToken: "token1234", AccessToken: "token1234",
@ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
for i, tc := range testCases { for i, tc := range testCases {
encoded, err := tc.EncodeSessionState(tc.Cipher) encoded, err := tc.EncodeSessionState(tc.Cipher)
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
if tc.Error { if tc.Error {
assert.Error(t, err) assert.Error(t, err)
assert.Empty(t, encoded) assert.Empty(t, encoded)
@ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) {
} }
} }
// TestDecodeSessionState tests DecodeSessionState with the test vector // TestDecodeSessionState testssessions.DecodeSessionState with the test vector
func TestDecodeSessionState(t *testing.T) { func TestDecodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour) e := time.Now().Add(time.Duration(1) * time.Hour)
eJSON, _ := e.MarshalJSON() eJSON, _ := e.MarshalJSON()
@ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) {
testCases := []testCase{ testCases := []testCase{
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
}, },
Encoded: `{"Email":"user@domain.com","User":"just-user"}`, Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "user@domain.com", User: "user@domain.com",
}, },
Encoded: `{"Email":"user@domain.com"}`, Encoded: `{"Email":"user@domain.com"}`,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
User: "just-user", User: "just-user",
}, },
Encoded: `{"User":"just-user"}`, Encoded: `{"User":"just-user"}`,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
}, },
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
AccessToken: "token1234", AccessToken: "token1234",
@ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) {
Cipher: c, Cipher: c,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
}, },
@ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) {
Error: true, Error: true,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
User: "just-user", User: "just-user",
Email: "user@domain.com", Email: "user@domain.com",
}, },
@ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) {
Error: true, Error: true,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
AccessToken: "token1234", AccessToken: "token1234",
@ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) {
Cipher: c, Cipher: c,
}, },
{ {
SessionState: SessionState{ SessionState: sessions.SessionState{
Email: "user@domain.com", Email: "user@domain.com",
User: "just-user", User: "just-user",
AccessToken: "token1234", AccessToken: "token1234",
@ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) {
} }
for i, tc := range testCases { for i, tc := range testCases {
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
if tc.Error { if tc.Error {
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, ss) assert.Nil(t, ss)

View File

@ -9,6 +9,7 @@ import (
"github.com/bitly/go-simplejson" "github.com/bitly/go-simplejson"
"github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// AzureProvider represents an Azure based Identity Provider // AzureProvider represents an Azure based Identity Provider
@ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
var email string var email string
var err error var err error

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", email)
@ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", email)
@ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email) assert.Equal(t, "user@windows.net", email)
@ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host) p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email) assert.Equal(t, "", email)

View File

@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// FacebookProvider represents an Facebook based Identity Provider // FacebookProvider represents an Facebook based Identity Provider
@ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
@ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
} }

View File

@ -11,6 +11,7 @@ import (
"strings" "strings"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// GitHubProvider represents an GitHub based Identity Provider // GitHubProvider represents an GitHub based Identity Provider
@ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
var emails []struct { var emails []struct {
Email string `json:"email"` Email string `json:"email"`
@ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
} }
// GetUserName returns the Account user name // GetUserName returns the Account user name
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
var user struct { var user struct {
Login string `json:"login"` Login string `json:"login"`
Email string `json:"email"` Email string `json:"email"`

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
@ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Empty(t, "", email) assert.Empty(t, "", email)
@ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
p.Org = "testorg1" p.Org = "testorg1"
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
@ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host) p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetUserName(session) email, err := p.GetUserName(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "mbland", email) assert.Equal(t, "mbland", email)

View File

@ -6,6 +6,7 @@ import (
"github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// GitLabProvider represents an GitLab based Identity Provider // GitLabProvider represents an GitLab based Identity Provider
@ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
req, err := http.NewRequest("GET", req, err := http.NewRequest("GET",
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host) p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
@ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host) p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)

View File

@ -14,6 +14,7 @@ import (
"time" "time"
"github.com/pusher/oauth2_proxy/logger" "github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1" admin "google.golang.org/api/admin/directory/v1"
@ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
if err != nil { if err != nil {
return return
} }
s = &SessionState{ s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
@ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil return false, nil
} }

View File

@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct {
*ProviderData *ProviderData
} }
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
// Note that we're testing the internal validateToken() used to implement // Note that we're testing the internal validateToken() used to implement
// several Provider's ValidateSessionState() implementations // several Provider's ValidateSessionState() implementations
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool {
return false return false
} }

View File

@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// LinkedInProvider represents an LinkedIn based Identity Provider // LinkedInProvider represents an LinkedIn based Identity Provider
@ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header {
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
@ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
} }

View File

@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"testing" "testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(bURL.Host) p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@linkedin.com", email) assert.Equal(t, "user@linkedin.com", email)
@ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"} session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
@ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
bURL, _ := url.Parse(b.URL) bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(bURL.Host) p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"} session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session) email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)

View File

@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
) )
@ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er
} }
// Store the data that we found in the session state // Store the data that we found in the session state
s = &SessionState{ s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken, IDToken: jsonResponse.IDToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),

View File

@ -5,9 +5,9 @@ import (
"fmt" "fmt"
"time" "time"
"golang.org/x/oauth2"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"golang.org/x/oauth2"
) )
// OIDCProvider represents an OIDC based Identity Provider // OIDCProvider represents an OIDC based Identity Provider
@ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
} }
// Redeem exchanges the OAuth2 authentication token for an ID token // Redeem exchanges the OAuth2 authentication token for an ID token
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
ctx := context.Background() ctx := context.Background()
c := oauth2.Config{ c := oauth2.Config{
ClientID: p.ClientID, ClientID: p.ClientID,
@ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
// RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required // RefreshToken to fetch a new ID token if required
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil return false, nil
} }
@ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
return true, nil return true, nil
} }
func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) {
c := oauth2.Config{ c := oauth2.Config{
ClientID: p.ClientID, ClientID: p.ClientID,
ClientSecret: p.ClientSecret, ClientSecret: p.ClientSecret,
@ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
return return
} }
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) { func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
rawIDToken, ok := token.Extra("id_token").(string) rawIDToken, ok := token.Extra("id_token").(string)
if !ok { if !ok {
return nil, fmt.Errorf("token response did not contain an id_token") return nil, fmt.Errorf("token response did not contain an id_token")
@ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
} }
return &SessionState{ return &sessions.SessionState{
AccessToken: token.AccessToken, AccessToken: token.AccessToken,
IDToken: rawIDToken, IDToken: rawIDToken,
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,
@ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
} }
// ValidateSessionState checks that the session's IDToken is still valid // ValidateSessionState checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool { func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool {
ctx := context.Background() ctx := context.Background()
_, err := p.Verifier.Verify(ctx, s.IDToken) _, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil { if err != nil {

View File

@ -10,10 +10,11 @@ import (
"net/url" "net/url"
"github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// Redeem provides a default implementation of the OAuth2 token redemption process // Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
} }
err = json.Unmarshal(body, &jsonResponse) err = json.Unmarshal(body, &jsonResponse)
if err == nil { if err == nil {
s = &SessionState{ s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken, AccessToken: jsonResponse.AccessToken,
} }
return return
@ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
return return
} }
if a := v.Get("access_token"); a != "" { if a := v.Get("access_token"); a != "" {
s = &SessionState{AccessToken: a} s = &sessions.SessionState{AccessToken: a}
} else { } else {
err = fmt.Errorf("no access token found %s", body) err = fmt.Errorf("no access token found %s", body)
} }
@ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
} }
// CookieForSession serializes a session state for storage in a cookie // CookieForSession serializes a session state for storage in a cookie
func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) {
return s.EncodeSessionState(c) return s.EncodeSessionState(c)
} }
// SessionFromCookie deserializes a session from a cookie value // SessionFromCookie deserializes a session from a cookie value
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) {
return DecodeSessionState(v, c) return sessions.DecodeSessionState(v, c)
} }
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
// GetUserName returns the Account username // GetUserName returns the Account username
func (p *ProviderData) GetUserName(s *SessionState) (string, error) { func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
@ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool {
} }
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *ProviderData) ValidateSessionState(s *SessionState) bool { func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, nil) return validateToken(p, s.AccessToken, nil)
} }
// RefreshSessionIfNeeded should refresh the user's session if required and // RefreshSessionIfNeeded should refresh the user's session if required and
// do nothing if a refresh is not required // do nothing if a refresh is not required
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
return false, nil return false, nil
} }

View File

@ -4,12 +4,13 @@ import (
"testing" "testing"
"time" "time"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestRefresh(t *testing.T) { func TestRefresh(t *testing.T) {
p := &ProviderData{} p := &ProviderData{}
refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
}) })
assert.Equal(t, false, refreshed) assert.Equal(t, false, refreshed)

View File

@ -2,20 +2,21 @@ package providers
import ( import (
"github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
) )
// Provider represents an upstream identity provider implementation // Provider represents an upstream identity provider implementation
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(*SessionState) (string, error) GetEmailAddress(*sessions.SessionState) (string, error)
GetUserName(*SessionState) (string, error) GetUserName(*sessions.SessionState) (string, error)
Redeem(string, string) (*SessionState, error) Redeem(string, string) (*sessions.SessionState, error)
ValidateGroup(string) bool ValidateGroup(string) bool
ValidateSessionState(*SessionState) bool ValidateSessionState(*sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(*SessionState) (bool, error) RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error)
CookieForSession(*SessionState, *cookie.Cipher) (string, error) CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error)
} }
// New provides a new Provider based on the configured provider string // New provides a new Provider based on the configured provider string