From 55085d9697962668fd4e43e8e4644144fe83cd93 Mon Sep 17 00:00:00 2001 From: Colin Arnott Date: Mon, 27 Mar 2017 21:14:38 -0400 Subject: [PATCH] csrf protection; always set state --- cookie/nonce.go | 16 ++++++ oauthproxy.go | 96 +++++++++++++++++++++++++---------- oauthproxy_test.go | 37 ++++++++++---- providers/provider_default.go | 7 +-- 4 files changed, 112 insertions(+), 44 deletions(-) create mode 100644 cookie/nonce.go diff --git a/cookie/nonce.go b/cookie/nonce.go new file mode 100644 index 0000000..3012ce2 --- /dev/null +++ b/cookie/nonce.go @@ -0,0 +1,16 @@ +package cookie + +import ( + "crypto/rand" + "fmt" +) + +func Nonce() (nonce string, err error) { + b := make([]byte, 16) + _, err = rand.Read(b) + if err != nil { + return + } + nonce = fmt.Sprintf("%x", b) + return +} diff --git a/oauthproxy.go b/oauthproxy.go index 5d325e1..630b88a 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -37,6 +37,7 @@ var SignatureHeaders []string = []string{ type OAuthProxy struct { CookieSeed string CookieName string + CSRFCookieName string CookieDomain string CookieSecure bool CookieHttpOnly bool @@ -174,6 +175,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { return &OAuthProxy{ CookieName: opts.CookieName, + CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"), CookieSeed: opts.CookieSecret, CookieDomain: opts.CookieDomain, CookieSecure: opts.CookieSecure, @@ -245,7 +247,22 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e return } -func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { +func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { + if value != "" { + value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) + if len(value) > 4096 { + // Cookies cannot be larger than 4kb + log.Printf("WARNING - Cookie Size: %d bytes", len(value)) + } + } + return p.makeCookie(req, p.CookieName, value, expiration, now) +} + +func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { + return p.makeCookie(req, p.CSRFCookieName, value, expiration, now) +} + +func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie { domain := req.Host if h, _, err := net.SplitHostPort(domain); err == nil { domain = h @@ -257,15 +274,8 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time domain = p.CookieDomain } - if value != "" { - value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) - if len(value) > 4096 { - // Cookies cannot be larger than 4kb - log.Printf("WARNING - Cookie Size: %d bytes", len(value)) - } - } return &http.Cookie{ - Name: p.CookieName, + Name: name, Value: value, Path: "/", Domain: domain, @@ -275,12 +285,20 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time } } -func (p *OAuthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) { - http.SetCookie(rw, p.MakeCookie(req, "", time.Hour*-1, time.Now())) +func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) { + http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now())) } -func (p *OAuthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) { - http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now())) +func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) { + http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now())) +} + +func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { + http.SetCookie(rw, p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())) +} + +func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { + http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) } func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { @@ -309,7 +327,7 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *p if err != nil { return err } - p.SetCookie(rw, req, value) + p.SetSessionCookie(rw, req, value) return nil } @@ -339,7 +357,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m } func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) { - p.ClearCookie(rw, req) + p.ClearSessionCookie(rw, req) rw.WriteHeader(code) redirect_url := req.URL.RequestURI() @@ -384,20 +402,18 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st return "", false } -func (p *OAuthProxy) GetRedirect(req *http.Request) (string, error) { - err := req.ParseForm() - +func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) { + err = req.ParseForm() if err != nil { - return "", err + return } - redirect := req.FormValue("rd") - - if redirect == "" { + redirect = req.Form.Get("rd") + if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { redirect = "/" } - return redirect, err + return } func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) { @@ -459,18 +475,24 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) { } func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) { - p.ClearCookie(rw, req) + p.ClearSessionCookie(rw, req) http.Redirect(rw, req, "/", 302) } func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { + nonce, err := cookie.Nonce() + if err != nil { + p.ErrorPage(rw, 500, "Internal Error", err.Error()) + return + } + p.SetCSRFCookie(rw, req, nonce) redirect, err := p.GetRedirect(req) if err != nil { p.ErrorPage(rw, 500, "Internal Error", err.Error()) return } redirectURI := p.GetRedirectURI(req.Host) - http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302) + http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302) } func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { @@ -495,8 +517,26 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - redirect := req.Form.Get("state") - if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { + s := strings.SplitN(req.Form.Get("state"), ":", 2) + if len(s) != 2 { + p.ErrorPage(rw, 500, "Internal Error", "Invalid State") + return + } + nonce := s[0] + redirect := s[1] + c, err := req.Cookie(p.CSRFCookieName) + if err != nil { + p.ErrorPage(rw, 403, "Permission Denied", err.Error()) + return + } + p.ClearCSRFCookie(rw, req) + if c.Value != nonce { + log.Printf("%s csrf token mismatch, potential attack", remoteAddr) + p.ErrorPage(rw, 403, "Permission Denied", "csrf failed") + return + } + + if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") { redirect = "/" } @@ -595,7 +635,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int } if clearSession { - p.ClearCookie(rw, req) + p.ClearSessionCookie(rw, req) } if session == nil { diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 9a7f5e3..26a942d 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -170,10 +170,14 @@ func TestBasicAuthPassword(t *testing.T) { }) rw := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code", + req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader("")) + req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now())) proxy.ServeHTTP(rw, req) - cookie := rw.HeaderMap["Set-Cookie"][0] + if rw.Code >= 400 { + t.Fatalf("expected 3xx got %d", rw.Code) + } + cookie := rw.HeaderMap["Set-Cookie"][1] cookieName := proxy.CookieName var value string @@ -196,9 +200,11 @@ func TestBasicAuthPassword(t *testing.T) { Expires: time.Now().Add(time.Duration(24)), HttpOnly: true, }) + req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now())) rw = httptest.NewRecorder() proxy.ServeHTTP(rw, req) + expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword)) assert.Equal(t, expectedHeader, rw.Body.String()) provider_server.Close() @@ -263,13 +269,14 @@ func (pat_test *PassAccessTokenTest) Close() { func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, cookie string) { rw := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code", + req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", strings.NewReader("")) if err != nil { return 0, "" } + req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) pat_test.proxy.ServeHTTP(rw, req) - return rw.Code, rw.HeaderMap["Set-Cookie"][0] + return rw.Code, rw.HeaderMap["Set-Cookie"][1] } func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) { @@ -314,14 +321,18 @@ func TestForwardAccessTokenUpstream(t *testing.T) { // A successful validation will redirect and set the auth cookie. code, cookie := pat_test.getCallbackEndpoint() - assert.Equal(t, 302, code) + if code != 302 { + t.Fatalf("expected 302; got %d", code) + } assert.NotEqual(t, nil, cookie) // Now we make a regular request; the access_token from the cookie is // forwarded as the "X-Forwarded-Access-Token" header. The token is // read by the test provider server and written in the response body. code, payload := pat_test.getRootEndpoint(cookie) - assert.Equal(t, 200, code) + if code != 200 { + t.Fatalf("expected 200; got %d", code) + } assert.Equal(t, "my_auth_token", payload) } @@ -333,13 +344,17 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { // A successful validation will redirect and set the auth cookie. code, cookie := pat_test.getCallbackEndpoint() - assert.Equal(t, 302, code) + if code != 302 { + t.Fatalf("expected 302; got %d", code) + } assert.NotEqual(t, nil, cookie) // Now we make a regular request, but the access token header should // not be present. code, payload := pat_test.getRootEndpoint(cookie) - assert.Equal(t, 200, code) + if code != 200 { + t.Fatalf("expected 200; got %d", code) + } assert.Equal(t, "No access token found.", payload) } @@ -457,7 +472,7 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest { } func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie { - return p.proxy.MakeCookie(p.req, value, p.opts.CookieExpire, ref) + return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) } func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error { @@ -465,7 +480,7 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time if err != nil { return err } - p.req.AddCookie(p.proxy.MakeCookie(p.req, value, p.proxy.CookieExpire, ref)) + p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref)) return nil } @@ -697,7 +712,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { if err != nil { panic(err) } - cookie := proxy.MakeCookie(req, value, proxy.CookieExpire, time.Now()) + cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) req.AddCookie(cookie) // This is used by the upstream to validate the signature. st.authenticator.auth = hmacauth.NewHmacAuth( diff --git a/providers/provider_default.go b/providers/provider_default.go index 6b8ec40..1d1daea 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -8,7 +8,6 @@ import ( "io/ioutil" "net/http" "net/url" - "strings" "github.com/bitly/oauth2_proxy/cookie" ) @@ -79,7 +78,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er } // GetLoginURL with typical oauth parameters -func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string { +func (p *ProviderData) GetLoginURL(redirectURI, state string) string { var a url.URL a = *p.LoginURL params, _ := url.ParseQuery(a.RawQuery) @@ -88,9 +87,7 @@ func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string { params.Add("scope", p.Scope) params.Set("client_id", p.ClientID) params.Set("response_type", "code") - if strings.HasPrefix(finalRedirect, "/") && !strings.HasPrefix(finalRedirect,"//") { - params.Add("state", finalRedirect) - } + params.Add("state", state) a.RawQuery = params.Encode() return a.String() }