pass raw unencoded request URI upstream

This commit is contained in:
Jehiah Czebotar 2015-03-17 17:17:40 -04:00
parent 85e025db25
commit 71ae70834d
2 changed files with 41 additions and 1 deletions

View File

@ -55,7 +55,18 @@ func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) {
director := proxy.Director director := proxy.Director
proxy.Director = func(req *http.Request) { proxy.Director = func(req *http.Request) {
director(req) director(req)
req.Host = target.Host // use RequestURI so that we aren't unescaping encoded slashes in the request path
req.URL.Opaque = fmt.Sprintf("//%s%s", target.Host, req.RequestURI)
req.URL.RawQuery = ""
}
}
func setProxyDirector(proxy *httputil.ReverseProxy) {
director := proxy.Director
proxy.Director = func(req *http.Request) {
director(req)
// use RequestURI so that we aren't unescaping encoded slashes in the request path
req.URL.Opaque = fmt.Sprintf("//%s%s", req.URL.Host, req.RequestURI)
req.URL.RawQuery = ""
} }
} }
@ -70,6 +81,8 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
proxy := NewReverseProxy(u) proxy := NewReverseProxy(u)
if !opts.PassHostHeader { if !opts.PassHostHeader {
setProxyUpstreamHostHeader(proxy, u) setProxyUpstreamHostHeader(proxy, u)
} else {
setProxyDirector(proxy)
} }
serveMux.Handle(path, proxy) serveMux.Handle(path, proxy)
} }

View File

@ -35,3 +35,30 @@ func TestNewReverseProxy(t *testing.T) {
t.Errorf("got body %q; expected %q", g, e) t.Errorf("got body %q; expected %q", g, e)
} }
} }
func TestEncodedSlashes(t *testing.T) {
var seen string
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
seen = r.RequestURI
}))
defer backend.Close()
b, _ := url.Parse(backend.URL)
proxyHandler := NewReverseProxy(b)
setProxyDirector(proxyHandler)
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
f, _ := url.Parse(frontend.URL)
encodedPath := "/a%2Fb/"
getReq := &http.Request{URL: &url.URL{Scheme: "http", Host: f.Host, Opaque: encodedPath}}
_, err := http.DefaultClient.Do(getReq)
if err != nil {
t.Fatalf("err %s", err)
}
expected := backend.URL + encodedPath
if seen != expected {
t.Errorf("got bad request %q expected %q", seen, expected)
}
}