diff --git a/oauthproxy.go b/oauthproxy.go index f40c249..265252f 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -202,17 +202,17 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { } for _, u := range opts.proxyURLs { path := u.Path + host := u.Host switch u.Scheme { case httpScheme, httpsScheme: logger.Printf("mapping path %q => upstream %q", path, u) proxy := NewWebSocketOrRestReverseProxy(u, opts, auth) serveMux.Handle(path, proxy) - case "static": - serveMux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { - responseCode, err := strconv.Atoi(u.Host) + serveMux.HandleFunc(path, func(rw http.ResponseWriter, req *http.Request) { + responseCode, err := strconv.Atoi(host) if err != nil { - logger.Printf("unable to convert %q to int, use default \"200\"", u.Host) + logger.Printf("unable to convert %q to int, use default \"200\"", host) responseCode = 200 } rw.WriteHeader(responseCode) diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 8dd3adf..563e6ce 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -365,6 +365,7 @@ type PassAccessTokenTest struct { type PassAccessTokenTestOptions struct { PassAccessToken bool + ProxyUpstream string } func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { @@ -372,7 +373,6 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes t.providerServer = httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - logger.Printf("%#v", r) var payload string switch r.URL.Path { case "/oauth/token": @@ -389,6 +389,9 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes t.opts = NewOptions() t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL) + if opts.ProxyUpstream != "" { + t.opts.Upstreams = append(t.opts.Upstreams, opts.ProxyUpstream) + } // The CookieSecret must be 32 bytes in order to create the AES // cipher. t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" @@ -459,6 +462,39 @@ func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int return rw.Code, rw.Body.String() } +func (patTest *PassAccessTokenTest) getProxyEndpoint(cookie string) (httpCode int, accessToken string) { + cookieName := patTest.proxy.CookieName + var value string + keyPrefix := cookieName + "=" + + for _, field := range strings.Split(cookie, "; ") { + value = strings.TrimPrefix(field, keyPrefix) + if value != field { + break + } else { + value = "" + } + } + if value == "" { + return 0, "" + } + + req, err := http.NewRequest("GET", "/static-proxy", strings.NewReader("")) + if err != nil { + return 0, "" + } + req.AddCookie(&http.Cookie{ + Name: cookieName, + Value: value, + Path: "/", + Expires: time.Now().Add(time.Duration(24)), + HttpOnly: true, + }) + rw := httptest.NewRecorder() + patTest.proxy.ServeHTTP(rw, req) + return rw.Code, rw.Body.String() +} + func TestForwardAccessTokenUpstream(t *testing.T) { patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, @@ -482,6 +518,28 @@ func TestForwardAccessTokenUpstream(t *testing.T) { assert.Equal(t, "my_auth_token", payload) } +func TestStaticProxyUpstream(t *testing.T) { + patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + PassAccessToken: true, + ProxyUpstream: "static://200/static-proxy", + }) + + defer patTest.Close() + + // A successful validation will redirect and set the auth cookie. + code, cookie := patTest.getCallbackEndpoint() + if code != 302 { + t.Fatalf("expected 302; got %d", code) + } + assert.NotEqual(t, nil, cookie) + + code, payload := patTest.getProxyEndpoint(cookie) + if code != 200 { + t.Fatalf("expected 200; got %d", code) + } + assert.Equal(t, "Authenticated", payload) +} + func TestDoNotForwardAccessTokenUpstream(t *testing.T) { patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: false,