diff --git a/oauthproxy.go b/oauthproxy.go index 76fa385..f2c4c7a 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -55,7 +55,18 @@ func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) { director := proxy.Director proxy.Director = func(req *http.Request) { director(req) - req.Host = target.Host + // use RequestURI so that we aren't unescaping encoded slashes in the request path + req.URL.Opaque = fmt.Sprintf("//%s%s", target.Host, req.RequestURI) + req.URL.RawQuery = "" + } +} +func setProxyDirector(proxy *httputil.ReverseProxy) { + director := proxy.Director + proxy.Director = func(req *http.Request) { + director(req) + // use RequestURI so that we aren't unescaping encoded slashes in the request path + req.URL.Opaque = fmt.Sprintf("//%s%s", req.URL.Host, req.RequestURI) + req.URL.RawQuery = "" } } @@ -70,6 +81,8 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { proxy := NewReverseProxy(u) if !opts.PassHostHeader { setProxyUpstreamHostHeader(proxy, u) + } else { + setProxyDirector(proxy) } serveMux.Handle(path, proxy) } diff --git a/oauthproxy_test.go b/oauthproxy_test.go index c4e3836..b7969b4 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -35,3 +35,30 @@ func TestNewReverseProxy(t *testing.T) { t.Errorf("got body %q; expected %q", g, e) } } + +func TestEncodedSlashes(t *testing.T) { + var seen string + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + seen = r.RequestURI + })) + defer backend.Close() + + b, _ := url.Parse(backend.URL) + proxyHandler := NewReverseProxy(b) + setProxyDirector(proxyHandler) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + f, _ := url.Parse(frontend.URL) + encodedPath := "/a%2Fb/" + getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}} + _, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("err %s", err) + } + expected := backend.URL + encodedPath + if seen != expected { + t.Errorf("got bad request %q expected %q", seen, expected) + } +}