diff --git a/README.md b/README.md index 83fc5fb..4a5656f 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ Usage of google_auth_proxy: -htpasswd-file="": additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption -http-address="127.0.0.1:4180": [http://]: or unix:// to listen on for HTTP clients -login-url="": Authentication endpoint + -pass-access-token=false: pass OAuth access_token to upstream via X-Forwarded-Access-Token header -pass-basic-auth=true: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream -pass-host-header=true: pass the request Host Header to upstream -profile-url="": Profile access endpoint diff --git a/cookies.go b/cookies.go index ffa0004..0ae6a92 100644 --- a/cookies.go +++ b/cookies.go @@ -1,10 +1,14 @@ package main import ( + "crypto/aes" + "crypto/cipher" "crypto/hmac" + "crypto/rand" "crypto/sha1" "encoding/base64" "fmt" + "io" "net/http" "strconv" "strings" @@ -59,3 +63,37 @@ func checkHmac(input, expected string) bool { } return false } + +func encodeAccessToken(aes_cipher cipher.Block, access_token string) (string, error) { + ciphertext := make([]byte, aes.BlockSize+len(access_token)) + iv := ciphertext[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return "", fmt.Errorf("failed to create access code initialization vector") + } + + stream := cipher.NewCFBEncrypter(aes_cipher, iv) + stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(access_token)) + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (string, error) { + encrypted_access_token, err := base64.StdEncoding.DecodeString( + encoded_access_token) + + if err != nil { + return "", fmt.Errorf("failed to decode access token") + } + + if len(encrypted_access_token) < aes.BlockSize { + return "", fmt.Errorf("encrypted access token should be "+ + "at least %d bytes, but is only %d bytes", + aes.BlockSize, len(encrypted_access_token)) + } + + iv := encrypted_access_token[:aes.BlockSize] + encrypted_access_token = encrypted_access_token[aes.BlockSize:] + stream := cipher.NewCFBDecrypter(aes_cipher, iv) + stream.XORKeyStream(encrypted_access_token, encrypted_access_token) + + return string(encrypted_access_token), nil +} diff --git a/cookies_test.go b/cookies_test.go new file mode 100644 index 0000000..d5470d0 --- /dev/null +++ b/cookies_test.go @@ -0,0 +1,23 @@ +package main + +import ( + "crypto/aes" + "github.com/bmizerany/assert" + "testing" +) + +func TestEncodeAndDecodeAccessToken(t *testing.T) { + const key = "0123456789abcdefghijklmnopqrstuv" + const access_token = "my access token" + c, err := aes.NewCipher([]byte(key)) + assert.Equal(t, nil, err) + + encoded_token, err := encodeAccessToken(c, access_token) + assert.Equal(t, nil, err) + + decoded_token, err := decodeAccessToken(c, encoded_token) + assert.Equal(t, nil, err) + + assert.NotEqual(t, access_token, encoded_token) + assert.Equal(t, access_token, decoded_token) +} diff --git a/main.go b/main.go index 5250dcd..d6c99da 100644 --- a/main.go +++ b/main.go @@ -30,6 +30,7 @@ func main() { flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint. If multiple, routing is based on path") flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") + flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") diff --git a/oauthproxy.go b/oauthproxy.go index 06d8621..db4e2e7 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -2,6 +2,8 @@ package main import ( "bytes" + "crypto/aes" + "crypto/cipher" "encoding/base64" "errors" "fmt" @@ -45,6 +47,8 @@ type OauthProxy struct { DisplayHtpasswdForm bool serveMux http.Handler PassBasicAuth bool + PassAccessToken bool + AesCipher cipher.Block skipAuthRegex []string compiledRegex []*regexp.Regexp templates *template.Template @@ -116,6 +120,29 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain) + var aes_cipher cipher.Block + + if opts.PassAccessToken == true { + valid_cookie_secret_size := false + for _, i := range []int{16, 24, 32} { + if len(opts.CookieSecret) == i { + valid_cookie_secret_size = true + } + } + if valid_cookie_secret_size == false { + log.Fatal("cookie_secret must be 16, 24, or 32 bytes " + + "to create an AES cipher when " + + "pass_access_token == true") + } + + var err error + aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret)) + if err != nil { + log.Fatal("error creating AES cipher with "+ + "pass_access_token == true: %s", err) + } + } + return &OauthProxy{ CookieKey: "_oauthproxy", CookieSeed: opts.CookieSecret, @@ -136,6 +163,8 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy { skipAuthRegex: opts.SkipAuthRegex, compiledRegex: opts.CompiledRegex, PassBasicAuth: opts.PassBasicAuth, + PassAccessToken: opts.PassAccessToken, + AesCipher: aes_cipher, templates: loadTemplates(opts.CustomTemplatesDir), } } @@ -337,6 +366,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { var ok bool var user string var email string + var access_token string if req.URL.Path == pingPath { p.PingPage(rw) @@ -390,7 +420,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - _, email, err := p.redeemCode(req.Host, req.Form.Get("code")) + access_token, email, err = p.redeemCode(req.Host, req.Form.Get("code")) if err != nil { log.Printf("%s error redeeming code %s", remoteAddr, err) p.ErrorPage(rw, 500, "Internal Error", err.Error()) @@ -405,7 +435,20 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // set cookie, or deny if p.Validator(email) { log.Printf("%s authenticating %s completed", remoteAddr, email) - p.SetCookie(rw, req, email) + encoded_token := "" + if p.PassAccessToken { + encoded_token, err = encodeAccessToken(p.AesCipher, access_token) + if err != nil { + log.Printf("error encoding access token: %s", err) + } + } + access_token = "" + + if encoded_token != "" { + p.SetCookie(rw, req, email+"|"+encoded_token) + } else { + p.SetCookie(rw, req, email) + } http.Redirect(rw, req, redirect, 302) return } else { @@ -417,7 +460,16 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if !ok { cookie, err := req.Cookie(p.CookieKey) if err == nil { - email, ok = validateCookie(cookie, p.CookieSeed) + var value string + value, ok = validateCookie(cookie, p.CookieSeed) + components := strings.Split(value, "|") + email = components[0] + if len(components) == 2 { + access_token, err = decodeAccessToken(p.AesCipher, components[1]) + if err != nil { + log.Printf("error decoding access token: %s", err) + } + } user = strings.Split(email, "@")[0] } } @@ -437,6 +489,9 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { req.Header["X-Forwarded-User"] = []string{user} req.Header["X-Forwarded-Email"] = []string{email} } + if access_token != "" { + req.Header["X-Forwarded-Access-Token"] = []string{access_token} + } if email == "" { rw.Header().Set("GAP-Auth", user) } else { diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 2a89dbe..bae7e6b 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -1,12 +1,17 @@ package main import ( + "github.com/bitly/go-simplejson" + "github.com/bitly/google_auth_proxy/providers" + "github.com/bmizerany/assert" "io/ioutil" "net" "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "time" ) func TestNewReverseProxy(t *testing.T) { @@ -18,8 +23,7 @@ func TestNewReverseProxy(t *testing.T) { defer backend.Close() backendURL, _ := url.Parse(backend.URL) - backendHostname := "upstream.127.0.0.1.xip.io" - _, backendPort, _ := net.SplitHostPort(backendURL.Host) + backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host) backendHost := net.JoinHostPort(backendHostname, backendPort) proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") @@ -61,3 +65,175 @@ func TestEncodedSlashes(t *testing.T) { t.Errorf("got bad request %q expected %q", seen, encodedPath) } } + +type TestProvider struct { + *providers.ProviderData + EmailAddress string +} + +func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json, + unused_access_token string) (string, error) { + return tp.EmailAddress, nil +} + +type PassAccessTokenTest struct { + provider_server *httptest.Server + proxy *OauthProxy + opts *Options +} + +type PassAccessTokenTestOptions struct { + PassAccessToken bool +} + +func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { + t := &PassAccessTokenTest{} + + t.provider_server = httptest.NewServer( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + url := r.URL + payload := "" + switch url.Path { + case "/oauth/token": + payload = `{"access_token": "my_auth_token"}` + default: + token_header := r.Header["X-Forwarded-Access-Token"] + if len(token_header) != 0 { + payload = token_header[0] + } else { + payload = "No access token found." + } + } + w.WriteHeader(200) + w.Write([]byte(payload)) + })) + + t.opts = NewOptions() + t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL) + // The CookieSecret must be 32 bytes in order to create the AES + // cipher. + t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" + t.opts.ClientID = "bazquux" + t.opts.ClientSecret = "foobar" + t.opts.CookieSecure = false + t.opts.PassAccessToken = opts.PassAccessToken + t.opts.Validate() + + provider_url, _ := url.Parse(t.provider_server.URL) + const email_address = "michael.bland@gsa.gov" + + t.opts.provider = &TestProvider{ + ProviderData: &providers.ProviderData{ + ProviderName: "Test Provider", + LoginUrl: &url.URL{ + Scheme: "http", + Host: provider_url.Host, + Path: "/oauth/authorize", + }, + RedeemUrl: &url.URL{ + Scheme: "http", + Host: provider_url.Host, + Path: "/oauth/token", + }, + ProfileUrl: &url.URL{ + Scheme: "http", + Host: provider_url.Host, + Path: "/api/v1/profile", + }, + Scope: "profile.email", + }, + EmailAddress: email_address, + } + + t.proxy = NewOauthProxy(t.opts, func(email string) bool { + return email == email_address + }) + return t +} + +func Close(t *PassAccessTokenTest) { + t.provider_server.Close() +} + +func getCallbackEndpoint(pac_test *PassAccessTokenTest) (http_code int, cookie string) { + rw := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code", + strings.NewReader("")) + if err != nil { + return 0, "" + } + pac_test.proxy.ServeHTTP(rw, req) + return rw.Code, rw.HeaderMap["Set-Cookie"][0] +} + +func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code int, + access_token string) { + cookie_key := pac_test.proxy.CookieKey + var value string + key_prefix := cookie_key + "=" + + for _, field := range strings.Split(cookie, "; ") { + value = strings.TrimPrefix(field, key_prefix) + if value != field { + break + } else { + value = "" + } + } + if value == "" { + return 0, "" + } + + req, err := http.NewRequest("GET", "/", strings.NewReader("")) + if err != nil { + return 0, "" + } + req.AddCookie(&http.Cookie{ + Name: cookie_key, + Value: value, + Path: "/", + Expires: time.Now().Add(time.Duration(24)), + HttpOnly: true, + }) + + rw := httptest.NewRecorder() + pac_test.proxy.ServeHTTP(rw, req) + return rw.Code, rw.Body.String() +} + +func TestForwardAccessTokenUpstream(t *testing.T) { + pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + PassAccessToken: true, + }) + defer Close(pac_test) + + // A successful validation will redirect and set the auth cookie. + code, cookie := getCallbackEndpoint(pac_test) + assert.Equal(t, 302, code) + assert.NotEqual(t, nil, cookie) + + // Now we make a regular request; the access_token from the cookie is + // forwarded as the "X-Forwarded-Access-Token" header. The token is + // read by the test provider server and written in the response body. + code, payload := getRootEndpoint(pac_test, cookie) + assert.Equal(t, 200, code) + assert.Equal(t, "my_auth_token", payload) +} + +func TestDoNotForwardAccessTokenUpstream(t *testing.T) { + pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + PassAccessToken: false, + }) + defer Close(pac_test) + + // A successful validation will redirect and set the auth cookie. + code, cookie := getCallbackEndpoint(pac_test) + assert.Equal(t, 302, code) + assert.NotEqual(t, nil, cookie) + + // Now we make a regular request, but the access token header should + // not be present. + code, payload := getRootEndpoint(pac_test, cookie) + assert.Equal(t, 200, code) + assert.Equal(t, "No access token found.", payload) +} diff --git a/options.go b/options.go index 85b4c01..e02cfb8 100644 --- a/options.go +++ b/options.go @@ -30,10 +30,11 @@ type Options struct { CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"` CookieHttpOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` - Upstreams []string `flag:"upstream" cfg:"upstreams"` - SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` - PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"` - PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` + Upstreams []string `flag:"upstream" cfg:"upstreams"` + SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` + PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"` + PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token"` + PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` // These options allow for other providers besides Google, with // potential overrides. @@ -61,6 +62,7 @@ func NewOptions() *Options { CookieHttpOnly: true, CookieExpire: time.Duration(168) * time.Hour, PassBasicAuth: true, + PassAccessToken: false, PassHostHeader: true, RequestLogging: true, }