Merge pull request #360 from jehiah/csrf_validation_360
CSRF protection for OAuth flow.
This commit is contained in:
commit
4464655276
16
cookie/nonce.go
Normal file
16
cookie/nonce.go
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package cookie
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Nonce() (nonce string, err error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
_, err = rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nonce = fmt.Sprintf("%x", b)
|
||||||
|
return
|
||||||
|
}
|
@ -37,6 +37,7 @@ var SignatureHeaders []string = []string{
|
|||||||
type OAuthProxy struct {
|
type OAuthProxy struct {
|
||||||
CookieSeed string
|
CookieSeed string
|
||||||
CookieName string
|
CookieName string
|
||||||
|
CSRFCookieName string
|
||||||
CookieDomain string
|
CookieDomain string
|
||||||
CookieSecure bool
|
CookieSecure bool
|
||||||
CookieHttpOnly bool
|
CookieHttpOnly bool
|
||||||
@ -174,6 +175,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
|||||||
|
|
||||||
return &OAuthProxy{
|
return &OAuthProxy{
|
||||||
CookieName: opts.CookieName,
|
CookieName: opts.CookieName,
|
||||||
|
CSRFCookieName: fmt.Sprintf("%v_%v", opts.CookieName, "csrf"),
|
||||||
CookieSeed: opts.CookieSecret,
|
CookieSeed: opts.CookieSecret,
|
||||||
CookieDomain: opts.CookieDomain,
|
CookieDomain: opts.CookieDomain,
|
||||||
CookieSecure: opts.CookieSecure,
|
CookieSecure: opts.CookieSecure,
|
||||||
@ -245,7 +247,22 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
|
if value != "" {
|
||||||
|
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
|
||||||
|
if len(value) > 4096 {
|
||||||
|
// Cookies cannot be larger than 4kb
|
||||||
|
log.Printf("WARNING - Cookie Size: %d bytes", len(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p.makeCookie(req, p.CookieName, value, expiration, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
|
return p.makeCookie(req, p.CSRFCookieName, value, expiration, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
domain := req.Host
|
domain := req.Host
|
||||||
if h, _, err := net.SplitHostPort(domain); err == nil {
|
if h, _, err := net.SplitHostPort(domain); err == nil {
|
||||||
domain = h
|
domain = h
|
||||||
@ -257,15 +274,8 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time
|
|||||||
domain = p.CookieDomain
|
domain = p.CookieDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
if value != "" {
|
|
||||||
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
|
|
||||||
if len(value) > 4096 {
|
|
||||||
// Cookies cannot be larger than 4kb
|
|
||||||
log.Printf("WARNING - Cookie Size: %d bytes", len(value))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &http.Cookie{
|
return &http.Cookie{
|
||||||
Name: p.CookieName,
|
Name: name,
|
||||||
Value: value,
|
Value: value,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Domain: domain,
|
Domain: domain,
|
||||||
@ -275,12 +285,20 @@ func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) {
|
func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {
|
||||||
http.SetCookie(rw, p.MakeCookie(req, "", time.Hour*-1, time.Now()))
|
http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||||
http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now()))
|
http.SetCookie(rw, p.MakeCSRFCookie(req, val, p.CookieExpire, time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
http.SetCookie(rw, p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||||
|
http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
|
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
|
||||||
@ -309,7 +327,7 @@ func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *p
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
p.SetCookie(rw, req, value)
|
p.SetSessionCookie(rw, req, value)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -339,7 +357,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, m
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
|
||||||
p.ClearCookie(rw, req)
|
p.ClearSessionCookie(rw, req)
|
||||||
rw.WriteHeader(code)
|
rw.WriteHeader(code)
|
||||||
|
|
||||||
redirect_url := req.URL.RequestURI()
|
redirect_url := req.URL.RequestURI()
|
||||||
@ -384,20 +402,18 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st
|
|||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) GetRedirect(req *http.Request) (string, error) {
|
func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
|
||||||
err := req.ParseForm()
|
err = req.ParseForm()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirect := req.FormValue("rd")
|
redirect = req.Form.Get("rd")
|
||||||
|
if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
||||||
if redirect == "" {
|
|
||||||
redirect = "/"
|
redirect = "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
return redirect, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) {
|
func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) {
|
||||||
@ -459,18 +475,24 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
|
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
|
||||||
p.ClearCookie(rw, req)
|
p.ClearSessionCookie(rw, req)
|
||||||
http.Redirect(rw, req, "/", 302)
|
http.Redirect(rw, req, "/", 302)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
nonce, err := cookie.Nonce()
|
||||||
|
if err != nil {
|
||||||
|
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.SetCSRFCookie(rw, req, nonce)
|
||||||
redirect, err := p.GetRedirect(req)
|
redirect, err := p.GetRedirect(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
p.ErrorPage(rw, 500, "Internal Error", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
redirectURI := p.GetRedirectURI(req.Host)
|
redirectURI := p.GetRedirectURI(req.Host)
|
||||||
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
|
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), 302)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||||
@ -495,7 +517,25 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
redirect := req.Form.Get("state")
|
s := strings.SplitN(req.Form.Get("state"), ":", 2)
|
||||||
|
if len(s) != 2 {
|
||||||
|
p.ErrorPage(rw, 500, "Internal Error", "Invalid State")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nonce := s[0]
|
||||||
|
redirect := s[1]
|
||||||
|
c, err := req.Cookie(p.CSRFCookieName)
|
||||||
|
if err != nil {
|
||||||
|
p.ErrorPage(rw, 403, "Permission Denied", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.ClearCSRFCookie(rw, req)
|
||||||
|
if c.Value != nonce {
|
||||||
|
log.Printf("%s csrf token mismatch, potential attack", remoteAddr)
|
||||||
|
p.ErrorPage(rw, 403, "Permission Denied", "csrf failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
||||||
redirect = "/"
|
redirect = "/"
|
||||||
}
|
}
|
||||||
@ -595,7 +635,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
|||||||
}
|
}
|
||||||
|
|
||||||
if clearSession {
|
if clearSession {
|
||||||
p.ClearCookie(rw, req)
|
p.ClearSessionCookie(rw, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
if session == nil {
|
if session == nil {
|
||||||
|
@ -170,10 +170,14 @@ func TestBasicAuthPassword(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
|
req, _ := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
|
||||||
strings.NewReader(""))
|
strings.NewReader(""))
|
||||||
|
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
|
||||||
proxy.ServeHTTP(rw, req)
|
proxy.ServeHTTP(rw, req)
|
||||||
cookie := rw.HeaderMap["Set-Cookie"][0]
|
if rw.Code >= 400 {
|
||||||
|
t.Fatalf("expected 3xx got %d", rw.Code)
|
||||||
|
}
|
||||||
|
cookie := rw.HeaderMap["Set-Cookie"][1]
|
||||||
|
|
||||||
cookieName := proxy.CookieName
|
cookieName := proxy.CookieName
|
||||||
var value string
|
var value string
|
||||||
@ -196,9 +200,11 @@ func TestBasicAuthPassword(t *testing.T) {
|
|||||||
Expires: time.Now().Add(time.Duration(24)),
|
Expires: time.Now().Add(time.Duration(24)),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
})
|
})
|
||||||
|
req.AddCookie(proxy.MakeCSRFCookie(req, "nonce", proxy.CookieExpire, time.Now()))
|
||||||
|
|
||||||
rw = httptest.NewRecorder()
|
rw = httptest.NewRecorder()
|
||||||
proxy.ServeHTTP(rw, req)
|
proxy.ServeHTTP(rw, req)
|
||||||
|
|
||||||
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword))
|
expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword))
|
||||||
assert.Equal(t, expectedHeader, rw.Body.String())
|
assert.Equal(t, expectedHeader, rw.Body.String())
|
||||||
provider_server.Close()
|
provider_server.Close()
|
||||||
@ -263,13 +269,14 @@ func (pat_test *PassAccessTokenTest) Close() {
|
|||||||
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
|
func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int,
|
||||||
cookie string) {
|
cookie string) {
|
||||||
rw := httptest.NewRecorder()
|
rw := httptest.NewRecorder()
|
||||||
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
|
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:",
|
||||||
strings.NewReader(""))
|
strings.NewReader(""))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, ""
|
return 0, ""
|
||||||
}
|
}
|
||||||
|
req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now()))
|
||||||
pat_test.proxy.ServeHTTP(rw, req)
|
pat_test.proxy.ServeHTTP(rw, req)
|
||||||
return rw.Code, rw.HeaderMap["Set-Cookie"][0]
|
return rw.Code, rw.HeaderMap["Set-Cookie"][1]
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
|
func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) {
|
||||||
@ -314,14 +321,18 @@ func TestForwardAccessTokenUpstream(t *testing.T) {
|
|||||||
|
|
||||||
// A successful validation will redirect and set the auth cookie.
|
// A successful validation will redirect and set the auth cookie.
|
||||||
code, cookie := pat_test.getCallbackEndpoint()
|
code, cookie := pat_test.getCallbackEndpoint()
|
||||||
assert.Equal(t, 302, code)
|
if code != 302 {
|
||||||
|
t.Fatalf("expected 302; got %d", code)
|
||||||
|
}
|
||||||
assert.NotEqual(t, nil, cookie)
|
assert.NotEqual(t, nil, cookie)
|
||||||
|
|
||||||
// Now we make a regular request; the access_token from the cookie is
|
// Now we make a regular request; the access_token from the cookie is
|
||||||
// forwarded as the "X-Forwarded-Access-Token" header. The token is
|
// forwarded as the "X-Forwarded-Access-Token" header. The token is
|
||||||
// read by the test provider server and written in the response body.
|
// read by the test provider server and written in the response body.
|
||||||
code, payload := pat_test.getRootEndpoint(cookie)
|
code, payload := pat_test.getRootEndpoint(cookie)
|
||||||
assert.Equal(t, 200, code)
|
if code != 200 {
|
||||||
|
t.Fatalf("expected 200; got %d", code)
|
||||||
|
}
|
||||||
assert.Equal(t, "my_auth_token", payload)
|
assert.Equal(t, "my_auth_token", payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -333,13 +344,17 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
|
|||||||
|
|
||||||
// A successful validation will redirect and set the auth cookie.
|
// A successful validation will redirect and set the auth cookie.
|
||||||
code, cookie := pat_test.getCallbackEndpoint()
|
code, cookie := pat_test.getCallbackEndpoint()
|
||||||
assert.Equal(t, 302, code)
|
if code != 302 {
|
||||||
|
t.Fatalf("expected 302; got %d", code)
|
||||||
|
}
|
||||||
assert.NotEqual(t, nil, cookie)
|
assert.NotEqual(t, nil, cookie)
|
||||||
|
|
||||||
// Now we make a regular request, but the access token header should
|
// Now we make a regular request, but the access token header should
|
||||||
// not be present.
|
// not be present.
|
||||||
code, payload := pat_test.getRootEndpoint(cookie)
|
code, payload := pat_test.getRootEndpoint(cookie)
|
||||||
assert.Equal(t, 200, code)
|
if code != 200 {
|
||||||
|
t.Fatalf("expected 200; got %d", code)
|
||||||
|
}
|
||||||
assert.Equal(t, "No access token found.", payload)
|
assert.Equal(t, "No access token found.", payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -457,7 +472,7 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
|
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
|
||||||
return p.proxy.MakeCookie(p.req, value, p.opts.CookieExpire, ref)
|
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
|
func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
|
||||||
@ -465,7 +480,7 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
p.req.AddCookie(p.proxy.MakeCookie(p.req, value, p.proxy.CookieExpire, ref))
|
p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -697,7 +712,7 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
cookie := proxy.MakeCookie(req, value, proxy.CookieExpire, time.Now())
|
cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now())
|
||||||
req.AddCookie(cookie)
|
req.AddCookie(cookie)
|
||||||
// This is used by the upstream to validate the signature.
|
// This is used by the upstream to validate the signature.
|
||||||
st.authenticator.auth = hmacauth.NewHmacAuth(
|
st.authenticator.auth = hmacauth.NewHmacAuth(
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/bitly/oauth2_proxy/cookie"
|
"github.com/bitly/oauth2_proxy/cookie"
|
||||||
)
|
)
|
||||||
@ -79,7 +78,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetLoginURL with typical oauth parameters
|
// GetLoginURL with typical oauth parameters
|
||||||
func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
|
func (p *ProviderData) GetLoginURL(redirectURI, state string) string {
|
||||||
var a url.URL
|
var a url.URL
|
||||||
a = *p.LoginURL
|
a = *p.LoginURL
|
||||||
params, _ := url.ParseQuery(a.RawQuery)
|
params, _ := url.ParseQuery(a.RawQuery)
|
||||||
@ -88,9 +87,7 @@ func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
|
|||||||
params.Add("scope", p.Scope)
|
params.Add("scope", p.Scope)
|
||||||
params.Set("client_id", p.ClientID)
|
params.Set("client_id", p.ClientID)
|
||||||
params.Set("response_type", "code")
|
params.Set("response_type", "code")
|
||||||
if strings.HasPrefix(finalRedirect, "/") && !strings.HasPrefix(finalRedirect,"//") {
|
params.Add("state", state)
|
||||||
params.Add("state", finalRedirect)
|
|
||||||
}
|
|
||||||
a.RawQuery = params.Encode()
|
a.RawQuery = params.Encode()
|
||||||
return a.String()
|
return a.String()
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user