Move ValidateToken() to Provider

This commit is contained in:
Mike Bland 2015-05-12 21:48:13 -04:00 committed by Jehiah Czebotar
parent aca1fe81f4
commit 8471f972e1
11 changed files with 285 additions and 150 deletions

View File

@ -30,3 +30,20 @@ func Request(req *http.Request) (*simplejson.Json, error) {
}
return data, nil
}
func RequestUnparsedResponse(url string, header http.Header) (
response *http.Response, err error) {
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, errors.New("failed building request for " +
url + ": " + err.Error())
}
req.Header = header
httpclient := &http.Client{}
if response, err = httpclient.Do(req); err != nil {
return nil, errors.New("request failed for " +
url + ": " + err.Error())
}
return
}

View File

@ -3,6 +3,7 @@ package api
import (
"github.com/bitly/go-simplejson"
"github.com/bmizerany/assert"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
@ -66,3 +67,61 @@ func TestJsonParsingError(t *testing.T) {
assert.Equal(t, (*simplejson.Json)(nil), resp)
assert.NotEqual(t, nil, err)
}
// Parsing a URL practically never fails, so we won't cover that test case.
func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
token := r.FormValue("access_token")
if r.URL.Path == "/" && token == "my_token" {
w.WriteHeader(200)
w.Write([]byte("some payload"))
} else {
w.WriteHeader(403)
}
}))
defer backend.Close()
response, err := RequestUnparsedResponse(
backend.URL+"?access_token=my_token", nil)
assert.Equal(t, nil, err)
assert.Equal(t, 200, response.StatusCode)
body, err := ioutil.ReadAll(response.Body)
assert.Equal(t, nil, err)
response.Body.Close()
assert.Equal(t, "some payload", string(body))
}
func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) {
backend := testBackend(200, "some payload")
// Close the backend now to force a request failure.
backend.Close()
response, err := RequestUnparsedResponse(
backend.URL+"?access_token=my_token", nil)
assert.NotEqual(t, nil, err)
assert.Equal(t, (*http.Response)(nil), response)
}
func TestRequestUnparsedResponseUsingHeaders(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" {
w.WriteHeader(200)
w.Write([]byte("some payload"))
} else {
w.WriteHeader(403)
}
}))
defer backend.Close()
headers := make(http.Header)
headers.Set("Auth", "my_token")
response, err := RequestUnparsedResponse(backend.URL, headers)
assert.Equal(t, nil, err)
assert.Equal(t, 200, response.StatusCode)
body, err := ioutil.ReadAll(response.Body)
assert.Equal(t, nil, err)
response.Body.Close()
assert.Equal(t, "some payload", string(body))
}

View File

