diff --git a/oauthproxy.go b/oauthproxy.go index 08b065b..8392489 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -236,6 +236,22 @@ func (p *OauthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st return "", false } +func (p *OauthProxy) GetRedirect(req *http.Request) (string, error) { + err := req.ParseForm() + + if err != nil { + return "", err + } + + redirect := req.FormValue("rd") + + if redirect == "" { + redirect = "/" + } + + return redirect, err +} + func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // check if this is a redirect back at the end of oauth remoteIP := req.Header.Get("X-Real-IP") @@ -246,13 +262,14 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { var ok bool var user string - redirect := req.FormValue("rd") - - if redirect == "" { - redirect = "/" - } if req.URL.Path == signInPath { + redirect, err := p.GetRedirect(req) + if err != nil { + p.ErrorPage(rw, 500, "Internal Error", err.Error()) + return + } + user, ok = p.ManualSignIn(rw, req) if ok { p.SetCookie(rw, req, user) @@ -263,8 +280,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } if req.URL.Path == oauthStartPath { - // get the ?rd= value - err := req.ParseForm() + redirect, err := p.GetRedirect(req) if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) return