Move SessionState to its own package

This commit is contained in:
Joel Speed 2019-05-05 13:33:13 +01:00
parent a1130e41a3
commit 2ab8a7d95d
No known key found for this signature in database
GPG Key ID: 6E80578D6751DEFB
20 changed files with 127 additions and 109 deletions

View File

@ -16,6 +16,7 @@ import (
"github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/providers"
"github.com/yhat/wsutil"
)
@ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool {
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
}
func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) {
func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) {
if code == "" {
return nil, errors.New("missing code")
}
@ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request,
}
// LoadCookiedSession reads the user's authentication details from the request
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) {
var age time.Duration
c, err := loadCookie(req, p.CookieName)
if err != nil {
@ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt
}
// SaveSession creates a new session cookie value and sets this on the response
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error {
value, err := p.provider.CookieForSession(s, p.CookieCipher)
if err != nil {
return err
@ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
user, ok := p.ManualSignIn(rw, req)
if ok {
session := &providers.SessionState{User: user}
session := &sessions.SessionState{User: user}
p.SaveSession(rw, req, session)
http.Redirect(rw, req, redirect, 302)
} else {
@ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
// CheckBasicAuth checks the requests Authorization header for basic auth
// credentials and authenticates these against the proxies HtpasswdFile
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) {
if p.HtpasswdFile == nil {
return nil, nil
}
@ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState,
}
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
return &providers.SessionState{User: pair[0]}, nil
return &sessions.SessionState{User: pair[0]}, nil
}
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File")
return nil, nil

View File

@ -16,6 +16,7 @@ import (
"github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/providers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider {
}
}
func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) {
func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) {
return tp.EmailAddress, nil
}
func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool {
func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool {
return tp.ValidToken
}
@ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
}
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error {
value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
if err != nil {
return err
@ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
return nil
}
func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) {
func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) {
return p.proxy.LoadCookiedSession(p.req)
}
func TestLoadCookiedSession(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults()
startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, time.Now())
session, _, err := pcTest.LoadCookiedSession()
@ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-2) * time.Hour)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, reference)
session, age, err := pcTest.LoadCookiedSession()
@ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults()
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, reference)
session, _, err := pcTest.LoadCookiedSession()
@ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pcTest := NewProcessCookieTestWithDefaults()
pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pcTest.SaveSession(startSession, reference)
pcTest.proxy.CookieRefresh = time.Hour
@ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest {
func TestAuthOnlyEndpointAccepted(t *testing.T) {
test := NewAuthOnlyEndpointTest()
startSession := &providers.SessionState{
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now())
@ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
test := NewAuthOnlyEndpointTest()
test.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
startSession := &providers.SessionState{
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, reference)
@ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) {
func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) {
test := NewAuthOnlyEndpointTest()
startSession := &providers.SessionState{
startSession := &sessions.SessionState{
Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
test.SaveSession(startSession, time.Now())
test.validateUser = false
@ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) {
pcTest.req, _ = http.NewRequest("GET",
pcTest.opts.ProxyPrefix+"/auth", nil)
startSession := &providers.SessionState{
startSession := &sessions.SessionState{
User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"}
pcTest.SaveSession(startSession, time.Now())
@ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
req := httptest.NewRequest(method, "/foo/bar", bodyBuf)
req.Header = st.header
state := &providers.SessionState{
state := &sessions.SessionState{
Email: "mbland@acm.org", AccessToken: "my_access_token"}
value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher)
if err != nil {

View File

@ -1,4 +1,4 @@
package providers
package sessions
import (
"encoding/json"

View File

@ -1,4 +1,4 @@
package providers
package sessions_test
import (
"fmt"
@ -6,6 +6,7 @@ import (
"time"
"github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
@ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, nil, err)
c2, err := cookie.NewCipher([]byte(altSecret))
assert.Equal(t, nil, err)
s := &SessionState{
s := &sessions.SessionState{
Email: "user@domain.com",
AccessToken: "token1234",
IDToken: "rawtoken1234",
@ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) {
encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err)
ss, err := DecodeSessionState(encoded, c)
ss, err := sessions.DecodeSessionState(encoded, c)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.Equal(t, "user@domain.com", ss.User)
@ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = DecodeSessionState(encoded, c2)
ss, err = sessions.DecodeSessionState(encoded, c2)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.NotEqual(t, "user@domain.com", ss.User)
@ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, nil, err)
c2, err := cookie.NewCipher([]byte(altSecret))
assert.Equal(t, nil, err)
s := &SessionState{
s := &sessions.SessionState{
User: "just-user",
Email: "user@domain.com",
AccessToken: "token1234",
@ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err)
ss, err := DecodeSessionState(encoded, c)
ss, err := sessions.DecodeSessionState(encoded, c)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
@ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = DecodeSessionState(encoded, c2)
ss, err = sessions.DecodeSessionState(encoded, c2)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.NotEqual(t, s.User, ss.User)
@ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
}
func TestSessionStateSerializationNoCipher(t *testing.T) {
s := &SessionState{
s := &sessions.SessionState{
Email: "user@domain.com",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
@ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
assert.Equal(t, nil, err)
// only email should have been serialized
ss, err := DecodeSessionState(encoded, nil)
ss, err := sessions.DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err)
assert.Equal(t, "user@domain.com", ss.User)
assert.Equal(t, s.Email, ss.Email)
@ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
}
func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
s := &SessionState{
s := &sessions.SessionState{
User: "just-user",
Email: "user@domain.com",
AccessToken: "token1234",
@ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
assert.Equal(t, nil, err)
// only email should have been serialized
ss, err := DecodeSessionState(encoded, nil)
ss, err := sessions.DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email)
@ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
}
func TestExpired(t *testing.T) {
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
assert.Equal(t, true, s.IsExpired())
s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
assert.Equal(t, false, s.IsExpired())
s = &SessionState{}
s = &sessions.SessionState{}
assert.Equal(t, false, s.IsExpired())
}
type testCase struct {
SessionState
sessions.SessionState
Encoded string
Cipher *cookie.Cipher
Error bool
@ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) {
testCases := []testCase{
{
SessionState: SessionState{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
{
SessionState: SessionState{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
@ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) {
for i, tc := range testCases {
encoded, err := tc.EncodeSessionState(tc.Cipher)
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err)
if tc.Error {
assert.Error(t, err)
assert.Empty(t, encoded)
@ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) {
}
}
// TestDecodeSessionState tests DecodeSessionState with the test vector
// TestDecodeSessionState testssessions.DecodeSessionState with the test vector
func TestDecodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour)
eJSON, _ := e.MarshalJSON()
@ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) {
testCases := []testCase{
{
SessionState: SessionState{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: `{"Email":"user@domain.com","User":"just-user"}`,
},
{
SessionState: SessionState{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "user@domain.com",
},
Encoded: `{"Email":"user@domain.com"}`,
},
{
SessionState: SessionState{
SessionState: sessions.SessionState{
User: "just-user",
},
Encoded: `{"User":"just-user"}`,
},
{
SessionState: SessionState{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
},
Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString),
},
{
SessionState: SessionState{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
@ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) {
Cipher: c,
},
{
SessionState: SessionState{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
},
@ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) {
Error: true,
},
{
SessionState: SessionState{
SessionState: sessions.SessionState{
User: "just-user",
Email: "user@domain.com",
},
@ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) {
Error: true,
},
{
SessionState: SessionState{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
@ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) {
Cipher: c,
},
{
SessionState: SessionState{
SessionState: sessions.SessionState{
Email: "user@domain.com",
User: "just-user",
AccessToken: "token1234",
@ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) {
}
for i, tc := range testCases {
ss, err := DecodeSessionState(tc.Encoded, tc.Cipher)
t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher)
t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err)
if tc.Error {
assert.Error(t, err)
assert.Nil(t, ss)

View File

@ -9,6 +9,7 @@ import (
"github.com/bitly/go-simplejson"
"github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
)
// AzureProvider represents an Azure based Identity Provider
@ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) {
}
// GetEmailAddress returns the Account email address
func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) {
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
var email string
var err error

View File

@ -6,6 +6,7 @@ import (
"net/url"
"testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
@ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email)
@ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email)
@ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "user@windows.net", email)
@ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email)
@ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "", email)
@ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testAzureProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, "type assertion to string failed", err.Error())
assert.Equal(t, "", email)

View File

@ -7,6 +7,7 @@ import (
"net/url"
"github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
)
// FacebookProvider represents an Facebook based Identity Provider
@ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header {
}
// GetEmailAddress returns the Account email address
func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
if s.AccessToken == "" {
return "", errors.New("missing access token")
}
@ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) {
}
// ValidateSessionState validates the AccessToken
func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool {
func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken))
}

View File

@ -11,6 +11,7 @@ import (
"strings"
"github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
)
// GitHubProvider represents an GitHub based Identity Provider
@ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
}
// GetEmailAddress returns the Account email address
func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
var emails []struct {
Email string `json:"email"`
@ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
}
// GetUserName returns the Account user name
func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) {
func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) {
var user struct {
Login string `json:"login"`
Email string `json:"email"`

View File

@ -6,6 +6,7 @@ import (
"net/url"
"testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
@ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
@ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Empty(t, "", email)
@ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) {
p := testGitHubProvider(bURL.Host)
p.Org = "testorg1"
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
@ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"}
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
@ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
@ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testGitHubProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetUserName(session)
assert.Equal(t, nil, err)
assert.Equal(t, "mbland", email)

View File

@ -6,6 +6,7 @@ import (
"github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
)
// GitLabProvider represents an GitLab based Identity Provider
@ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider {
}
// GetEmailAddress returns the Account email address
func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) {
func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
req, err := http.NewRequest("GET",
p.ValidateURL.String()+"?access_token="+s.AccessToken, nil)

View File

@ -6,6 +6,7 @@ import (
"net/url"
"testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
@ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
@ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"}
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
@ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testGitLabProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)

View File

@ -14,6 +14,7 @@ import (
"time"
"github.com/pusher/oauth2_proxy/logger"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
admin "google.golang.org/api/admin/directory/v1"
@ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) {
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" {
err = errors.New("missing code")
return
@ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
if err != nil {
return
}
s = &SessionState{
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
@ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil
}

View File

@ -7,6 +7,7 @@ import (
"net/url"
"testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
@ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct {
*ProviderData
}
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) {
func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented")
}
// Note that we're testing the internal validateToken() used to implement
// several Provider's ValidateSessionState() implementations
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool {
func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool {
return false
}

View File

@ -7,6 +7,7 @@ import (
"net/url"
"github.com/pusher/oauth2_proxy/api"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
)
// LinkedInProvider represents an LinkedIn based Identity Provider
@ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header {
}
// GetEmailAddress returns the Account email address
func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
if s.AccessToken == "" {
return "", errors.New("missing access token")
}
@ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
}
// ValidateSessionState validates the AccessToken
func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool {
func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
}

View File

@ -6,6 +6,7 @@ import (
"net/url"
"testing"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
@ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err)
assert.Equal(t, "user@linkedin.com", email)
@ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as
// JSON to fail.
session := &SessionState{AccessToken: "unexpected_access_token"}
session := &sessions.SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)
@ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
bURL, _ := url.Parse(b.URL)
p := testLinkedInProvider(bURL.Host)
session := &SessionState{AccessToken: "imaginary_access_token"}
session := &sessions.SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err)
assert.Equal(t, "", email)

View File

@ -13,6 +13,7 @@ import (
"time"
"github.com/dgrijalva/jwt-go"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"gopkg.in/square/go-jose.v2"
)
@ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" {
err = errors.New("missing code")
return
@ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er
}
// Store the data that we found in the session state
s = &SessionState{
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),

View File

@ -5,9 +5,9 @@ import (
"fmt"
"time"
"golang.org/x/oauth2"
oidc "github.com/coreos/go-oidc"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"golang.org/x/oauth2"
)
// OIDCProvider represents an OIDC based Identity Provider
@ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider {
}
// Redeem exchanges the OAuth2 authentication token for an ID token
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) {
func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
ctx := context.Background()
c := oauth2.Config{
ClientID: p.ClientID,
@ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil
}
@ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
return true, nil
}
func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) {
c := oauth2.Config{
ClientID: p.ClientID,
ClientSecret: p.ClientSecret,
@ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
return
}
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) {
func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) {
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, fmt.Errorf("token response did not contain an id_token")
@ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
}
return &SessionState{
return &sessions.SessionState{
AccessToken: token.AccessToken,
IDToken: rawIDToken,
RefreshToken: token.RefreshToken,
@ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
}
// ValidateSessionState checks that the session's IDToken is still valid
func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool {
func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool {
ctx := context.Background()
_, err := p.Verifier.Verify(ctx, s.IDToken)
if err != nil {

View File

@ -10,10 +10,11 @@ import (
"net/url"
"github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
)
// Redeem provides a default implementation of the OAuth2 token redemption process
func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) {
func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" {
err = errors.New("missing code")
return
@ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
}
err = json.Unmarshal(body, &jsonResponse)
if err == nil {
s = &SessionState{
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
}
return
@ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
return
}
if a := v.Get("access_token"); a != "" {
s = &SessionState{AccessToken: a}
s = &sessions.SessionState{AccessToken: a}
} else {
err = fmt.Errorf("no access token found %s", body)
}
@ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
}
// CookieForSession serializes a session state for storage in a cookie
func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) {
func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) {
return s.EncodeSessionState(c)
}
// SessionFromCookie deserializes a session from a cookie value
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) {
return DecodeSessionState(v, c)
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) {
return sessions.DecodeSessionState(v, c)
}
// GetEmailAddress returns the Account email address
func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented")
}
// GetUserName returns the Account username
func (p *ProviderData) GetUserName(s *SessionState) (string, error) {
func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented")
}
@ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool {
}
// ValidateSessionState validates the AccessToken
func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, nil)
}
// RefreshSessionIfNeeded should refresh the user's session if required and
// do nothing if a refresh is not required
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
return false, nil
}

View File

@ -4,12 +4,13 @@ import (
"testing"
"time"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/stretchr/testify/assert"
)
func TestRefresh(t *testing.T) {
p := &ProviderData{}
refreshed, err := p.RefreshSessionIfNeeded(&SessionState{
refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
})
assert.Equal(t, false, refreshed)

View File

@ -2,20 +2,21 @@ package providers
import (
"github.com/pusher/oauth2_proxy/cookie"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
)
// Provider represents an upstream identity provider implementation
type Provider interface {
Data() *ProviderData
GetEmailAddress(*SessionState) (string, error)
GetUserName(*SessionState) (string, error)
Redeem(string, string) (*SessionState, error)
GetEmailAddress(*sessions.SessionState) (string, error)
GetUserName(*sessions.SessionState) (string, error)
Redeem(string, string) (*sessions.SessionState, error)
ValidateGroup(string) bool
ValidateSessionState(*SessionState) bool
ValidateSessionState(*sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(*SessionState) (bool, error)
SessionFromCookie(string, *cookie.Cipher) (*SessionState, error)
CookieForSession(*SessionState, *cookie.Cipher) (string, error)
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error)
CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error)
}
// New provides a new Provider based on the configured provider string