@ -265,27 +265,6 @@ func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val st
http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire))
}
func (p *OauthProxy) ValidateToken(access_token string) bool {
if access_token == "" || p.oauthValidateUrl == nil {
return false
}
req, err := http.NewRequest("GET",
p.oauthValidateUrl.String()+"?access_token="+access_token, nil)
if err != nil {
log.Printf("failed building token validation request: %s", err)
return false
}
httpclient := &http.Client{}
resp, err := httpclient.Do(req)
if err != nil {
log.Printf("token validation request failed: %s", err)
return false
}
return resp.StatusCode == 200
}
func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (email, user, access_token string, ok bool) {
var value string
var timestamp time.Time
@ -304,7 +283,7 @@ func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (e
expires := timestamp.Add(p.CookieExpire)
refresh_threshold := time.Now().Add(p.CookieRefresh)
if refresh_threshold.Unix() > expires.Unix() {
ok = p.Validator(email) && p.ValidateToken(access_token)
ok = p.Validator(email) && p.provider.ValidateToken(access_token)
if ok {
p.SetCookie(rw, req, value)
}

View File

@ -86,6 +86,7 @@ func TestRobotsTxt(t *testing.T) {
type TestProvider struct {
*providers.ProviderData
EmailAddress string
ValidToken bool
}
func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
@ -93,6 +94,10 @@ func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
return tp.EmailAddress, nil
}
func (tp *TestProvider) ValidateToken(access_token string) bool {
return tp.ValidToken
}
type PassAccessTokenTest struct {
provider_server *httptest.Server
proxy *OauthProxy
@ -322,101 +327,21 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
}
}
type ValidateTokenTest struct {
opts *Options
proxy *OauthProxy
backend *httptest.Server
response_code int
}
func NewValidateTokenTest() *ValidateTokenTest {
var vt_test ValidateTokenTest
vt_test.opts = NewOptions()
vt_test.opts.Upstreams = append(vt_test.opts.Upstreams, "unused")
vt_test.opts.CookieSecret = "foobar"
vt_test.opts.ClientID = "bazquux"
vt_test.opts.ClientSecret = "xyzzyplugh"
vt_test.opts.Validate()
vt_test.backend = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/oauth/tokeninfo":
w.WriteHeader(vt_test.response_code)
w.Write([]byte("only code matters; contents disregarded"))
default:
w.WriteHeader(500)
w.Write([]byte("unknown URL"))
}
}))
backend_url, _ := url.Parse(vt_test.backend.URL)
vt_test.opts.provider.Data().ValidateUrl = &url.URL{
Scheme: "http",
Host: backend_url.Host,
Path: "/oauth/tokeninfo",
}
vt_test.response_code = 200
vt_test.proxy = NewOauthProxy(vt_test.opts, func(email string) bool {
return true
})
return &vt_test
}
func (vt_test *ValidateTokenTest) Close() {
vt_test.backend.Close()
}
func TestValidateTokenEmptyToken(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
assert.Equal(t, false, vt_test.proxy.ValidateToken(""))
}
func TestValidateTokenEmptyValidateUrl(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
vt_test.proxy.oauthValidateUrl = nil
assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar"))
}
func TestValidateTokenRequestNetworkFailure(t *testing.T) {
vt_test := NewValidateTokenTest()
// Close immediately to simulate a network failure
vt_test.Close()
assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar"))
}
func TestValidateTokenExpiredToken(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
vt_test.response_code = 401
assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar"))
}
func TestValidateTokenValidToken(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
assert.Equal(t, true, vt_test.proxy.ValidateToken("foobar"))
}
type ProcessCookieTest struct {
opts *Options
proxy *OauthProxy
rw *httptest.ResponseRecorder
req *http.Request
backend *httptest.Server
provider TestProvider
response_code int
validate_user bool
}
func NewProcessCookieTest() *ProcessCookieTest {
type ProcessCookieTestOpts struct {
provider_validate_cookie_response bool
}
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
var pc_test ProcessCookieTest
pc_test.opts = NewOptions()
@ -433,6 +358,9 @@ func NewProcessCookieTest() *ProcessCookieTest {
pc_test.proxy = NewOauthProxy(pc_test.opts, func(email string) bool {
return pc_test.validate_user
})
pc_test.proxy.provider = &TestProvider{
ValidToken: opts.provider_validate_cookie_response,
}
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
// access_token validation.
@ -443,22 +371,10 @@ func NewProcessCookieTest() *ProcessCookieTest {
return &pc_test
}
func (p *ProcessCookieTest) InstantiateBackend() {
p.backend = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(p.response_code)
}))
backend_url, _ := url.Parse(p.backend.URL)
p.proxy.oauthValidateUrl = &url.URL{
Scheme: "http",
Host: backend_url.Host,
Path: "/oauth/tokeninfo",
}
p.response_code = 200
}
func (p *ProcessCookieTest) Close() {
p.backend.Close()
func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
return NewProcessCookieTest(ProcessCookieTestOpts{
provider_validate_cookie_response: true,
})
}
func (p *ProcessCookieTest) MakeCookie(value, access_token string) *http.Cookie {
@ -476,7 +392,7 @@ func (p *ProcessCookieTest) ProcessCookie() (email, user, access_token string, o
}
func TestProcessCookie(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test := NewProcessCookieTestWithDefaults()
pc_test.AddCookie("michael.bland@gsa.gov", "my_access_token")
email, user, access_token, ok := pc_test.ProcessCookie()
@ -487,13 +403,13 @@ func TestProcessCookie(t *testing.T) {
}
func TestProcessCookieNoCookieError(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test := NewProcessCookieTestWithDefaults()
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, false, ok)
}
func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test := NewProcessCookieTestWithDefaults()
value, _ := buildCookieValue("michael.bland@gsa.gov",
pc_test.proxy.AesCipher, "my_access_token")
pc_test.req.AddCookie(pc_test.proxy.MakeCookie(
@ -504,10 +420,7 @@ func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) {
}
func TestProcessCookieRefreshNotSet(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.InstantiateBackend()
defer pc_test.Close()
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "")
pc_test.req.AddCookie(cookie)
@ -518,10 +431,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
}
func TestProcessCookieRefresh(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.InstantiateBackend()
defer pc_test.Close()
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
pc_test.req.AddCookie(cookie)
@ -533,10 +443,7 @@ func TestProcessCookieRefresh(t *testing.T) {
}
func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.InstantiateBackend()
defer pc_test.Close()
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(25) * time.Hour
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
pc_test.req.AddCookie(cookie)
@ -548,11 +455,9 @@ func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) {
}
func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.InstantiateBackend()
defer pc_test.Close()
pc_test.response_code = 401
pc_test := NewProcessCookieTest(ProcessCookieTestOpts{
provider_validate_cookie_response: false,
})
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
pc_test.req.AddCookie(cookie)
@ -564,9 +469,7 @@ func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) {
}
func TestProcessCookieFailIfRefreshSetAndUserNoLongerValid(t *testing.T) {
pc_test := NewProcessCookieTest()
pc_test.InstantiateBackend()
defer pc_test.Close()
pc_test := NewProcessCookieTestWithDefaults()
pc_test.validate_user = false
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour

View File

@ -66,3 +66,7 @@ func jwtDecodeSegment(seg string) ([]byte, error) {
return base64.URLEncoding.DecodeString(seg)
}
func (p *GoogleProvider) ValidateToken(access_token string) bool {
return validateToken(p, access_token, nil)
}

View File

@ -0,0 +1,24 @@
package providers
import (
"github.com/bitly/google_auth_proxy/api"
"log"
"net/http"
)
func validateToken(p Provider, access_token string,
header http.Header) bool {
if access_token == "" || p.Data().ValidateUrl == nil {
return false
}
url := p.Data().ValidateUrl.String()
if len(header) == 0 {
url = url + "?access_token=" + access_token
}
if resp, err := api.RequestUnparsedResponse(url, header); err != nil {
log.Printf("token validation request failed: %s", err)
return false
} else {
return resp.StatusCode == 200
}
}

View File

@ -0,0 +1,122 @@
package providers
import (
"github.com/bitly/go-simplejson"
"github.com/bmizerany/assert"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
type ValidateTokenTestProvider struct {
*ProviderData
}
func (tp *ValidateTokenTestProvider) GetEmailAddress(
unused_auth_response *simplejson.Json,
unused_access_token string) (string, error) {
return "", nil
}
// Note that we're testing the internal validateToken() used to implement
// several Provider's ValidateToken() implementations
func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) bool {
return false
}
type ValidateTokenTest struct {
backend *httptest.Server
response_code int
provider *ValidateTokenTestProvider
header http.Header
}
func NewValidateTokenTest() *ValidateTokenTest {
var vt_test ValidateTokenTest
vt_test.backend = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/tokeninfo" {
w.WriteHeader(500)
w.Write([]byte("unknown URL"))
}
token_param := r.FormValue("access_token")
if token_param == "" {
missing := false
received_headers := r.Header
for k, _ := range vt_test.header {
received := received_headers.Get(k)
expected := vt_test.header.Get(k)
if received == "" || received != expected {
missing = true
}
}
if missing {
w.WriteHeader(500)
w.Write([]byte("no token param and missing or incorrect headers"))
}
}
w.WriteHeader(vt_test.response_code)
w.Write([]byte("only code matters; contents disregarded"))
}))
backend_url, _ := url.Parse(vt_test.backend.URL)
vt_test.provider = &ValidateTokenTestProvider{
ProviderData: &ProviderData{
ValidateUrl: &url.URL{
Scheme: "http",
Host: backend_url.Host,
Path: "/oauth/tokeninfo",
},
},
}
vt_test.response_code = 200
return &vt_test
}
func (vt_test *ValidateTokenTest) Close() {
vt_test.backend.Close()
}
func TestValidateTokenValidToken(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil))
}
func TestValidateTokenValidTokenWithHeaders(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
vt_test.header = make(http.Header)
vt_test.header.Set("Authorization", "Bearer foobar")
assert.Equal(t, true,
validateToken(vt_test.provider, "foobar", vt_test.header))
}
func TestValidateTokenEmptyToken(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
assert.Equal(t, false, validateToken(vt_test.provider, "", nil))
}
func TestValidateTokenEmptyValidateUrl(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
vt_test.provider.Data().ValidateUrl = nil
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
}
func TestValidateTokenRequestNetworkFailure(t *testing.T) {
vt_test := NewValidateTokenTest()
// Close immediately to simulate a network failure
vt_test.Close()
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
}
func TestValidateTokenExpiredToken(t *testing.T) {
vt_test := NewValidateTokenTest()
defer vt_test.Close()
vt_test.response_code = 401
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
}

