From d2b1815d4337a43ee89aee18bdac64f88d1bb66b Mon Sep 17 00:00:00 2001 From: Sean O'Connor Date: Tue, 22 Oct 2013 19:56:29 +0000 Subject: [PATCH] After authentication, redirect to original URI. --- oauthproxy.go | 53 ++++++++++++++++++++++++++++++++++----------------- templates.go | 2 ++ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index 9ce9ee9..08b065b 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -67,13 +67,16 @@ func (p *OauthProxy) SetRedirectUrl(redirectUrl *url.URL) { p.redirectUrl = redirectUrl } -func (p *OauthProxy) GetLoginURL() string { +func (p *OauthProxy) GetLoginURL(redirectUrl string) string { params := url.Values{} params.Add("redirect_uri", p.redirectUrl.String()) params.Add("approval_prompt", "force") params.Add("scope", p.oauthScope) params.Add("client_id", p.clientID) params.Add("response_type", "code") + if strings.HasPrefix(redirectUrl, "/") { + params.Add("state", redirectUrl) + } return fmt.Sprintf("%s?%s", p.oauthLoginUrl, params.Encode()) } @@ -100,6 +103,9 @@ func apiRequest(req *http.Request) (*simplejson.Json, error) { } func (p *OauthProxy) redeemCode(code string) (string, error) { + if code == "" { + return "", errors.New("missing code") + } params := url.Values{} params.Add("redirect_uri", p.redirectUrl.String()) params.Add("client_id", p.clientID) @@ -197,7 +203,6 @@ func (p *OauthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m } func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { - // TODO: capture state for which url to redirect to at the end p.ClearCookie(rw, req) rw.WriteHeader(code) templates := getTemplates() @@ -205,9 +210,11 @@ func (p *OauthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code t := struct { SignInMessage string Htpasswd bool + Redirect string }{ SignInMessage: p.SignInMessage, Htpasswd: p.HtpasswdFile != nil, + Redirect: req.URL.RequestURI(), } templates.ExecuteTemplate(rw, "sign_in.html", t) } @@ -239,39 +246,46 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { var ok bool var user string + redirect := req.FormValue("rd") + + if redirect == "" { + redirect = "/" + } + if req.URL.Path == signInPath { user, ok = p.ManualSignIn(rw, req) if ok { p.SetCookie(rw, req, user) - http.Redirect(rw, req, "/", 302) + http.Redirect(rw, req, redirect, 302) } else { p.SignInPage(rw, req, 200) } return } if req.URL.Path == oauthStartPath { - http.Redirect(rw, req, p.GetLoginURL(), 302) - return - } - if req.URL.Path == oauthCallbackPath { - // finish the oauth cycle - reqParams, err := url.ParseQuery(req.URL.RawQuery) + // get the ?rd= value + err := req.ParseForm() if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } - errorString, ok := reqParams["error"] - if ok && len(errorString) == 1 { - p.ErrorPage(rw, 403, "Permission Denied", errorString[0]) + http.Redirect(rw, req, p.GetLoginURL(redirect), 302) + return + } + if req.URL.Path == oauthCallbackPath { + // finish the oauth cycle + err := req.ParseForm() + if err != nil { + p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } - code, ok := reqParams["code"] - if !ok || len(code) != 1 { - p.ErrorPage(rw, 500, "Internal Error", "Invalid API response") + errorString := req.Form.Get("error") + if errorString != "" { + p.ErrorPage(rw, 403, "Permission Denied", errorString) return } - token, err := p.redeemCode(code[0]) + token, err := p.redeemCode(req.Form.Get("code")) if err != nil { log.Printf("error redeeming code %s", err.Error()) p.ErrorPage(rw, 500, "Internal Error", err.Error()) @@ -285,11 +299,16 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } + redirect := req.Form.Get("state") + if redirect == "" { + redirect = "/" + } + // set cookie, or deny if p.Validator(email) { log.Printf("authenticating %s completed", email) p.SetCookie(rw, req, email) - http.Redirect(rw, req, "/", 302) + http.Redirect(rw, req, redirect, 302) return } else { p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") diff --git a/templates.go b/templates.go index 5154dec..5929a6f 100644 --- a/templates.go +++ b/templates.go @@ -12,12 +12,14 @@ func getTemplates() *template.Template { Sign In
+ {{.SignInMessage}}
{{ if .Htpasswd }}
+