Add tests for static upstream
This commit is contained in:
parent
a7e3c3a7ef
commit
297c3c465d
@ -202,17 +202,17 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
|||||||
}
|
}
|
||||||
for _, u := range opts.proxyURLs {
|
for _, u := range opts.proxyURLs {
|
||||||
path := u.Path
|
path := u.Path
|
||||||
|
host := u.Host
|
||||||
switch u.Scheme {
|
switch u.Scheme {
|
||||||
case httpScheme, httpsScheme:
|
case httpScheme, httpsScheme:
|
||||||
logger.Printf("mapping path %q => upstream %q", path, u)
|
logger.Printf("mapping path %q => upstream %q", path, u)
|
||||||
proxy := NewWebSocketOrRestReverseProxy(u, opts, auth)
|
proxy := NewWebSocketOrRestReverseProxy(u, opts, auth)
|
||||||
serveMux.Handle(path, proxy)
|
serveMux.Handle(path, proxy)
|
||||||
|
|
||||||
case "static":
|
case "static":
|
||||||
serveMux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) {
|
serveMux.HandleFunc(path, func(rw http.ResponseWriter, req *http.Request) {
|
||||||
responseCode, err := strconv.Atoi(u.Host)
|
responseCode, err := strconv.Atoi(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Printf("unable to convert %q to int, use default \"200\"", u.Host)
|
logger.Printf("unable to convert %q to int, use default \"200\"", host)
|
||||||
responseCode = 200
|
responseCode = 200
|
||||||
}
|
}
|
||||||
rw.WriteHeader(responseCode)
|
rw.WriteHeader(responseCode)
|
||||||
|
@ -365,6 +365,7 @@ type PassAccessTokenTest struct {
|
|||||||
|
|
||||||
type PassAccessTokenTestOptions struct {
|
type PassAccessTokenTestOptions struct {
|
||||||
PassAccessToken bool
|
PassAccessToken bool
|
||||||
|
ProxyUpstream string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
|
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
|
||||||
@ -372,7 +373,6 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
|
|||||||
|
|
||||||
t.providerServer = httptest.NewServer(
|
t.providerServer = httptest.NewServer(
|
||||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
logger.Printf("%#v", r)
|
|
||||||
var payload string
|
var payload string
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case "/oauth/token":
|
case "/oauth/token":
|
||||||
@ -389,6 +389,9 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
|
|||||||
|
|
||||||
t.opts = NewOptions()
|
t.opts = NewOptions()
|
||||||
t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL)
|
t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL)
|
||||||
|
if opts.ProxyUpstream != "" {
|
||||||
|
t.opts.Upstreams = append(t.opts.Upstreams, opts.ProxyUpstream)
|
||||||
|
}
|
||||||
// The CookieSecret must be 32 bytes in order to create the AES
|
// The CookieSecret must be 32 bytes in order to create the AES
|
||||||
// cipher.
|
// cipher.
|
||||||
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
|
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
|
||||||
@ -459,6 +462,39 @@ func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int
|
|||||||
return rw.Code, rw.Body.String()
|
return rw.Code, rw.Body.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (patTest *PassAccessTokenTest) getProxyEndpoint(cookie string) (httpCode int, accessToken string) {
|
||||||
|
cookieName := patTest.proxy.CookieName
|
||||||
|
var value string
|
||||||
|
keyPrefix := cookieName + "="
|
||||||
|
|
||||||
|
for _, field := range strings.Split(cookie, "; ") {
|
||||||
|
value = strings.TrimPrefix(field, keyPrefix)
|
||||||
|
if value != field {
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
value = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if value == "" {
|
||||||
|
return 0, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("GET", "/static-proxy", strings.NewReader(""))
|
||||||
|
if err != nil {
|
||||||
|
return 0, ""
|
||||||
|
}
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: cookieName,
|
||||||
|
Value: value,
|
||||||
|
Path: "/",
|
||||||
|
Expires: time.Now().Add(time.Duration(24)),
|
||||||
|
HttpOnly: true,
|
||||||
|
})
|
||||||
|
rw := httptest.NewRecorder()
|
||||||
|
patTest.proxy.ServeHTTP(rw, req)
|
||||||
|
return rw.Code, rw.Body.String()
|
||||||
|
}
|
||||||
|
|
||||||
func TestForwardAccessTokenUpstream(t *testing.T) {
|
func TestForwardAccessTokenUpstream(t *testing.T) {
|
||||||
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
||||||
PassAccessToken: true,
|
PassAccessToken: true,
|
||||||
@ -482,6 +518,28 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
|
|||||||
assert.Equal(t, "my_auth_token", payload)
|
assert.Equal(t, "my_auth_token", payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStaticProxyUpstream(t *testing.T) {
|
||||||
|
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
||||||
|
PassAccessToken: true,
|
||||||
|
ProxyUpstream: "static://200/static-proxy",
|
||||||
|
})
|
||||||
|
|
||||||
|
defer patTest.Close()
|
||||||
|
|
||||||
|
// A successful validation will redirect and set the auth cookie.
|
||||||
|
code, cookie := patTest.getCallbackEndpoint()
|
||||||
|
if code != 302 {
|
||||||
|
t.Fatalf("expected 302; got %d", code)
|
||||||
|
}
|
||||||
|
assert.NotEqual(t, nil, cookie)
|
||||||
|
|
||||||
|
code, payload := patTest.getProxyEndpoint(cookie)
|
||||||
|
if code != 200 {
|
||||||
|
t.Fatalf("expected 200; got %d", code)
|
||||||
|
}
|
||||||
|
assert.Equal(t, "Authenticated", payload)
|
||||||
|
}
|
||||||
|
|
||||||
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
||||||
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
|
||||||
PassAccessToken: false,
|
PassAccessToken: false,
|
||||||
|
Loading…
Reference in New Issue
Block a user