Merge pull request #360 from jehiah/csrf_validation_360
CSRF protection for OAuth flow.
This commit is contained in:
commit
4464655276
16
cookie/nonce.go
Normal file
16
cookie/nonce.go
Normal file
@ -0,0 +1,16 @@
|
||||
package cookie
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func Nonce() (nonce string, err error) {
|
||||
b := make([]byte, 16)
|
||||
_, err = rand.Read(b)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
nonce = fmt.Sprintf("%x", b)
|
||||
return
|
||||
}
|
@ -37,6 +37,7 @@ var SignatureHeaders []string = []string{
|
||||
type OAuthProxy struct {
|
||||
CookieSeed string
|
||||
CookieName string
|
||||
CSRFCookieName string
|
||||
CookieDomain string
|
||||
CookieSecure bool
|
||||
CookieHttpOnly bool
|
||||
@ -174,6 +175,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
||||
|
||||
return &OAuthProxy{
|
||||
CookieName: opts.CookieName,
|
||||
CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"),
|
||||
CookieSeed: opts.CookieSecret,
|
||||
CookieDomain: opts.CookieDomain,
|
||||
CookieSecure: opts.CookieSecure,
|
||||
@ -245,7 +247,22 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
if value != "" {
|
||||
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
|
||||
if len(value) > 4096 {
|
||||
// Cookies cannot be larger than 4kb
|
||||
log.Printf("WARNING - Cookie Size: %d bytes", len(value))
|
||||
}
|
||||
}
|
||||
return p.makeCookie(req, p.CookieName, value, expiration, now)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||
domain := req.Host
|
||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||
domain = h
|
||||
@ -257,15 +274,8 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time
|
||||
domain = p.CookieDomain
|
||||
}
|
||||
|
||||
if value != "" {
|
||||
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
|
||||
if len(value) > 4096 {
|
||||
// Cookies cannot be larger than 4kb
|
||||
log.Printf("WARNING - Cookie Size: %d bytes", len(value))
|
||||
}
|
||||
}
|
||||
return &http.Cookie{
|
||||
Name: p.CookieName,
|
||||
Name: name,
|
||||
Value: value,
|
||||
Path: "/",
|
||||
Domain: domain,
|
||||
@ -275,12 +285,20 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(rw, p.MakeCookie(req, "", time.Hour*-1, time.Now()))
|
||||
func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now()))
|
||||
func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
|
||||
http.SetCookie(rw, p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()))
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||
http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()))
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
|
||||
@ -309,7 +327,7 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *p
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.SetCookie(rw, req, value)
|
||||
p.SetSessionCookie(rw, req, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -339,7 +357,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
||||
p.ClearCookie(rw, req)
|
||||
p.ClearSessionCookie(rw, req)
|
||||
rw.WriteHeader(code)
|
||||
|
||||
redirect_url := req.URL.RequestURI()
|
||||
@ -384,20 +402,18 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) GetRedirect(req *http.Request) (string, error) {
|
||||
err := req.ParseForm()
|
||||
|
||||
func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
|
||||
err = req.ParseForm()
|
||||
if err != nil {
|
||||
return "", err
|
||||
return
|
||||
}
|
||||
|
||||
redirect := req.FormValue("rd")
|
||||
|
||||
if redirect == "" {
|
||||
redirect = req.Form.Get("rd")
|
||||
if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
||||
redirect = "/"
|
||||
}
|
||||
|
||||
return redirect, err
|
||||
return
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) {
|
||||
@ -459,18 +475,24 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
|
||||
p.ClearCookie(rw, req)
|
||||
p.ClearSessionCookie(rw, req)
|
||||
http.Redirect(rw, req, "/", 302)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||
nonce, err := cookie.Nonce()
|
||||
if err != nil {
|
||||
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
p.SetCSRFCookie(rw, req, nonce)
|
||||
redirect, err := p.GetRedirect(req)
|
||||
if err != nil {
|
||||
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||
return
|
||||
}
|
||||
redirectURI := p.GetRedirectURI(req.Host)
|
||||
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
|
||||
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302)
|
||||
}
|
||||
|
||||
func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
@ -495,8 +517,26 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
redirect := req.Form.Get("state")
|
||||
if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
||||
s := strings.SplitN(req.Form.Get("state"), ":", 2)
|
||||
if len(s) != 2 {
|
||||
p.ErrorPage(rw, 500, "Internal Error", "Invalid State")
|
||||
return
|
||||
}
|
||||
nonce := s[0]
|
||||
redirect := s[1]
|
||||
c, err := req.Cookie(p.CSRFCookieName)
|
||||
if err != nil {
|
||||
p.ErrorPage(rw, 403, "Permission Denied", err.Error())
|
||||
return
|
||||
}
|
||||
p.ClearCSRFCookie(rw, req)
|
||||
if c.Value != nonce {
|
||||
log.Printf("%s csrf token mismatch, potential attack", remoteAddr)
|
||||
p.ErrorPage(rw, 403, "Permission Denied", "csrf failed")
|
||||
return
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
||||
redirect = "/"
|
||||
}
|
||||
|
||||
@ -595,7 +635,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
||||
}
|
||||
|
||||
if clearSession {
|
||||
p.ClearCookie(rw, req)
|
||||
p.ClearSessionCookie(rw, req)
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
|
@ -170,10 +170,14 @@ func TestBasicAuthPassword(t *testing.T) {
|
||||
})
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
|
||||
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
|
||||
strings.NewReader(""))
|
||||
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
|
||||
proxy.ServeHTTP(rw, req)
|
||||
cookie := rw.HeaderMap["Set-Cookie"][0]
|
||||
if rw.Code >= 400 {
|
||||
t.Fatalf("expected 3xx got %d", rw.Code)
|
||||
}
|
||||
cookie := rw.HeaderMap["Set-Cookie"][1]
|
||||
|
||||
cookieName := proxy.CookieName
|
||||
var value string
|
||||
@ -196,9 +200,11 @@ func TestBasicAuthPassword(t *testing.T) {
|
||||
Expires: time.Now().Add(time.Duration(24)),
|
||||
HttpOnly: true,
|
||||
})
|
||||
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
|
||||
|
||||
rw = httptest.NewRecorder()
|
||||
proxy.ServeHTTP(rw, req)
|
||||
|
||||
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword))
|
||||
assert.Equal(t, expectedHeader, rw.Body.String())
|
||||
provider_server.Close()
|
||||
@ -263,13 +269,14 @@ func (pat_test *PassAccessTokenTest) Close() {
|
||||
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
|
||||
cookie string) {
|
||||
rw := httptest.NewRecorder()
|
||||
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
|
||||
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
|
||||
strings.NewReader(""))
|
||||
if err != nil {
|
||||
return 0, ""
|
||||
}
|
||||
req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
|
||||
pat_test.proxy.ServeHTTP(rw, req)
|
||||
return rw.Code, rw.HeaderMap["Set-Cookie"][0]
|
||||
return rw.Code, rw.HeaderMap["Set-Cookie"][1]
|
||||
}
|
||||
|
||||
func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
|
||||
@ -314,14 +321,18 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
|
||||
|
||||
// A successful validation will redirect and set the auth cookie.
|
||||
code, cookie := pat_test.getCallbackEndpoint()
|
||||
assert.Equal(t, 302, code)
|
||||
if code != 302 {
|
||||
t.Fatalf("expected 302; got %d", code)
|
||||
}
|
||||
assert.NotEqual(t, nil, cookie)
|
||||
|
||||
// Now we make a regular request; the access_token from the cookie is
|
||||
// forwarded as the "X-Forwarded-Access-Token" header. The token is
|
||||
// read by the test provider server and written in the response body.
|
||||
code, payload := pat_test.getRootEndpoint(cookie)
|
||||
assert.Equal(t, 200, code)
|
||||
if code != 200 {
|
||||
t.Fatalf("expected 200; got %d", code)
|
||||
}
|
||||
assert.Equal(t, "my_auth_token", payload)
|
||||
}
|
||||
|
||||
@ -333,13 +344,17 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
||||
|
||||
// A successful validation will redirect and set the auth cookie.
|
||||
code, cookie := pat_test.getCallbackEndpoint()
|
||||
assert.Equal(t, 302, code)
|
||||
if code != 302 {
|
||||
t.Fatalf("expected 302; got %d", code)
|
||||
}
|
||||
assert.NotEqual(t, nil, cookie)
|
||||
|
||||
// Now we make a regular request, but the access token header should
|
||||
// not be present.
|
||||
code, payload := pat_test.getRootEndpoint(cookie)
|
||||
assert.Equal(t, 200, code)
|
||||
if code != 200 {
|
||||
t.Fatalf("expected 200; got %d", code)
|
||||
}
|
||||
assert.Equal(t, "No access token found.", payload)
|
||||
}
|
||||
|
||||
@ -457,7 +472,7 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
|
||||
}
|
||||
|
||||
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
|
||||
return p.proxy.MakeCookie(p.req, value, p.opts.CookieExpire, ref)
|
||||
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
||||
}
|
||||
|
||||
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
|
||||
@ -465,7 +480,7 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.req.AddCookie(p.proxy.MakeCookie(p.req, value, p.proxy.CookieExpire, ref))
|
||||
p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -697,7 +712,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
cookie := proxy.MakeCookie(req, value, proxy.CookieExpire, time.Now())
|
||||
cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now())
|
||||
req.AddCookie(cookie)
|
||||
// This is used by the upstream to validate the signature.
|
||||
st.authenticator.auth = hmacauth.NewHmacAuth(
|
||||
|
@ -8,7 +8,6 @@ import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/bitly/oauth2_proxy/cookie"
|
||||
)
|
||||
@ -79,7 +78,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
|
||||
}
|
||||
|
||||
// GetLoginURL with typical oauth parameters
|
||||
func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
|
||||
func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
|
||||
var a url.URL
|
||||
a = *p.LoginURL
|
||||
params, _ := url.ParseQuery(a.RawQuery)
|
||||
@ -88,9 +87,7 @@ func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
|
||||
params.Add("scope", p.Scope)
|
||||
params.Set("client_id", p.ClientID)
|
||||
params.Set("response_type", "code")
|
||||
if strings.HasPrefix(finalRedirect, "/") && !strings.HasPrefix(finalRedirect,"//") {
|
||||
params.Add("state", finalRedirect)
|
||||
}
|
||||
params.Add("state", state)
|
||||
a.RawQuery = params.Encode()
|
||||
return a.String()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user