From 8471f972e141a44f2319e7c7f35ad10c6195c2fa Mon Sep 17 00:00:00 2001 From: Mike Bland Date: Tue, 12 May 2015 21:48:13 -0400 Subject: [PATCH] Move ValidateToken() to Provider --- api/api.go | 17 ++++ api/api_test.go | 59 ++++++++++++ oauthproxy.go | 23 +---- oauthproxy_test.go | 153 ++++++-------------------------- providers/google.go | 4 + providers/internal_util.go | 24 +++++ providers/internal_util_test.go | 122 +++++++++++++++++++++++++ providers/linkedin.go | 19 +++- providers/linkedin_test.go | 9 ++ providers/myusa.go | 4 + providers/providers.go | 1 + 11 files changed, 285 insertions(+), 150 deletions(-) create mode 100644 providers/internal_util.go create mode 100644 providers/internal_util_test.go diff --git a/api/api.go b/api/api.go index 3cd88d2..d2de0ce 100644 --- a/api/api.go +++ b/api/api.go @@ -30,3 +30,20 @@ func Request(req *http.Request) (*simplejson.Json, error) { } return data, nil } + +func RequestUnparsedResponse(url string, header http.Header) ( + response *http.Response, err error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, errors.New("failed building request for " + + url + ": " + err.Error()) + } + req.Header = header + + httpclient := &http.Client{} + if response, err = httpclient.Do(req); err != nil { + return nil, errors.New("request failed for " + + url + ": " + err.Error()) + } + return +} diff --git a/api/api_test.go b/api/api_test.go index 8327a28..515d4da 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -3,6 +3,7 @@ package api import ( "github.com/bitly/go-simplejson" "github.com/bmizerany/assert" + "io/ioutil" "net/http" "net/http/httptest" "strings" @@ -66,3 +67,61 @@ func TestJsonParsingError(t *testing.T) { assert.Equal(t, (*simplejson.Json)(nil), resp) assert.NotEqual(t, nil, err) } + +// Parsing a URL practically never fails, so we won't cover that test case. +func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + token := r.FormValue("access_token") + if r.URL.Path == "/" && token == "my_token" { + w.WriteHeader(200) + w.Write([]byte("some payload")) + } else { + w.WriteHeader(403) + } + })) + defer backend.Close() + + response, err := RequestUnparsedResponse( + backend.URL+"?access_token=my_token", nil) + assert.Equal(t, nil, err) + assert.Equal(t, 200, response.StatusCode) + body, err := ioutil.ReadAll(response.Body) + assert.Equal(t, nil, err) + response.Body.Close() + assert.Equal(t, "some payload", string(body)) +} + +func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) { + backend := testBackend(200, "some payload") + // Close the backend now to force a request failure. + backend.Close() + + response, err := RequestUnparsedResponse( + backend.URL+"?access_token=my_token", nil) + assert.NotEqual(t, nil, err) + assert.Equal(t, (*http.Response)(nil), response) +} + +func TestRequestUnparsedResponseUsingHeaders(t *testing.T) { + backend := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" { + w.WriteHeader(200) + w.Write([]byte("some payload")) + } else { + w.WriteHeader(403) + } + })) + defer backend.Close() + + headers := make(http.Header) + headers.Set("Auth", "my_token") + response, err := RequestUnparsedResponse(backend.URL, headers) + assert.Equal(t, nil, err) + assert.Equal(t, 200, response.StatusCode) + body, err := ioutil.ReadAll(response.Body) + assert.Equal(t, nil, err) + response.Body.Close() + assert.Equal(t, "some payload", string(body)) +} diff --git a/oauthproxy.go b/oauthproxy.go index e24a9a0..7738c35 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -265,27 +265,6 @@ func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val st http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire)) } -func (p *OauthProxy) ValidateToken(access_token string) bool { - if access_token == "" || p.oauthValidateUrl == nil { - return false - } - - req, err := http.NewRequest("GET", - p.oauthValidateUrl.String()+"?access_token="+access_token, nil) - if err != nil { - log.Printf("failed building token validation request: %s", err) - return false - } - - httpclient := &http.Client{} - resp, err := httpclient.Do(req) - if err != nil { - log.Printf("token validation request failed: %s", err) - return false - } - return resp.StatusCode == 200 -} - func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (email, user, access_token string, ok bool) { var value string var timestamp time.Time @@ -304,7 +283,7 @@ func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (e expires := timestamp.Add(p.CookieExpire) refresh_threshold := time.Now().Add(p.CookieRefresh) if refresh_threshold.Unix() > expires.Unix() { - ok = p.Validator(email) && p.ValidateToken(access_token) + ok = p.Validator(email) && p.provider.ValidateToken(access_token) if ok { p.SetCookie(rw, req, value) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index fac0457..409f1d5 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -86,6 +86,7 @@ func TestRobotsTxt(t *testing.T) { type TestProvider struct { *providers.ProviderData EmailAddress string + ValidToken bool } func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json, @@ -93,6 +94,10 @@ func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json, return tp.EmailAddress, nil } +func (tp *TestProvider) ValidateToken(access_token string) bool { + return tp.ValidToken +} + type PassAccessTokenTest struct { provider_server *httptest.Server proxy *OauthProxy @@ -322,101 +327,21 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { } } -type ValidateTokenTest struct { - opts *Options - proxy *OauthProxy - backend *httptest.Server - response_code int -} - -func NewValidateTokenTest() *ValidateTokenTest { - var vt_test ValidateTokenTest - - vt_test.opts = NewOptions() - vt_test.opts.Upstreams = append(vt_test.opts.Upstreams, "unused") - vt_test.opts.CookieSecret = "foobar" - vt_test.opts.ClientID = "bazquux" - vt_test.opts.ClientSecret = "xyzzyplugh" - vt_test.opts.Validate() - - vt_test.backend = httptest.NewServer( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.URL.Path { - case "/oauth/tokeninfo": - w.WriteHeader(vt_test.response_code) - w.Write([]byte("only code matters; contents disregarded")) - default: - w.WriteHeader(500) - w.Write([]byte("unknown URL")) - } - })) - backend_url, _ := url.Parse(vt_test.backend.URL) - vt_test.opts.provider.Data().ValidateUrl = &url.URL{ - Scheme: "http", - Host: backend_url.Host, - Path: "/oauth/tokeninfo", - } - vt_test.response_code = 200 - - vt_test.proxy = NewOauthProxy(vt_test.opts, func(email string) bool { - return true - }) - return &vt_test -} - -func (vt_test *ValidateTokenTest) Close() { - vt_test.backend.Close() -} - -func TestValidateTokenEmptyToken(t *testing.T) { - vt_test := NewValidateTokenTest() - defer vt_test.Close() - - assert.Equal(t, false, vt_test.proxy.ValidateToken("")) -} - -func TestValidateTokenEmptyValidateUrl(t *testing.T) { - vt_test := NewValidateTokenTest() - defer vt_test.Close() - - vt_test.proxy.oauthValidateUrl = nil - assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar")) -} - -func TestValidateTokenRequestNetworkFailure(t *testing.T) { - vt_test := NewValidateTokenTest() - // Close immediately to simulate a network failure - vt_test.Close() - - assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar")) -} - -func TestValidateTokenExpiredToken(t *testing.T) { - vt_test := NewValidateTokenTest() - defer vt_test.Close() - - vt_test.response_code = 401 - assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar")) -} - -func TestValidateTokenValidToken(t *testing.T) { - vt_test := NewValidateTokenTest() - defer vt_test.Close() - - assert.Equal(t, true, vt_test.proxy.ValidateToken("foobar")) -} - type ProcessCookieTest struct { opts *Options proxy *OauthProxy rw *httptest.ResponseRecorder req *http.Request - backend *httptest.Server + provider TestProvider response_code int validate_user bool } -func NewProcessCookieTest() *ProcessCookieTest { +type ProcessCookieTestOpts struct { + provider_validate_cookie_response bool +} + +func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { var pc_test ProcessCookieTest pc_test.opts = NewOptions() @@ -433,6 +358,9 @@ func NewProcessCookieTest() *ProcessCookieTest { pc_test.proxy = NewOauthProxy(pc_test.opts, func(email string) bool { return pc_test.validate_user }) + pc_test.proxy.provider = &TestProvider{ + ValidToken: opts.provider_validate_cookie_response, + } // Now, zero-out proxy.CookieRefresh for the cases that don't involve // access_token validation. @@ -443,22 +371,10 @@ func NewProcessCookieTest() *ProcessCookieTest { return &pc_test } -func (p *ProcessCookieTest) InstantiateBackend() { - p.backend = httptest.NewServer( - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(p.response_code) - })) - backend_url, _ := url.Parse(p.backend.URL) - p.proxy.oauthValidateUrl = &url.URL{ - Scheme: "http", - Host: backend_url.Host, - Path: "/oauth/tokeninfo", - } - p.response_code = 200 -} - -func (p *ProcessCookieTest) Close() { - p.backend.Close() +func NewProcessCookieTestWithDefaults() *ProcessCookieTest { + return NewProcessCookieTest(ProcessCookieTestOpts{ + provider_validate_cookie_response: true, + }) } func (p *ProcessCookieTest) MakeCookie(value, access_token string) *http.Cookie { @@ -476,7 +392,7 @@ func (p *ProcessCookieTest) ProcessCookie() (email, user, access_token string, o } func TestProcessCookie(t *testing.T) { - pc_test := NewProcessCookieTest() + pc_test := NewProcessCookieTestWithDefaults() pc_test.AddCookie("michael.bland@gsa.gov", "my_access_token") email, user, access_token, ok := pc_test.ProcessCookie() @@ -487,13 +403,13 @@ func TestProcessCookie(t *testing.T) { } func TestProcessCookieNoCookieError(t *testing.T) { - pc_test := NewProcessCookieTest() + pc_test := NewProcessCookieTestWithDefaults() _, _, _, ok := pc_test.ProcessCookie() assert.Equal(t, false, ok) } func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) { - pc_test := NewProcessCookieTest() + pc_test := NewProcessCookieTestWithDefaults() value, _ := buildCookieValue("michael.bland@gsa.gov", pc_test.proxy.AesCipher, "my_access_token") pc_test.req.AddCookie(pc_test.proxy.MakeCookie( @@ -504,10 +420,7 @@ func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) { } func TestProcessCookieRefreshNotSet(t *testing.T) { - pc_test := NewProcessCookieTest() - pc_test.InstantiateBackend() - defer pc_test.Close() - + pc_test := NewProcessCookieTestWithDefaults() pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "") pc_test.req.AddCookie(cookie) @@ -518,10 +431,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) { } func TestProcessCookieRefresh(t *testing.T) { - pc_test := NewProcessCookieTest() - pc_test.InstantiateBackend() - defer pc_test.Close() - + pc_test := NewProcessCookieTestWithDefaults() pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token") pc_test.req.AddCookie(cookie) @@ -533,10 +443,7 @@ func TestProcessCookieRefresh(t *testing.T) { } func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) { - pc_test := NewProcessCookieTest() - pc_test.InstantiateBackend() - defer pc_test.Close() - + pc_test := NewProcessCookieTestWithDefaults() pc_test.proxy.CookieExpire = time.Duration(25) * time.Hour cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token") pc_test.req.AddCookie(cookie) @@ -548,11 +455,9 @@ func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) { } func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) { - pc_test := NewProcessCookieTest() - pc_test.InstantiateBackend() - defer pc_test.Close() - pc_test.response_code = 401 - + pc_test := NewProcessCookieTest(ProcessCookieTestOpts{ + provider_validate_cookie_response: false, + }) pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token") pc_test.req.AddCookie(cookie) @@ -564,9 +469,7 @@ func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) { } func TestProcessCookieFailIfRefreshSetAndUserNoLongerValid(t *testing.T) { - pc_test := NewProcessCookieTest() - pc_test.InstantiateBackend() - defer pc_test.Close() + pc_test := NewProcessCookieTestWithDefaults() pc_test.validate_user = false pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour diff --git a/providers/google.go b/providers/google.go index 5fc94be..aa162d9 100644 --- a/providers/google.go +++ b/providers/google.go @@ -66,3 +66,7 @@ func jwtDecodeSegment(seg string) ([]byte, error) { return base64.URLEncoding.DecodeString(seg) } + +func (p *GoogleProvider) ValidateToken(access_token string) bool { + return validateToken(p, access_token, nil) +} diff --git a/providers/internal_util.go b/providers/internal_util.go new file mode 100644 index 0000000..4955430 --- /dev/null +++ b/providers/internal_util.go @@ -0,0 +1,24 @@ +package providers + +import ( + "github.com/bitly/google_auth_proxy/api" + "log" + "net/http" +) + +func validateToken(p Provider, access_token string, + header http.Header) bool { + if access_token == "" || p.Data().ValidateUrl == nil { + return false + } + url := p.Data().ValidateUrl.String() + if len(header) == 0 { + url = url + "?access_token=" + access_token + } + if resp, err := api.RequestUnparsedResponse(url, header); err != nil { + log.Printf("token validation request failed: %s", err) + return false + } else { + return resp.StatusCode == 200 + } +} diff --git a/providers/internal_util_test.go b/providers/internal_util_test.go new file mode 100644 index 0000000..39b8ab3 --- /dev/null +++ b/providers/internal_util_test.go @@ -0,0 +1,122 @@ +package providers + +import ( + "github.com/bitly/go-simplejson" + "github.com/bmizerany/assert" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +type ValidateTokenTestProvider struct { + *ProviderData +} + +func (tp *ValidateTokenTestProvider) GetEmailAddress( + unused_auth_response *simplejson.Json, + unused_access_token string) (string, error) { + return "", nil +} + +// Note that we're testing the internal validateToken() used to implement +// several Provider's ValidateToken() implementations +func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) bool { + return false +} + +type ValidateTokenTest struct { + backend *httptest.Server + response_code int + provider *ValidateTokenTestProvider + header http.Header +} + +func NewValidateTokenTest() *ValidateTokenTest { + var vt_test ValidateTokenTest + + vt_test.backend = httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/tokeninfo" { + w.WriteHeader(500) + w.Write([]byte("unknown URL")) + } + token_param := r.FormValue("access_token") + if token_param == "" { + missing := false + received_headers := r.Header + for k, _ := range vt_test.header { + received := received_headers.Get(k) + expected := vt_test.header.Get(k) + if received == "" || received != expected { + missing = true + } + } + if missing { + w.WriteHeader(500) + w.Write([]byte("no token param and missing or incorrect headers")) + } + } + w.WriteHeader(vt_test.response_code) + w.Write([]byte("only code matters; contents disregarded")) + + })) + backend_url, _ := url.Parse(vt_test.backend.URL) + vt_test.provider = &ValidateTokenTestProvider{ + ProviderData: &ProviderData{ + ValidateUrl: &url.URL{ + Scheme: "http", + Host: backend_url.Host, + Path: "/oauth/tokeninfo", + }, + }, + } + vt_test.response_code = 200 + return &vt_test +} + +func (vt_test *ValidateTokenTest) Close() { + vt_test.backend.Close() +} + +func TestValidateTokenValidToken(t *testing.T) { + vt_test := NewValidateTokenTest() + defer vt_test.Close() + assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil)) +} + +func TestValidateTokenValidTokenWithHeaders(t *testing.T) { + vt_test := NewValidateTokenTest() + defer vt_test.Close() + vt_test.header = make(http.Header) + vt_test.header.Set("Authorization", "Bearer foobar") + assert.Equal(t, true, + validateToken(vt_test.provider, "foobar", vt_test.header)) +} + +func TestValidateTokenEmptyToken(t *testing.T) { + vt_test := NewValidateTokenTest() + defer vt_test.Close() + assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) +} + +func TestValidateTokenEmptyValidateUrl(t *testing.T) { + vt_test := NewValidateTokenTest() + defer vt_test.Close() + vt_test.provider.Data().ValidateUrl = nil + assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) +} + +func TestValidateTokenRequestNetworkFailure(t *testing.T) { + vt_test := NewValidateTokenTest() + // Close immediately to simulate a network failure + vt_test.Close() + assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) +} + +func TestValidateTokenExpiredToken(t *testing.T) { + vt_test := NewValidateTokenTest() + defer vt_test.Close() + vt_test.response_code = 401 + assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) +} diff --git a/providers/linkedin.go b/providers/linkedin.go index 539cd97..4eea3f7 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -33,12 +33,23 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { Host: "www.linkedin.com", Path: "/v1/people/~/email-address"} } + if p.ValidateUrl.String() == "" { + p.ValidateUrl = p.ProfileUrl + } if p.Scope == "" { p.Scope = "r_emailaddress r_basicprofile" } return &LinkedInProvider{ProviderData: p} } +func getLinkedInHeader(access_token string) http.Header { + header := make(http.Header) + header.Set("Accept", "application/json") + header.Set("x-li-format", "json") + header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) + return header +} + func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json, access_token string) (string, error) { if access_token == "" { @@ -49,9 +60,7 @@ func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json if err != nil { return "", err } - req.Header.Set("Accept", "application/json") - req.Header.Set("x-li-format", "json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) + req.Header = getLinkedInHeader(access_token) json, err := api.Request(req) if err != nil { @@ -66,3 +75,7 @@ func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json } return email, nil } + +func (p *LinkedInProvider) ValidateToken(access_token string) bool { + return validateToken(p, access_token, getLinkedInHeader(access_token)) +} diff --git a/providers/linkedin_test.go b/providers/linkedin_test.go index 5aa4353..be8d05e 100644 --- a/providers/linkedin_test.go +++ b/providers/linkedin_test.go @@ -16,6 +16,7 @@ func testLinkedInProvider(hostname string) *LinkedInProvider { LoginUrl: &url.URL{}, RedeemUrl: &url.URL{}, ProfileUrl: &url.URL{}, + ValidateUrl: &url.URL{}, Scope: ""}) if hostname != "" { updateUrl(p.Data().LoginUrl, hostname) @@ -52,6 +53,8 @@ func TestLinkedInProviderDefaults(t *testing.T) { p.Data().RedeemUrl.String()) assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address", p.Data().ProfileUrl.String()) + assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address", + p.Data().ValidateUrl.String()) assert.Equal(t, "r_emailaddress r_basicprofile", p.Data().Scope) } @@ -70,6 +73,10 @@ func TestLinkedInProviderOverrides(t *testing.T) { Scheme: "https", Host: "example.com", Path: "/oauth/profile"}, + ValidateUrl: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/tokeninfo"}, Scope: "profile"}) assert.NotEqual(t, nil, p) assert.Equal(t, "LinkedIn", p.Data().ProviderName) @@ -79,6 +86,8 @@ func TestLinkedInProviderOverrides(t *testing.T) { p.Data().RedeemUrl.String()) assert.Equal(t, "https://example.com/oauth/profile", p.Data().ProfileUrl.String()) + assert.Equal(t, "https://example.com/oauth/tokeninfo", + p.Data().ValidateUrl.String()) assert.Equal(t, "profile", p.Data().Scope) } diff --git a/providers/myusa.go b/providers/myusa.go index 69014ba..05ba33f 100644 --- a/providers/myusa.go +++ b/providers/myusa.go @@ -58,3 +58,7 @@ func (p *MyUsaProvider) GetEmailAddress(auth_response *simplejson.Json, } return json.Get("email").String() } + +func (p *MyUsaProvider) ValidateToken(access_token string) bool { + return validateToken(p, access_token, nil) +} diff --git a/providers/providers.go b/providers/providers.go index 486668b..8f5091a 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -8,6 +8,7 @@ type Provider interface { Data() *ProviderData GetEmailAddress(auth_response *simplejson.Json, access_token string) (string, error) + ValidateToken(access_token string) bool } func New(provider string, p *ProviderData) Provider {