View File

@ -33,12 +33,23 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
Host: "www.linkedin.com",
Path: "/v1/people/~/email-address"}
}
if p.ValidateUrl.String() == "" {
p.ValidateUrl = p.ProfileUrl
}
if p.Scope == "" {
p.Scope = "r_emailaddress r_basicprofile"
}
return &LinkedInProvider{ProviderData: p}
}
func getLinkedInHeader(access_token string) http.Header {
header := make(http.Header)
header.Set("Accept", "application/json")
header.Set("x-li-format", "json")
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
return header
}
func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
access_token string) (string, error) {
if access_token == "" {
@ -49,9 +60,7 @@ func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json
if err != nil {
return "", err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("x-li-format", "json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
req.Header = getLinkedInHeader(access_token)
json, err := api.Request(req)
if err != nil {
@ -66,3 +75,7 @@ func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json
}
return email, nil
}
func (p *LinkedInProvider) ValidateToken(access_token string) bool {
return validateToken(p, access_token, getLinkedInHeader(access_token))
}

View File

@ -16,6 +16,7 @@ func testLinkedInProvider(hostname string) *LinkedInProvider {
LoginUrl: &url.URL{},
RedeemUrl: &url.URL{},
ProfileUrl: &url.URL{},
ValidateUrl: &url.URL{},
Scope: ""})
if hostname != "" {
updateUrl(p.Data().LoginUrl, hostname)
@ -52,6 +53,8 @@ func TestLinkedInProviderDefaults(t *testing.T) {
p.Data().RedeemUrl.String())
assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address",
p.Data().ProfileUrl.String())
assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address",
p.Data().ValidateUrl.String())
assert.Equal(t, "r_emailaddress r_basicprofile", p.Data().Scope)
}
@ -70,6 +73,10 @@ func TestLinkedInProviderOverrides(t *testing.T) {
Scheme: "https",
Host: "example.com",
Path: "/oauth/profile"},
ValidateUrl: &url.URL{
Scheme: "https",
Host: "example.com",
Path: "/oauth/tokeninfo"},
Scope: "profile"})
assert.NotEqual(t, nil, p)
assert.Equal(t, "LinkedIn", p.Data().ProviderName)
@ -79,6 +86,8 @@ func TestLinkedInProviderOverrides(t *testing.T) {
p.Data().RedeemUrl.String())
assert.Equal(t, "https://example.com/oauth/profile",
p.Data().ProfileUrl.String())
assert.Equal(t, "https://example.com/oauth/tokeninfo",
p.Data().ValidateUrl.String())
assert.Equal(t, "profile", p.Data().Scope)
}

View File

@ -58,3 +58,7 @@ func (p *MyUsaProvider) GetEmailAddress(auth_response *simplejson.Json,
}
return json.Get("email").String()
}
func (p *MyUsaProvider) ValidateToken(access_token string) bool {
return validateToken(p, access_token, nil)
}

View File

@ -8,6 +8,7 @@ type Provider interface {
Data() *ProviderData
GetEmailAddress(auth_response *simplejson.Json,
access_token string) (string, error)
ValidateToken(access_token string) bool
}
func New(provider string, p *ProviderData) Provider {