Move ValidateToken() to Provider
This commit is contained in:
parent
aca1fe81f4
commit
8471f972e1
17
api/api.go
17
api/api.go
@ -30,3 +30,20 @@ func Request(req *http.Request) (*simplejson.Json, error) {
|
|||||||
}
|
}
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func RequestUnparsedResponse(url string, header http.Header) (
|
||||||
|
response *http.Response, err error) {
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New("failed building request for " +
|
||||||
|
url + ": " + err.Error())
|
||||||
|
}
|
||||||
|
req.Header = header
|
||||||
|
|
||||||
|
httpclient := &http.Client{}
|
||||||
|
if response, err = httpclient.Do(req); err != nil {
|
||||||
|
return nil, errors.New("request failed for " +
|
||||||
|
url + ": " + err.Error())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
@ -3,6 +3,7 @@ package api
|
|||||||
import (
|
import (
|
||||||
"github.com/bitly/go-simplejson"
|
"github.com/bitly/go-simplejson"
|
||||||
"github.com/bmizerany/assert"
|
"github.com/bmizerany/assert"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
@ -66,3 +67,61 @@ func TestJsonParsingError(t *testing.T) {
|
|||||||
assert.Equal(t, (*simplejson.Json)(nil), resp)
|
assert.Equal(t, (*simplejson.Json)(nil), resp)
|
||||||
assert.NotEqual(t, nil, err)
|
assert.NotEqual(t, nil, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parsing a URL practically never fails, so we won't cover that test case.
|
||||||
|
func TestRequestUnparsedResponseUsingAccessTokenParameter(t *testing.T) {
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
token := r.FormValue("access_token")
|
||||||
|
if r.URL.Path == "/" && token == "my_token" {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte("some payload"))
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(403)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
response, err := RequestUnparsedResponse(
|
||||||
|
backend.URL+"?access_token=my_token", nil)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, 200, response.StatusCode)
|
||||||
|
body, err := ioutil.ReadAll(response.Body)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
response.Body.Close()
|
||||||
|
assert.Equal(t, "some payload", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestUnparsedResponseUsingAccessTokenParameterFailedResponse(t *testing.T) {
|
||||||
|
backend := testBackend(200, "some payload")
|
||||||
|
// Close the backend now to force a request failure.
|
||||||
|
backend.Close()
|
||||||
|
|
||||||
|
response, err := RequestUnparsedResponse(
|
||||||
|
backend.URL+"?access_token=my_token", nil)
|
||||||
|
assert.NotEqual(t, nil, err)
|
||||||
|
assert.Equal(t, (*http.Response)(nil), response)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequestUnparsedResponseUsingHeaders(t *testing.T) {
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path == "/" && r.Header["Auth"][0] == "my_token" {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
w.Write([]byte("some payload"))
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(403)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
headers := make(http.Header)
|
||||||
|
headers.Set("Auth", "my_token")
|
||||||
|
response, err := RequestUnparsedResponse(backend.URL, headers)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.Equal(t, 200, response.StatusCode)
|
||||||
|
body, err := ioutil.ReadAll(response.Body)
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
response.Body.Close()
|
||||||
|
assert.Equal(t, "some payload", string(body))
|
||||||
|
}
|
||||||
|
@ -265,27 +265,6 @@ func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val st
|
|||||||
http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire))
|
http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OauthProxy) ValidateToken(access_token string) bool {
|
|
||||||
if access_token == "" || p.oauthValidateUrl == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequest("GET",
|
|
||||||
p.oauthValidateUrl.String()+"?access_token="+access_token, nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("failed building token validation request: %s", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
httpclient := &http.Client{}
|
|
||||||
resp, err := httpclient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("token validation request failed: %s", err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return resp.StatusCode == 200
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (email, user, access_token string, ok bool) {
|
func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (email, user, access_token string, ok bool) {
|
||||||
var value string
|
var value string
|
||||||
var timestamp time.Time
|
var timestamp time.Time
|
||||||
@ -304,7 +283,7 @@ func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (e
|
|||||||
expires := timestamp.Add(p.CookieExpire)
|
expires := timestamp.Add(p.CookieExpire)
|
||||||
refresh_threshold := time.Now().Add(p.CookieRefresh)
|
refresh_threshold := time.Now().Add(p.CookieRefresh)
|
||||||
if refresh_threshold.Unix() > expires.Unix() {
|
if refresh_threshold.Unix() > expires.Unix() {
|
||||||
ok = p.Validator(email) && p.ValidateToken(access_token)
|
ok = p.Validator(email) && p.provider.ValidateToken(access_token)
|
||||||
if ok {
|
if ok {
|
||||||
p.SetCookie(rw, req, value)
|
p.SetCookie(rw, req, value)
|
||||||
}
|
}
|
||||||
|
@ -86,6 +86,7 @@ func TestRobotsTxt(t *testing.T) {
|
|||||||
type TestProvider struct {
|
type TestProvider struct {
|
||||||
*providers.ProviderData
|
*providers.ProviderData
|
||||||
EmailAddress string
|
EmailAddress string
|
||||||
|
ValidToken bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
|
func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
|
||||||
@ -93,6 +94,10 @@ func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
|
|||||||
return tp.EmailAddress, nil
|
return tp.EmailAddress, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tp *TestProvider) ValidateToken(access_token string) bool {
|
||||||
|
return tp.ValidToken
|
||||||
|
}
|
||||||
|
|
||||||
type PassAccessTokenTest struct {
|
type PassAccessTokenTest struct {
|
||||||
provider_server *httptest.Server
|
provider_server *httptest.Server
|
||||||
proxy *OauthProxy
|
proxy *OauthProxy
|
||||||
@ -322,101 +327,21 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ValidateTokenTest struct {
|
|
||||||
opts *Options
|
|
||||||
proxy *OauthProxy
|
|
||||||
backend *httptest.Server
|
|
||||||
response_code int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewValidateTokenTest() *ValidateTokenTest {
|
|
||||||
var vt_test ValidateTokenTest
|
|
||||||
|
|
||||||
vt_test.opts = NewOptions()
|
|
||||||
vt_test.opts.Upstreams = append(vt_test.opts.Upstreams, "unused")
|
|
||||||
vt_test.opts.CookieSecret = "foobar"
|
|
||||||
vt_test.opts.ClientID = "bazquux"
|
|
||||||
vt_test.opts.ClientSecret = "xyzzyplugh"
|
|
||||||
vt_test.opts.Validate()
|
|
||||||
|
|
||||||
vt_test.backend = httptest.NewServer(
|
|
||||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch r.URL.Path {
|
|
||||||
case "/oauth/tokeninfo":
|
|
||||||
w.WriteHeader(vt_test.response_code)
|
|
||||||
w.Write([]byte("only code matters; contents disregarded"))
|
|
||||||
default:
|
|
||||||
w.WriteHeader(500)
|
|
||||||
w.Write([]byte("unknown URL"))
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
backend_url, _ := url.Parse(vt_test.backend.URL)
|
|
||||||
vt_test.opts.provider.Data().ValidateUrl = &url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: backend_url.Host,
|
|
||||||
Path: "/oauth/tokeninfo",
|
|
||||||
}
|
|
||||||
vt_test.response_code = 200
|
|
||||||
|
|
||||||
vt_test.proxy = NewOauthProxy(vt_test.opts, func(email string) bool {
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
return &vt_test
|
|
||||||
}
|
|
||||||
|
|
||||||
func (vt_test *ValidateTokenTest) Close() {
|
|
||||||
vt_test.backend.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateTokenEmptyToken(t *testing.T) {
|
|
||||||
vt_test := NewValidateTokenTest()
|
|
||||||
defer vt_test.Close()
|
|
||||||
|
|
||||||
assert.Equal(t, false, vt_test.proxy.ValidateToken(""))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateTokenEmptyValidateUrl(t *testing.T) {
|
|
||||||
vt_test := NewValidateTokenTest()
|
|
||||||
defer vt_test.Close()
|
|
||||||
|
|
||||||
vt_test.proxy.oauthValidateUrl = nil
|
|
||||||
assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateTokenRequestNetworkFailure(t *testing.T) {
|
|
||||||
vt_test := NewValidateTokenTest()
|
|
||||||
// Close immediately to simulate a network failure
|
|
||||||
vt_test.Close()
|
|
||||||
|
|
||||||
assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateTokenExpiredToken(t *testing.T) {
|
|
||||||
vt_test := NewValidateTokenTest()
|
|
||||||
defer vt_test.Close()
|
|
||||||
|
|
||||||
vt_test.response_code = 401
|
|
||||||
assert.Equal(t, false, vt_test.proxy.ValidateToken("foobar"))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestValidateTokenValidToken(t *testing.T) {
|
|
||||||
vt_test := NewValidateTokenTest()
|
|
||||||
defer vt_test.Close()
|
|
||||||
|
|
||||||
assert.Equal(t, true, vt_test.proxy.ValidateToken("foobar"))
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProcessCookieTest struct {
|
type ProcessCookieTest struct {
|
||||||
opts *Options
|
opts *Options
|
||||||
proxy *OauthProxy
|
proxy *OauthProxy
|
||||||
rw *httptest.ResponseRecorder
|
rw *httptest.ResponseRecorder
|
||||||
req *http.Request
|
req *http.Request
|
||||||
backend *httptest.Server
|
provider TestProvider
|
||||||
response_code int
|
response_code int
|
||||||
validate_user bool
|
validate_user bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcessCookieTest() *ProcessCookieTest {
|
type ProcessCookieTestOpts struct {
|
||||||
|
provider_validate_cookie_response bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest {
|
||||||
var pc_test ProcessCookieTest
|
var pc_test ProcessCookieTest
|
||||||
|
|
||||||
pc_test.opts = NewOptions()
|
pc_test.opts = NewOptions()
|
||||||
@ -433,6 +358,9 @@ func NewProcessCookieTest() *ProcessCookieTest {
|
|||||||
pc_test.proxy = NewOauthProxy(pc_test.opts, func(email string) bool {
|
pc_test.proxy = NewOauthProxy(pc_test.opts, func(email string) bool {
|
||||||
return pc_test.validate_user
|
return pc_test.validate_user
|
||||||
})
|
})
|
||||||
|
pc_test.proxy.provider = &TestProvider{
|
||||||
|
ValidToken: opts.provider_validate_cookie_response,
|
||||||
|
}
|
||||||
|
|
||||||
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
// Now, zero-out proxy.CookieRefresh for the cases that don't involve
|
||||||
// access_token validation.
|
// access_token validation.
|
||||||
@ -443,22 +371,10 @@ func NewProcessCookieTest() *ProcessCookieTest {
|
|||||||
return &pc_test
|
return &pc_test
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) InstantiateBackend() {
|
func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
|
||||||
p.backend = httptest.NewServer(
|
return NewProcessCookieTest(ProcessCookieTestOpts{
|
||||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
provider_validate_cookie_response: true,
|
||||||
w.WriteHeader(p.response_code)
|
})
|
||||||
}))
|
|
||||||
backend_url, _ := url.Parse(p.backend.URL)
|
|
||||||
p.proxy.oauthValidateUrl = &url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: backend_url.Host,
|
|
||||||
Path: "/oauth/tokeninfo",
|
|
||||||
}
|
|
||||||
p.response_code = 200
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *ProcessCookieTest) Close() {
|
|
||||||
p.backend.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) MakeCookie(value, access_token string) *http.Cookie {
|
func (p *ProcessCookieTest) MakeCookie(value, access_token string) *http.Cookie {
|
||||||
@ -476,7 +392,7 @@ func (p *ProcessCookieTest) ProcessCookie() (email, user, access_token string, o
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookie(t *testing.T) {
|
func TestProcessCookie(t *testing.T) {
|
||||||
pc_test := NewProcessCookieTest()
|
pc_test := NewProcessCookieTestWithDefaults()
|
||||||
|
|
||||||
pc_test.AddCookie("michael.bland@gsa.gov", "my_access_token")
|
pc_test.AddCookie("michael.bland@gsa.gov", "my_access_token")
|
||||||
email, user, access_token, ok := pc_test.ProcessCookie()
|
email, user, access_token, ok := pc_test.ProcessCookie()
|
||||||
@ -487,13 +403,13 @@ func TestProcessCookie(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookieNoCookieError(t *testing.T) {
|
func TestProcessCookieNoCookieError(t *testing.T) {
|
||||||
pc_test := NewProcessCookieTest()
|
pc_test := NewProcessCookieTestWithDefaults()
|
||||||
_, _, _, ok := pc_test.ProcessCookie()
|
_, _, _, ok := pc_test.ProcessCookie()
|
||||||
assert.Equal(t, false, ok)
|
assert.Equal(t, false, ok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) {
|
func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) {
|
||||||
pc_test := NewProcessCookieTest()
|
pc_test := NewProcessCookieTestWithDefaults()
|
||||||
value, _ := buildCookieValue("michael.bland@gsa.gov",
|
value, _ := buildCookieValue("michael.bland@gsa.gov",
|
||||||
pc_test.proxy.AesCipher, "my_access_token")
|
pc_test.proxy.AesCipher, "my_access_token")
|
||||||
pc_test.req.AddCookie(pc_test.proxy.MakeCookie(
|
pc_test.req.AddCookie(pc_test.proxy.MakeCookie(
|
||||||
@ -504,10 +420,7 @@ func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookieRefreshNotSet(t *testing.T) {
|
func TestProcessCookieRefreshNotSet(t *testing.T) {
|
||||||
pc_test := NewProcessCookieTest()
|
pc_test := NewProcessCookieTestWithDefaults()
|
||||||
pc_test.InstantiateBackend()
|
|
||||||
defer pc_test.Close()
|
|
||||||
|
|
||||||
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
|
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||||
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "")
|
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "")
|
||||||
pc_test.req.AddCookie(cookie)
|
pc_test.req.AddCookie(cookie)
|
||||||
@ -518,10 +431,7 @@ func TestProcessCookieRefreshNotSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookieRefresh(t *testing.T) {
|
func TestProcessCookieRefresh(t *testing.T) {
|
||||||
pc_test := NewProcessCookieTest()
|
pc_test := NewProcessCookieTestWithDefaults()
|
||||||
pc_test.InstantiateBackend()
|
|
||||||
defer pc_test.Close()
|
|
||||||
|
|
||||||
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
|
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||||
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
|
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
|
||||||
pc_test.req.AddCookie(cookie)
|
pc_test.req.AddCookie(cookie)
|
||||||
@ -533,10 +443,7 @@ func TestProcessCookieRefresh(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) {
|
func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) {
|
||||||
pc_test := NewProcessCookieTest()
|
pc_test := NewProcessCookieTestWithDefaults()
|
||||||
pc_test.InstantiateBackend()
|
|
||||||
defer pc_test.Close()
|
|
||||||
|
|
||||||
pc_test.proxy.CookieExpire = time.Duration(25) * time.Hour
|
pc_test.proxy.CookieExpire = time.Duration(25) * time.Hour
|
||||||
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
|
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
|
||||||
pc_test.req.AddCookie(cookie)
|
pc_test.req.AddCookie(cookie)
|
||||||
@ -548,11 +455,9 @@ func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) {
|
func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) {
|
||||||
pc_test := NewProcessCookieTest()
|
pc_test := NewProcessCookieTest(ProcessCookieTestOpts{
|
||||||
pc_test.InstantiateBackend()
|
provider_validate_cookie_response: false,
|
||||||
defer pc_test.Close()
|
})
|
||||||
pc_test.response_code = 401
|
|
||||||
|
|
||||||
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
|
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||||
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
|
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token")
|
||||||
pc_test.req.AddCookie(cookie)
|
pc_test.req.AddCookie(cookie)
|
||||||
@ -564,9 +469,7 @@ func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessCookieFailIfRefreshSetAndUserNoLongerValid(t *testing.T) {
|
func TestProcessCookieFailIfRefreshSetAndUserNoLongerValid(t *testing.T) {
|
||||||
pc_test := NewProcessCookieTest()
|
pc_test := NewProcessCookieTestWithDefaults()
|
||||||
pc_test.InstantiateBackend()
|
|
||||||
defer pc_test.Close()
|
|
||||||
pc_test.validate_user = false
|
pc_test.validate_user = false
|
||||||
|
|
||||||
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
|
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
|
||||||
|
@ -66,3 +66,7 @@ func jwtDecodeSegment(seg string) ([]byte, error) {
|
|||||||
|
|
||||||
return base64.URLEncoding.DecodeString(seg)
|
return base64.URLEncoding.DecodeString(seg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *GoogleProvider) ValidateToken(access_token string) bool {
|
||||||
|
return validateToken(p, access_token, nil)
|
||||||
|
}
|
||||||
|
24
providers/internal_util.go
Normal file
24
providers/internal_util.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/bitly/google_auth_proxy/api"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func validateToken(p Provider, access_token string,
|
||||||
|
header http.Header) bool {
|
||||||
|
if access_token == "" || p.Data().ValidateUrl == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
url := p.Data().ValidateUrl.String()
|
||||||
|
if len(header) == 0 {
|
||||||
|
url = url + "?access_token=" + access_token
|
||||||
|
}
|
||||||
|
if resp, err := api.RequestUnparsedResponse(url, header); err != nil {
|
||||||
|
log.Printf("token validation request failed: %s", err)
|
||||||
|
return false
|
||||||
|
} else {
|
||||||
|
return resp.StatusCode == 200
|
||||||
|
}
|
||||||
|
}
|
122
providers/internal_util_test.go
Normal file
122
providers/internal_util_test.go
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/bitly/go-simplejson"
|
||||||
|
"github.com/bmizerany/assert"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ValidateTokenTestProvider struct {
|
||||||
|
*ProviderData
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tp *ValidateTokenTestProvider) GetEmailAddress(
|
||||||
|
unused_auth_response *simplejson.Json,
|
||||||
|
unused_access_token string) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note that we're testing the internal validateToken() used to implement
|
||||||
|
// several Provider's ValidateToken() implementations
|
||||||
|
func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type ValidateTokenTest struct {
|
||||||
|
backend *httptest.Server
|
||||||
|
response_code int
|
||||||
|
provider *ValidateTokenTestProvider
|
||||||
|
header http.Header
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewValidateTokenTest() *ValidateTokenTest {
|
||||||
|
var vt_test ValidateTokenTest
|
||||||
|
|
||||||
|
vt_test.backend = httptest.NewServer(
|
||||||
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/oauth/tokeninfo" {
|
||||||
|
w.WriteHeader(500)
|
||||||
|
w.Write([]byte("unknown URL"))
|
||||||
|
}
|
||||||
|
token_param := r.FormValue("access_token")
|
||||||
|
if token_param == "" {
|
||||||
|
missing := false
|
||||||
|
received_headers := r.Header
|
||||||
|
for k, _ := range vt_test.header {
|
||||||
|
received := received_headers.Get(k)
|
||||||
|
expected := vt_test.header.Get(k)
|
||||||
|
if received == "" || received != expected {
|
||||||
|
missing = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if missing {
|
||||||
|
w.WriteHeader(500)
|
||||||
|
w.Write([]byte("no token param and missing or incorrect headers"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(vt_test.response_code)
|
||||||
|
w.Write([]byte("only code matters; contents disregarded"))
|
||||||
|
|
||||||
|
}))
|
||||||
|
backend_url, _ := url.Parse(vt_test.backend.URL)
|
||||||
|
vt_test.provider = &ValidateTokenTestProvider{
|
||||||
|
ProviderData: &ProviderData{
|
||||||
|
ValidateUrl: &url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: backend_url.Host,
|
||||||
|
Path: "/oauth/tokeninfo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
vt_test.response_code = 200
|
||||||
|
return &vt_test
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vt_test *ValidateTokenTest) Close() {
|
||||||
|
vt_test.backend.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTokenValidToken(t *testing.T) {
|
||||||
|
vt_test := NewValidateTokenTest()
|
||||||
|
defer vt_test.Close()
|
||||||
|
assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTokenValidTokenWithHeaders(t *testing.T) {
|
||||||
|
vt_test := NewValidateTokenTest()
|
||||||
|
defer vt_test.Close()
|
||||||
|
vt_test.header = make(http.Header)
|
||||||
|
vt_test.header.Set("Authorization", "Bearer foobar")
|
||||||
|
assert.Equal(t, true,
|
||||||
|
validateToken(vt_test.provider, "foobar", vt_test.header))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTokenEmptyToken(t *testing.T) {
|
||||||
|
vt_test := NewValidateTokenTest()
|
||||||
|
defer vt_test.Close()
|
||||||
|
assert.Equal(t, false, validateToken(vt_test.provider, "", nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTokenEmptyValidateUrl(t *testing.T) {
|
||||||
|
vt_test := NewValidateTokenTest()
|
||||||
|
defer vt_test.Close()
|
||||||
|
vt_test.provider.Data().ValidateUrl = nil
|
||||||
|
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTokenRequestNetworkFailure(t *testing.T) {
|
||||||
|
vt_test := NewValidateTokenTest()
|
||||||
|
// Close immediately to simulate a network failure
|
||||||
|
vt_test.Close()
|
||||||
|
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateTokenExpiredToken(t *testing.T) {
|
||||||
|
vt_test := NewValidateTokenTest()
|
||||||
|
defer vt_test.Close()
|
||||||
|
vt_test.response_code = 401
|
||||||
|
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
|
||||||
|
}
|
@ -33,12 +33,23 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider {
|
|||||||
Host: "www.linkedin.com",
|
Host: "www.linkedin.com",
|
||||||
Path: "/v1/people/~/email-address"}
|
Path: "/v1/people/~/email-address"}
|
||||||
}
|
}
|
||||||
|
if p.ValidateUrl.String() == "" {
|
||||||
|
p.ValidateUrl = p.ProfileUrl
|
||||||
|
}
|
||||||
if p.Scope == "" {
|
if p.Scope == "" {
|
||||||
p.Scope = "r_emailaddress r_basicprofile"
|
p.Scope = "r_emailaddress r_basicprofile"
|
||||||
}
|
}
|
||||||
return &LinkedInProvider{ProviderData: p}
|
return &LinkedInProvider{ProviderData: p}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getLinkedInHeader(access_token string) http.Header {
|
||||||
|
header := make(http.Header)
|
||||||
|
header.Set("Accept", "application/json")
|
||||||
|
header.Set("x-li-format", "json")
|
||||||
|
header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
|
||||||
func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
|
func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
|
||||||
access_token string) (string, error) {
|
access_token string) (string, error) {
|
||||||
if access_token == "" {
|
if access_token == "" {
|
||||||
@ -49,9 +60,7 @@ func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
req.Header.Set("Accept", "application/json")
|
req.Header = getLinkedInHeader(access_token)
|
||||||
req.Header.Set("x-li-format", "json")
|
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token))
|
|
||||||
|
|
||||||
json, err := api.Request(req)
|
json, err := api.Request(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -66,3 +75,7 @@ func (p *LinkedInProvider) GetEmailAddress(unused_auth_response *simplejson.Json
|
|||||||
}
|
}
|
||||||
return email, nil
|
return email, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *LinkedInProvider) ValidateToken(access_token string) bool {
|
||||||
|
return validateToken(p, access_token, getLinkedInHeader(access_token))
|
||||||
|
}
|
||||||
|
@ -16,6 +16,7 @@ func testLinkedInProvider(hostname string) *LinkedInProvider {
|
|||||||
LoginUrl: &url.URL{},
|
LoginUrl: &url.URL{},
|
||||||
RedeemUrl: &url.URL{},
|
RedeemUrl: &url.URL{},
|
||||||
ProfileUrl: &url.URL{},
|
ProfileUrl: &url.URL{},
|
||||||
|
ValidateUrl: &url.URL{},
|
||||||
Scope: ""})
|
Scope: ""})
|
||||||
if hostname != "" {
|
if hostname != "" {
|
||||||
updateUrl(p.Data().LoginUrl, hostname)
|
updateUrl(p.Data().LoginUrl, hostname)
|
||||||
@ -52,6 +53,8 @@ func TestLinkedInProviderDefaults(t *testing.T) {
|
|||||||
p.Data().RedeemUrl.String())
|
p.Data().RedeemUrl.String())
|
||||||
assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address",
|
assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address",
|
||||||
p.Data().ProfileUrl.String())
|
p.Data().ProfileUrl.String())
|
||||||
|
assert.Equal(t, "https://www.linkedin.com/v1/people/~/email-address",
|
||||||
|
p.Data().ValidateUrl.String())
|
||||||
assert.Equal(t, "r_emailaddress r_basicprofile", p.Data().Scope)
|
assert.Equal(t, "r_emailaddress r_basicprofile", p.Data().Scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,6 +73,10 @@ func TestLinkedInProviderOverrides(t *testing.T) {
|
|||||||
Scheme: "https",
|
Scheme: "https",
|
||||||
Host: "example.com",
|
Host: "example.com",
|
||||||
Path: "/oauth/profile"},
|
Path: "/oauth/profile"},
|
||||||
|
ValidateUrl: &url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: "example.com",
|
||||||
|
Path: "/oauth/tokeninfo"},
|
||||||
Scope: "profile"})
|
Scope: "profile"})
|
||||||
assert.NotEqual(t, nil, p)
|
assert.NotEqual(t, nil, p)
|
||||||
assert.Equal(t, "LinkedIn", p.Data().ProviderName)
|
assert.Equal(t, "LinkedIn", p.Data().ProviderName)
|
||||||
@ -79,6 +86,8 @@ func TestLinkedInProviderOverrides(t *testing.T) {
|
|||||||
p.Data().RedeemUrl.String())
|
p.Data().RedeemUrl.String())
|
||||||
assert.Equal(t, "https://example.com/oauth/profile",
|
assert.Equal(t, "https://example.com/oauth/profile",
|
||||||
p.Data().ProfileUrl.String())
|
p.Data().ProfileUrl.String())
|
||||||
|
assert.Equal(t, "https://example.com/oauth/tokeninfo",
|
||||||
|
p.Data().ValidateUrl.String())
|
||||||
assert.Equal(t, "profile", p.Data().Scope)
|
assert.Equal(t, "profile", p.Data().Scope)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -58,3 +58,7 @@ func (p *MyUsaProvider) GetEmailAddress(auth_response *simplejson.Json,
|
|||||||
}
|
}
|
||||||
return json.Get("email").String()
|
return json.Get("email").String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *MyUsaProvider) ValidateToken(access_token string) bool {
|
||||||
|
return validateToken(p, access_token, nil)
|
||||||
|
}
|
||||||
|
@ -8,6 +8,7 @@ type Provider interface {
|
|||||||
Data() *ProviderData
|
Data() *ProviderData
|
||||||
GetEmailAddress(auth_response *simplejson.Json,
|
GetEmailAddress(auth_response *simplejson.Json,
|
||||||
access_token string) (string, error)
|
access_token string) (string, error)
|
||||||
|
ValidateToken(access_token string) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(provider string, p *ProviderData) Provider {
|
func New(provider string, p *ProviderData) Provider {
|
||||||
|
Loading…
Reference in New Issue
Block a user