Add tests for static upstream

This commit is contained in:
Christian Groschupp 2019-09-19 11:03:38 +02:00
parent a7e3c3a7ef
commit 297c3c465d
No known key found for this signature in database
GPG Key ID: F164E00C6EDA908F
2 changed files with 63 additions and 5 deletions

View File

@ -202,17 +202,17 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
} }
for _, u := range opts.proxyURLs { for _, u := range opts.proxyURLs {
path := u.Path path := u.Path
host := u.Host
switch u.Scheme { switch u.Scheme {
case httpScheme, httpsScheme: case httpScheme, httpsScheme:
logger.Printf("mapping path %q => upstream %q", path, u) logger.Printf("mapping path %q => upstream %q", path, u)
proxy := NewWebSocketOrRestReverseProxy(u, opts, auth) proxy := NewWebSocketOrRestReverseProxy(u, opts, auth)
serveMux.Handle(path, proxy) serveMux.Handle(path, proxy)
case "static": case "static":
serveMux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { serveMux.HandleFunc(path, func(rw http.ResponseWriter, req *http.Request) {
responseCode, err := strconv.Atoi(u.Host) responseCode, err := strconv.Atoi(host)
if err != nil { if err != nil {
logger.Printf("unable to convert %q to int, use default \"200\"", u.Host) logger.Printf("unable to convert %q to int, use default \"200\"", host)
responseCode = 200 responseCode = 200
} }
rw.WriteHeader(responseCode) rw.WriteHeader(responseCode)

View File

@ -365,6 +365,7 @@ type PassAccessTokenTest struct {
type PassAccessTokenTestOptions struct { type PassAccessTokenTestOptions struct {
PassAccessToken bool PassAccessToken bool
ProxyUpstream string
} }
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
@ -372,7 +373,6 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
t.providerServer = httptest.NewServer( t.providerServer = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logger.Printf("%#v", r)
var payload string var payload string
switch r.URL.Path { switch r.URL.Path {
case "/oauth/token": case "/oauth/token":
@ -389,6 +389,9 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes
t.opts = NewOptions() t.opts = NewOptions()
t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL) t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL)
if opts.ProxyUpstream != "" {
t.opts.Upstreams = append(t.opts.Upstreams, opts.ProxyUpstream)
}
// The CookieSecret must be 32 bytes in order to create the AES // The CookieSecret must be 32 bytes in order to create the AES
// cipher. // cipher.
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
@ -459,6 +462,39 @@ func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int
return rw.Code, rw.Body.String() return rw.Code, rw.Body.String()
} }
func (patTest *PassAccessTokenTest) getProxyEndpoint(cookie string) (httpCode int, accessToken string) {
cookieName := patTest.proxy.CookieName
var value string
keyPrefix := cookieName + "="
for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, keyPrefix)
if value != field {
break
} else {
value = ""
}
}
if value == "" {
return 0, ""
}
req, err := http.NewRequest("GET", "/static-proxy", strings.NewReader(""))
if err != nil {
return 0, ""
}
req.AddCookie(&http.Cookie{
Name: cookieName,
Value: value,
Path: "/",
Expires: time.Now().Add(time.Duration(24)),
HttpOnly: true,
})
rw := httptest.NewRecorder()
patTest.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
func TestForwardAccessTokenUpstream(t *testing.T) { func TestForwardAccessTokenUpstream(t *testing.T) {
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true, PassAccessToken: true,
@ -482,6 +518,28 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
assert.Equal(t, "my_auth_token", payload) assert.Equal(t, "my_auth_token", payload)
} }
func TestStaticProxyUpstream(t *testing.T) {
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true,
ProxyUpstream: "static://200/static-proxy",
})
defer patTest.Close()
// A successful validation will redirect and set the auth cookie.
code, cookie := patTest.getCallbackEndpoint()
if code != 302 {
t.Fatalf("expected 302; got %d", code)
}
assert.NotEqual(t, nil, cookie)
code, payload := patTest.getProxyEndpoint(cookie)
if code != 200 {
t.Fatalf("expected 200; got %d", code)
}
assert.Equal(t, "Authenticated", payload)
}
func TestDoNotForwardAccessTokenUpstream(t *testing.T) { func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: false, PassAccessToken: false,