From 2ab8a7d95d172b4afd2b46823f4180734cc7794d Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sun, 5 May 2019 13:33:13 +0100 Subject: [PATCH] Move SessionState to its own package --- oauthproxy.go | 13 ++-- oauthproxy_test.go | 27 ++++---- .../apis/sessions}/session_state.go | 2 +- .../apis/sessions}/session_state_test.go | 61 ++++++++++--------- providers/azure.go | 3 +- providers/azure_test.go | 13 ++-- providers/facebook.go | 5 +- providers/github.go | 5 +- providers/github_test.go | 13 ++-- providers/gitlab.go | 3 +- providers/gitlab_test.go | 7 ++- providers/google.go | 7 ++- providers/internal_util_test.go | 5 +- providers/linkedin.go | 5 +- providers/linkedin_test.go | 7 ++- providers/logingov.go | 5 +- providers/oidc.go | 16 ++--- providers/provider_default.go | 21 ++++--- providers/provider_default_test.go | 3 +- providers/providers.go | 15 ++--- 20 files changed, 127 insertions(+), 109 deletions(-) rename {providers => pkg/apis/sessions}/session_state.go (99%) rename {providers => pkg/apis/sessions}/session_state_test.go (84%) diff --git a/oauthproxy.go b/oauthproxy.go index 52f7c79..02d9ac1 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -16,6 +16,7 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/providers" "github.com/yhat/wsutil" ) @@ -292,7 +293,7 @@ func (p *OAuthProxy) displayCustomLoginForm() bool { return p.HtpasswdFile != nil && p.DisplayHtpasswdForm } -func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) { +func (p *OAuthProxy) redeemCode(host, code string) (s *sessions.SessionState, err error) { if code == "" { return nil, errors.New("missing code") } @@ -484,7 +485,7 @@ func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, } // LoadCookiedSession reads the user's authentication details from the request -func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { +func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessions.SessionState, time.Duration, error) { var age time.Duration c, err := loadCookie(req, p.CookieName) if err != nil { @@ -506,7 +507,7 @@ func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionSt } // SaveSession creates a new session cookie value and sets this on the response -func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error { +func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *sessions.SessionState) error { value, err := p.provider.CookieForSession(s, p.CookieCipher) if err != nil { return err @@ -693,7 +694,7 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { user, ok := p.ManualSignIn(rw, req) if ok { - session := &providers.SessionState{User: user} + session := &sessions.SessionState{User: user} p.SaveSession(rw, req, session) http.Redirect(rw, req, redirect, 302) } else { @@ -944,7 +945,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int // CheckBasicAuth checks the requests Authorization header for basic auth // credentials and authenticates these against the proxies HtpasswdFile -func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) { +func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*sessions.SessionState, error) { if p.HtpasswdFile == nil { return nil, nil } @@ -966,7 +967,7 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, } if p.HtpasswdFile.Validate(pair[0], pair[1]) { logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File") - return &providers.SessionState{User: pair[0]}, nil + return &sessions.SessionState{User: pair[0]}, nil } logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth: not in Htpasswd File") return nil, nil diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 65a8fe1..914e99f 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -16,6 +16,7 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/providers" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -253,11 +254,11 @@ func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { } } -func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) { +func (tp *TestProvider) GetEmailAddress(session *sessions.SessionState) (string, error) { return tp.EmailAddress, nil } -func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool { +func (tp *TestProvider) ValidateSessionState(session *sessions.SessionState) bool { return tp.ValidToken } @@ -637,7 +638,7 @@ func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cook return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) } -func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { +func (p *ProcessCookieTest) SaveSession(s *sessions.SessionState, ref time.Time) error { value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher) if err != nil { return err @@ -648,14 +649,14 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time return nil } -func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) { +func (p *ProcessCookieTest) LoadCookiedSession() (*sessions.SessionState, time.Duration, error) { return p.proxy.LoadCookiedSession(p.req) } func TestLoadCookiedSession(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() - startSession := &providers.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} + startSession := &sessions.SessionState{Email: "john.doe@example.com", AccessToken: "my_access_token"} pcTest.SaveSession(startSession, time.Now()) session, _, err := pcTest.LoadCookiedSession() @@ -680,7 +681,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) { pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour reference := time.Now().Add(time.Duration(-2) * time.Hour) - startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} pcTest.SaveSession(startSession, reference) session, age, err := pcTest.LoadCookiedSession() @@ -695,7 +696,7 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} pcTest.SaveSession(startSession, reference) session, _, err := pcTest.LoadCookiedSession() @@ -709,7 +710,7 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { pcTest := NewProcessCookieTestWithDefaults() pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} + startSession := &sessions.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} pcTest.SaveSession(startSession, reference) pcTest.proxy.CookieRefresh = time.Hour @@ -729,7 +730,7 @@ func NewAuthOnlyEndpointTest() *ProcessCookieTest { func TestAuthOnlyEndpointAccepted(t *testing.T) { test := NewAuthOnlyEndpointTest() - startSession := &providers.SessionState{ + startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} test.SaveSession(startSession, time.Now()) @@ -752,7 +753,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { test := NewAuthOnlyEndpointTest() test.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) - startSession := &providers.SessionState{ + startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} test.SaveSession(startSession, reference) @@ -764,7 +765,7 @@ func TestAuthOnlyEndpointUnauthorizedOnExpiration(t *testing.T) { func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { test := NewAuthOnlyEndpointTest() - startSession := &providers.SessionState{ + startSession := &sessions.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} test.SaveSession(startSession, time.Now()) test.validateUser = false @@ -795,7 +796,7 @@ func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { pcTest.req, _ = http.NewRequest("GET", pcTest.opts.ProxyPrefix+"/auth", nil) - startSession := &providers.SessionState{ + startSession := &sessions.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} pcTest.SaveSession(startSession, time.Now()) @@ -927,7 +928,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { req := httptest.NewRequest(method, "/foo/bar", bodyBuf) req.Header = st.header - state := &providers.SessionState{ + state := &sessions.SessionState{ Email: "mbland@acm.org", AccessToken: "my_access_token"} value, err := proxy.provider.CookieForSession(state, proxy.CookieCipher) if err != nil { diff --git a/providers/session_state.go b/pkg/apis/sessions/session_state.go similarity index 99% rename from providers/session_state.go rename to pkg/apis/sessions/session_state.go index c3402ac..f6efefd 100644 --- a/providers/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -1,4 +1,4 @@ -package providers +package sessions import ( "encoding/json" diff --git a/providers/session_state_test.go b/pkg/apis/sessions/session_state_test.go similarity index 84% rename from providers/session_state_test.go rename to pkg/apis/sessions/session_state_test.go index 78957c6..83b21a4 100644 --- a/providers/session_state_test.go +++ b/pkg/apis/sessions/session_state_test.go @@ -1,4 +1,4 @@ -package providers +package sessions_test import ( "fmt" @@ -6,6 +6,7 @@ import ( "time" "github.com/pusher/oauth2_proxy/cookie" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -17,7 +18,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, nil, err) c2, err := cookie.NewCipher([]byte(altSecret)) assert.Equal(t, nil, err) - s := &SessionState{ + s := &sessions.SessionState{ Email: "user@domain.com", AccessToken: "token1234", IDToken: "rawtoken1234", @@ -27,7 +28,7 @@ func TestSessionStateSerialization(t *testing.T) { encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - ss, err := DecodeSessionState(encoded, c) + ss, err := sessions.DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, "user@domain.com", ss.User) @@ -38,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, s.RefreshToken, ss.RefreshToken) // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = DecodeSessionState(encoded, c2) + ss, err = sessions.DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.NotEqual(t, "user@domain.com", ss.User) @@ -54,7 +55,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, nil, err) c2, err := cookie.NewCipher([]byte(altSecret)) assert.Equal(t, nil, err) - s := &SessionState{ + s := &sessions.SessionState{ User: "just-user", Email: "user@domain.com", AccessToken: "token1234", @@ -64,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - ss, err := DecodeSessionState(encoded, c) + ss, err := sessions.DecodeSessionState(encoded, c) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) @@ -74,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { assert.Equal(t, s.RefreshToken, ss.RefreshToken) // ensure a different cipher can't decode properly (ie: it gets gibberish) - ss, err = DecodeSessionState(encoded, c2) + ss, err = sessions.DecodeSessionState(encoded, c2) t.Logf("%#v", ss) assert.Equal(t, nil, err) assert.NotEqual(t, s.User, ss.User) @@ -85,7 +86,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { } func TestSessionStateSerializationNoCipher(t *testing.T) { - s := &SessionState{ + s := &sessions.SessionState{ Email: "user@domain.com", AccessToken: "token1234", ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), @@ -95,7 +96,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { assert.Equal(t, nil, err) // only email should have been serialized - ss, err := DecodeSessionState(encoded, nil) + ss, err := sessions.DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) assert.Equal(t, "user@domain.com", ss.User) assert.Equal(t, s.Email, ss.Email) @@ -104,7 +105,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) { } func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { - s := &SessionState{ + s := &sessions.SessionState{ User: "just-user", Email: "user@domain.com", AccessToken: "token1234", @@ -115,7 +116,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { assert.Equal(t, nil, err) // only email should have been serialized - ss, err := DecodeSessionState(encoded, nil) + ss, err := sessions.DecodeSessionState(encoded, nil) assert.Equal(t, nil, err) assert.Equal(t, s.User, ss.User) assert.Equal(t, s.Email, ss.Email) @@ -124,18 +125,18 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) { } func TestExpired(t *testing.T) { - s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} + s := &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)} assert.Equal(t, true, s.IsExpired()) - s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} + s = &sessions.SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)} assert.Equal(t, false, s.IsExpired()) - s = &SessionState{} + s = &sessions.SessionState{} assert.Equal(t, false, s.IsExpired()) } type testCase struct { - SessionState + sessions.SessionState Encoded string Cipher *cookie.Cipher Error bool @@ -150,14 +151,14 @@ func TestEncodeSessionState(t *testing.T) { testCases := []testCase{ { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: `{"Email":"user@domain.com","User":"just-user"}`, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -171,7 +172,7 @@ func TestEncodeSessionState(t *testing.T) { for i, tc := range testCases { encoded, err := tc.EncodeSessionState(tc.Cipher) - t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) + t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, encoded, tc.SessionState, err) if tc.Error { assert.Error(t, err) assert.Empty(t, encoded) @@ -182,7 +183,7 @@ func TestEncodeSessionState(t *testing.T) { } } -// TestDecodeSessionState tests DecodeSessionState with the test vector +// TestDecodeSessionState testssessions.DecodeSessionState with the test vector func TestDecodeSessionState(t *testing.T) { e := time.Now().Add(time.Duration(1) * time.Hour) eJSON, _ := e.MarshalJSON() @@ -194,34 +195,34 @@ func TestDecodeSessionState(t *testing.T) { testCases := []testCase{ { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: `{"Email":"user@domain.com","User":"just-user"}`, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "user@domain.com", }, Encoded: `{"Email":"user@domain.com"}`, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ User: "just-user", }, Encoded: `{"User":"just-user"}`, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", }, Encoded: fmt.Sprintf(`{"Email":"user@domain.com","User":"just-user","AccessToken":"I6s+ml+/MldBMgHIiC35BTKTh57skGX24w==","IDToken":"xojNdyyjB1HgYWh6XMtXY/Ph5eCVxa1cNsklJw==","RefreshToken":"qEX0x6RmASxo4dhlBG6YuRs9Syn/e9sHu/+K","ExpiresOn":%s}`, eString), }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -233,7 +234,7 @@ func TestDecodeSessionState(t *testing.T) { Cipher: c, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", }, @@ -251,7 +252,7 @@ func TestDecodeSessionState(t *testing.T) { Error: true, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ User: "just-user", Email: "user@domain.com", }, @@ -272,7 +273,7 @@ func TestDecodeSessionState(t *testing.T) { Error: true, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -283,7 +284,7 @@ func TestDecodeSessionState(t *testing.T) { Cipher: c, }, { - SessionState: SessionState{ + SessionState: sessions.SessionState{ Email: "user@domain.com", User: "just-user", AccessToken: "token1234", @@ -297,8 +298,8 @@ func TestDecodeSessionState(t *testing.T) { } for i, tc := range testCases { - ss, err := DecodeSessionState(tc.Encoded, tc.Cipher) - t.Logf("i:%d Encoded:%#v SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) + ss, err := sessions.DecodeSessionState(tc.Encoded, tc.Cipher) + t.Logf("i:%d Encoded:%#vsessions.SessionState:%#v Error:%#v", i, tc.Encoded, ss, err) if tc.Error { assert.Error(t, err) assert.Nil(t, ss) diff --git a/providers/azure.go b/providers/azure.go index baae38f..a7961d2 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -9,6 +9,7 @@ import ( "github.com/bitly/go-simplejson" "github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // AzureProvider represents an Azure based Identity Provider @@ -88,7 +89,7 @@ func getEmailFromJSON(json *simplejson.Json) (string, error) { } // GetEmailAddress returns the Account email address -func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { +func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { var email string var err error diff --git a/providers/azure_test.go b/providers/azure_test.go index 469f2d1..8d34bdc 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -6,6 +6,7 @@ import ( "net/url" "testing" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -128,7 +129,7 @@ func TestAzureProviderGetEmailAddress(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) @@ -141,7 +142,7 @@ func TestAzureProviderGetEmailAddressMailNull(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) @@ -154,7 +155,7 @@ func TestAzureProviderGetEmailAddressGetUserPrincipalName(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@windows.net", email) @@ -167,7 +168,7 @@ func TestAzureProviderGetEmailAddressFailToGetEmailAddress(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "", email) @@ -180,7 +181,7 @@ func TestAzureProviderGetEmailAddressEmptyUserPrincipalName(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "", email) @@ -193,7 +194,7 @@ func TestAzureProviderGetEmailAddressIncorrectOtherMails(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testAzureProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, "type assertion to string failed", err.Error()) assert.Equal(t, "", email) diff --git a/providers/facebook.go b/providers/facebook.go index 6f81f15..9897a1b 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -7,6 +7,7 @@ import ( "net/url" "github.com/pusher/oauth2_proxy/api" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // FacebookProvider represents an Facebook based Identity Provider @@ -54,7 +55,7 @@ func getFacebookHeader(accessToken string) http.Header { } // GetEmailAddress returns the Account email address -func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { +func (p *FacebookProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { if s.AccessToken == "" { return "", errors.New("missing access token") } @@ -79,6 +80,6 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { } // ValidateSessionState validates the AccessToken -func (p *FacebookProvider) ValidateSessionState(s *SessionState) bool { +func (p *FacebookProvider) ValidateSessionState(s *sessions.SessionState) bool { return validateToken(p, s.AccessToken, getFacebookHeader(s.AccessToken)) } diff --git a/providers/github.go b/providers/github.go index f00fc19..b60ffe1 100644 --- a/providers/github.go +++ b/providers/github.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // GitHubProvider represents an GitHub based Identity Provider @@ -200,7 +201,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { } // GetEmailAddress returns the Account email address -func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { +func (p *GitHubProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { var emails []struct { Email string `json:"email"` @@ -259,7 +260,7 @@ func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) { } // GetUserName returns the Account user name -func (p *GitHubProvider) GetUserName(s *SessionState) (string, error) { +func (p *GitHubProvider) GetUserName(s *sessions.SessionState) (string, error) { var user struct { Login string `json:"login"` Email string `json:"email"` diff --git a/providers/github_test.go b/providers/github_test.go index 4b093ca..2d45b84 100644 --- a/providers/github_test.go +++ b/providers/github_test.go @@ -6,6 +6,7 @@ import ( "net/url" "testing" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -103,7 +104,7 @@ func TestGitHubProviderGetEmailAddress(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) @@ -116,7 +117,7 @@ func TestGitHubProviderGetEmailAddressNotVerified(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Empty(t, "", email) @@ -134,7 +135,7 @@ func TestGitHubProviderGetEmailAddressWithOrg(t *testing.T) { p := testGitHubProvider(bURL.Host) p.Org = "testorg1" - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) @@ -152,7 +153,7 @@ func TestGitHubProviderGetEmailAddressFailedRequest(t *testing.T) { // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. - session := &SessionState{AccessToken: "unexpected_access_token"} + session := &sessions.SessionState{AccessToken: "unexpected_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) @@ -165,7 +166,7 @@ func TestGitHubProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) @@ -178,7 +179,7 @@ func TestGitHubProviderGetUserName(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitHubProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetUserName(session) assert.Equal(t, nil, err) assert.Equal(t, "mbland", email) diff --git a/providers/gitlab.go b/providers/gitlab.go index 1962552..af956c4 100644 --- a/providers/gitlab.go +++ b/providers/gitlab.go @@ -6,6 +6,7 @@ import ( "github.com/pusher/oauth2_proxy/api" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // GitLabProvider represents an GitLab based Identity Provider @@ -44,7 +45,7 @@ func NewGitLabProvider(p *ProviderData) *GitLabProvider { } // GetEmailAddress returns the Account email address -func (p *GitLabProvider) GetEmailAddress(s *SessionState) (string, error) { +func (p *GitLabProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { req, err := http.NewRequest("GET", p.ValidateURL.String()+"?access_token="+s.AccessToken, nil) diff --git a/providers/gitlab_test.go b/providers/gitlab_test.go index 19038e1..112eb89 100644 --- a/providers/gitlab_test.go +++ b/providers/gitlab_test.go @@ -6,6 +6,7 @@ import ( "net/url" "testing" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -89,7 +90,7 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitLabProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "michael.bland@gsa.gov", email) @@ -107,7 +108,7 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) { // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. - session := &SessionState{AccessToken: "unexpected_access_token"} + session := &sessions.SessionState{AccessToken: "unexpected_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) @@ -120,7 +121,7 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testGitLabProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) diff --git a/providers/google.go b/providers/google.go index e3cb380..f79a131 100644 --- a/providers/google.go +++ b/providers/google.go @@ -14,6 +14,7 @@ import ( "time" "github.com/pusher/oauth2_proxy/logger" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "golang.org/x/oauth2" "golang.org/x/oauth2/google" admin "google.golang.org/api/admin/directory/v1" @@ -96,7 +97,7 @@ func claimsFromIDToken(idToken string) (*claims, error) { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { +func (p *GoogleProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -145,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err if err != nil { return } - s = &SessionState{ + s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), @@ -258,7 +259,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool { // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required -func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { +func (p *GoogleProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil } diff --git a/providers/internal_util_test.go b/providers/internal_util_test.go index 1a03fc5..ba6d470 100644 --- a/providers/internal_util_test.go +++ b/providers/internal_util_test.go @@ -7,6 +7,7 @@ import ( "net/url" "testing" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -19,13 +20,13 @@ type ValidateSessionStateTestProvider struct { *ProviderData } -func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) { +func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } // Note that we're testing the internal validateToken() used to implement // several Provider's ValidateSessionState() implementations -func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool { +func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *sessions.SessionState) bool { return false } diff --git a/providers/linkedin.go b/providers/linkedin.go index 8c392f8..a31b4a1 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -7,6 +7,7 @@ import ( "net/url" "github.com/pusher/oauth2_proxy/api" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // LinkedInProvider represents an LinkedIn based Identity Provider @@ -50,7 +51,7 @@ func getLinkedInHeader(accessToken string) http.Header { } // GetEmailAddress returns the Account email address -func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { +func (p *LinkedInProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { if s.AccessToken == "" { return "", errors.New("missing access token") } @@ -73,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) { } // ValidateSessionState validates the AccessToken -func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool { +func (p *LinkedInProvider) ValidateSessionState(s *sessions.SessionState) bool { return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken)) } diff --git a/providers/linkedin_test.go b/providers/linkedin_test.go index 7911522..9910a71 100644 --- a/providers/linkedin_test.go +++ b/providers/linkedin_test.go @@ -6,6 +6,7 @@ import ( "net/url" "testing" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) @@ -97,7 +98,7 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testLinkedInProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.Equal(t, nil, err) assert.Equal(t, "user@linkedin.com", email) @@ -113,7 +114,7 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as // JSON to fail. - session := &SessionState{AccessToken: "unexpected_access_token"} + session := &sessions.SessionState{AccessToken: "unexpected_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) @@ -126,7 +127,7 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { bURL, _ := url.Parse(b.URL) p := testLinkedInProvider(bURL.Host) - session := &SessionState{AccessToken: "imaginary_access_token"} + session := &sessions.SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) assert.NotEqual(t, nil, err) assert.Equal(t, "", email) diff --git a/providers/logingov.go b/providers/logingov.go index 09bd3be..60f4260 100644 --- a/providers/logingov.go +++ b/providers/logingov.go @@ -13,6 +13,7 @@ import ( "time" "github.com/dgrijalva/jwt-go" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "gopkg.in/square/go-jose.v2" ) @@ -173,7 +174,7 @@ func emailFromUserInfo(accessToken string, userInfoEndpoint string) (email strin } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { +func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -248,7 +249,7 @@ func (p *LoginGovProvider) Redeem(redirectURL, code string) (s *SessionState, er } // Store the data that we found in the session state - s = &SessionState{ + s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, IDToken: jsonResponse.IDToken, ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), diff --git a/providers/oidc.go b/providers/oidc.go index d751be5..bacabdf 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -5,9 +5,9 @@ import ( "fmt" "time" - "golang.org/x/oauth2" - oidc "github.com/coreos/go-oidc" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "golang.org/x/oauth2" ) // OIDCProvider represents an OIDC based Identity Provider @@ -24,7 +24,7 @@ func NewOIDCProvider(p *ProviderData) *OIDCProvider { } // Redeem exchanges the OAuth2 authentication token for an ID token -func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err error) { +func (p *OIDCProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { ctx := context.Background() c := oauth2.Config{ ClientID: p.ClientID, @@ -47,7 +47,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er // RefreshSessionIfNeeded checks if the session has expired and uses the // RefreshToken to fetch a new ID token if required -func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { +func (p *OIDCProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { return false, nil } @@ -63,7 +63,7 @@ func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) { return true, nil } -func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { +func (p *OIDCProvider) redeemRefreshToken(s *sessions.SessionState) (err error) { c := oauth2.Config{ ClientID: p.ClientID, ClientSecret: p.ClientSecret, @@ -92,7 +92,7 @@ func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) { return } -func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*SessionState, error) { +func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Token) (*sessions.SessionState, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok { return nil, fmt.Errorf("token response did not contain an id_token") @@ -122,7 +122,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email) } - return &SessionState{ + return &sessions.SessionState{ AccessToken: token.AccessToken, IDToken: rawIDToken, RefreshToken: token.RefreshToken, @@ -133,7 +133,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok } // ValidateSessionState checks that the session's IDToken is still valid -func (p *OIDCProvider) ValidateSessionState(s *SessionState) bool { +func (p *OIDCProvider) ValidateSessionState(s *sessions.SessionState) bool { ctx := context.Background() _, err := p.Verifier.Verify(ctx, s.IDToken) if err != nil { diff --git a/providers/provider_default.go b/providers/provider_default.go index f8f59ab..cd78251 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -10,10 +10,11 @@ import ( "net/url" "github.com/pusher/oauth2_proxy/cookie" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // Redeem provides a default implementation of the OAuth2 token redemption process -func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err error) { +func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { if code == "" { err = errors.New("missing code") return @@ -59,7 +60,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er } err = json.Unmarshal(body, &jsonResponse) if err == nil { - s = &SessionState{ + s = &sessions.SessionState{ AccessToken: jsonResponse.AccessToken, } return @@ -71,7 +72,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er return } if a := v.Get("access_token"); a != "" { - s = &SessionState{AccessToken: a} + s = &sessions.SessionState{AccessToken: a} } else { err = fmt.Errorf("no access token found %s", body) } @@ -94,22 +95,22 @@ func (p *ProviderData) GetLoginURL(redirectURI, state string) string { } // CookieForSession serializes a session state for storage in a cookie -func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) { +func (p *ProviderData) CookieForSession(s *sessions.SessionState, c *cookie.Cipher) (string, error) { return s.EncodeSessionState(c) } // SessionFromCookie deserializes a session from a cookie value -func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) { - return DecodeSessionState(v, c) +func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *sessions.SessionState, err error) { + return sessions.DecodeSessionState(v, c) } // GetEmailAddress returns the Account email address -func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) { +func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } // GetUserName returns the Account username -func (p *ProviderData) GetUserName(s *SessionState) (string, error) { +func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } @@ -120,12 +121,12 @@ func (p *ProviderData) ValidateGroup(email string) bool { } // ValidateSessionState validates the AccessToken -func (p *ProviderData) ValidateSessionState(s *SessionState) bool { +func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool { return validateToken(p, s.AccessToken, nil) } // RefreshSessionIfNeeded should refresh the user's session if required and // do nothing if a refresh is not required -func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { +func (p *ProviderData) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { return false, nil } diff --git a/providers/provider_default_test.go b/providers/provider_default_test.go index abff0a9..ffe4aa7 100644 --- a/providers/provider_default_test.go +++ b/providers/provider_default_test.go @@ -4,12 +4,13 @@ import ( "testing" "time" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/stretchr/testify/assert" ) func TestRefresh(t *testing.T) { p := &ProviderData{} - refreshed, err := p.RefreshSessionIfNeeded(&SessionState{ + refreshed, err := p.RefreshSessionIfNeeded(&sessions.SessionState{ ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute), }) assert.Equal(t, false, refreshed) diff --git a/providers/providers.go b/providers/providers.go index 4616153..57ace41 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -2,20 +2,21 @@ package providers import ( "github.com/pusher/oauth2_proxy/cookie" + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" ) // Provider represents an upstream identity provider implementation type Provider interface { Data() *ProviderData - GetEmailAddress(*SessionState) (string, error) - GetUserName(*SessionState) (string, error) - Redeem(string, string) (*SessionState, error) + GetEmailAddress(*sessions.SessionState) (string, error) + GetUserName(*sessions.SessionState) (string, error) + Redeem(string, string) (*sessions.SessionState, error) ValidateGroup(string) bool - ValidateSessionState(*SessionState) bool + ValidateSessionState(*sessions.SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string - RefreshSessionIfNeeded(*SessionState) (bool, error) - SessionFromCookie(string, *cookie.Cipher) (*SessionState, error) - CookieForSession(*SessionState, *cookie.Cipher) (string, error) + RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) + SessionFromCookie(string, *cookie.Cipher) (*sessions.SessionState, error) + CookieForSession(*sessions.SessionState, *cookie.Cipher) (string, error) } // New provides a new Provider based on the configured provider string