diff --git a/CHANGELOG.md b/CHANGELOG.md index 83c4e2c..51086d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ - [#21](https://github.com/pusher/oauth2_proxy/pull/21) Docker Improvement (@yaegashi) - Move Docker base image from debian to alpine - Install ca-certificates in docker image +- [#24](https://github.com/pusher/oauth2_proxy/pull/24) Redirect fix (@agentgonzo) + - After a successful login, you will be redirected to your original URL rather than / # v3.0.0 diff --git a/oauthproxy.go b/oauthproxy.go index 8cbb4e5..64503cb 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -564,7 +564,10 @@ func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) redirect = req.Form.Get("rd") if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { - redirect = "/" + redirect = req.URL.Path + if strings.HasPrefix(redirect, p.ProxyPrefix) { + redirect = "/" + } } return diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 1f9914e..adc3cfb 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -18,6 +18,7 @@ import ( "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/providers" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func init() { @@ -837,3 +838,36 @@ func TestRequestSignaturePostRequest(t *testing.T) { assert.Equal(t, 200, st.rw.Code) assert.Equal(t, st.rw.Body.String(), "signatures match") } + +func TestGetRedirect(t *testing.T) { + options := NewOptions() + _ = options.Validate() + require.NotEmpty(t, options.ProxyPrefix) + proxy := NewOAuthProxy(options, func(s string) bool { return false }) + + tests := []struct { + name string + url string + expectedRedirect string + }{ + { + name: "request outside of ProxyPrefix redirects to original URL", + url: "/foo/bar", + expectedRedirect: "/foo/bar", + }, + { + name: "request under ProxyPrefix redirects to root", + url: proxy.ProxyPrefix + "/foo/bar", + expectedRedirect: "/", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", tt.url, nil) + redirect, err := proxy.GetRedirect(req) + + assert.NoError(t, err) + assert.Equal(t, tt.expectedRedirect, redirect) + }) + } +}