Merge pull request #80 from 18F/pass-access-token

Pass the access token to the upstream server
This commit is contained in:
Jehiah Czebotar 2015-04-03 15:45:22 -04:00
commit 864d4787e9
7 changed files with 305 additions and 9 deletions

View File

@ -76,6 +76,7 @@ Usage of google_auth_proxy:
-htpasswd-file="": additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption -htpasswd-file="": additionally authenticate against a htpasswd file. Entries must be created with "htpasswd -s" for SHA encryption
-http-address="127.0.0.1:4180": [http://]<addr>:<port> or unix://<path> to listen on for HTTP clients -http-address="127.0.0.1:4180": [http://]<addr>:<port> or unix://<path> to listen on for HTTP clients
-login-url="": Authentication endpoint -login-url="": Authentication endpoint
-pass-access-token=false: pass OAuth access_token to upstream via X-Forwarded-Access-Token header
-pass-basic-auth=true: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream -pass-basic-auth=true: pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream
-pass-host-header=true: pass the request Host Header to upstream -pass-host-header=true: pass the request Host Header to upstream
-profile-url="": Profile access endpoint -profile-url="": Profile access endpoint

View File

@ -1,10 +1,14 @@
package main package main
import ( import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac" "crypto/hmac"
"crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
@ -59,3 +63,37 @@ func checkHmac(input, expected string) bool {
} }
return false return false
} }
func encodeAccessToken(aes_cipher cipher.Block, access_token string) (string, error) {
ciphertext := make([]byte, aes.BlockSize+len(access_token))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", fmt.Errorf("failed to create access code initialization vector")
}
stream := cipher.NewCFBEncrypter(aes_cipher, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(access_token))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (string, error) {
encrypted_access_token, err := base64.StdEncoding.DecodeString(
encoded_access_token)
if err != nil {
return "", fmt.Errorf("failed to decode access token")
}
if len(encrypted_access_token) < aes.BlockSize {
return "", fmt.Errorf("encrypted access token should be "+
"at least %d bytes, but is only %d bytes",
aes.BlockSize, len(encrypted_access_token))
}
iv := encrypted_access_token[:aes.BlockSize]
encrypted_access_token = encrypted_access_token[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(aes_cipher, iv)
stream.XORKeyStream(encrypted_access_token, encrypted_access_token)
return string(encrypted_access_token), nil
}

23
cookies_test.go Normal file
View File

@ -0,0 +1,23 @@
package main
import (
"crypto/aes"
"github.com/bmizerany/assert"
"testing"
)
func TestEncodeAndDecodeAccessToken(t *testing.T) {
const key = "0123456789abcdefghijklmnopqrstuv"
const access_token = "my access token"
c, err := aes.NewCipher([]byte(key))
assert.Equal(t, nil, err)
encoded_token, err := encodeAccessToken(c, access_token)
assert.Equal(t, nil, err)
decoded_token, err := decodeAccessToken(c, encoded_token)
assert.Equal(t, nil, err)
assert.NotEqual(t, access_token, encoded_token)
assert.Equal(t, access_token, decoded_token)
}

View File

@ -30,6 +30,7 @@ func main() {
flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"") flagSet.String("redirect-url", "", "the OAuth Redirect URL. ie: \"https://internalapp.yourcompany.com/oauth2/callback\"")
flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint. If multiple, routing is based on path") flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint. If multiple, routing is based on path")
flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream")
flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header")
flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream")
flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)")

View File

@ -2,6 +2,8 @@ package main
import ( import (
"bytes" "bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
@ -45,6 +47,8 @@ type OauthProxy struct {
DisplayHtpasswdForm bool DisplayHtpasswdForm bool
serveMux http.Handler serveMux http.Handler
PassBasicAuth bool PassBasicAuth bool
PassAccessToken bool
AesCipher cipher.Block
skipAuthRegex []string skipAuthRegex []string
compiledRegex []*regexp.Regexp compiledRegex []*regexp.Regexp
templates *template.Template templates *template.Template
@ -116,6 +120,29 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain) log.Printf("Cookie settings: secure (https):%v httponly:%v expiry:%s domain:%s", opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain)
var aes_cipher cipher.Block
if opts.PassAccessToken == true {
valid_cookie_secret_size := false
for _, i := range []int{16, 24, 32} {
if len(opts.CookieSecret) == i {
valid_cookie_secret_size = true
}
}
if valid_cookie_secret_size == false {
log.Fatal("cookie_secret must be 16, 24, or 32 bytes " +
"to create an AES cipher when " +
"pass_access_token == true")
}
var err error
aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret))
if err != nil {
log.Fatal("error creating AES cipher with "+
"pass_access_token == true: %s", err)
}
}
return &OauthProxy{ return &OauthProxy{
CookieKey: "_oauthproxy", CookieKey: "_oauthproxy",
CookieSeed: opts.CookieSecret, CookieSeed: opts.CookieSecret,
@ -136,6 +163,8 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
skipAuthRegex: opts.SkipAuthRegex, skipAuthRegex: opts.SkipAuthRegex,
compiledRegex: opts.CompiledRegex, compiledRegex: opts.CompiledRegex,
PassBasicAuth: opts.PassBasicAuth, PassBasicAuth: opts.PassBasicAuth,
PassAccessToken: opts.PassAccessToken,
AesCipher: aes_cipher,
templates: loadTemplates(opts.CustomTemplatesDir), templates: loadTemplates(opts.CustomTemplatesDir),
} }
} }
@ -337,6 +366,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
var ok bool var ok bool
var user string var user string
var email string var email string
var access_token string
if req.URL.Path == pingPath { if req.URL.Path == pingPath {
p.PingPage(rw) p.PingPage(rw)
@ -390,7 +420,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return return
} }
_, email, err := p.redeemCode(req.Host, req.Form.Get("code")) access_token, email, err = p.redeemCode(req.Host, req.Form.Get("code"))
if err != nil { if err != nil {
log.Printf("%s error redeeming code %s", remoteAddr, err) log.Printf("%s error redeeming code %s", remoteAddr, err)
p.ErrorPage(rw, 500, "Internal Error", err.Error()) p.ErrorPage(rw, 500, "Internal Error", err.Error())
@ -405,7 +435,20 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// set cookie, or deny // set cookie, or deny
if p.Validator(email) { if p.Validator(email) {
log.Printf("%s authenticating %s completed", remoteAddr, email) log.Printf("%s authenticating %s completed", remoteAddr, email)
encoded_token := ""
if p.PassAccessToken {
encoded_token, err = encodeAccessToken(p.AesCipher, access_token)
if err != nil {
log.Printf("error encoding access token: %s", err)
}
}
access_token = ""
if encoded_token != "" {
p.SetCookie(rw, req, email+"|"+encoded_token)
} else {
p.SetCookie(rw, req, email) p.SetCookie(rw, req, email)
}
http.Redirect(rw, req, redirect, 302) http.Redirect(rw, req, redirect, 302)
return return
} else { } else {
@ -417,7 +460,16 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if !ok { if !ok {
cookie, err := req.Cookie(p.CookieKey) cookie, err := req.Cookie(p.CookieKey)
if err == nil { if err == nil {
email, ok = validateCookie(cookie, p.CookieSeed) var value string
value, ok = validateCookie(cookie, p.CookieSeed)
components := strings.Split(value, "|")
email = components[0]
if len(components) == 2 {
access_token, err = decodeAccessToken(p.AesCipher, components[1])
if err != nil {
log.Printf("error decoding access token: %s", err)
}
}
user = strings.Split(email, "@")[0] user = strings.Split(email, "@")[0]
} }
} }
@ -437,6 +489,9 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
req.Header["X-Forwarded-User"] = []string{user} req.Header["X-Forwarded-User"] = []string{user}
req.Header["X-Forwarded-Email"] = []string{email} req.Header["X-Forwarded-Email"] = []string{email}
} }
if access_token != "" {
req.Header["X-Forwarded-Access-Token"] = []string{access_token}
}
if email == "" { if email == "" {
rw.Header().Set("GAP-Auth", user) rw.Header().Set("GAP-Auth", user)
} else { } else {

View File

@ -1,12 +1,17 @@
package main package main
import ( import (
"github.com/bitly/go-simplejson"
"github.com/bitly/google_auth_proxy/providers"
"github.com/bmizerany/assert"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"strings"
"testing" "testing"
"time"
) )
func TestNewReverseProxy(t *testing.T) { func TestNewReverseProxy(t *testing.T) {
@ -18,8 +23,7 @@ func TestNewReverseProxy(t *testing.T) {
defer backend.Close() defer backend.Close()
backendURL, _ := url.Parse(backend.URL) backendURL, _ := url.Parse(backend.URL)
backendHostname := "upstream.127.0.0.1.xip.io" backendHostname, backendPort, _ := net.SplitHostPort(backendURL.Host)
_, backendPort, _ := net.SplitHostPort(backendURL.Host)
backendHost := net.JoinHostPort(backendHostname, backendPort) backendHost := net.JoinHostPort(backendHostname, backendPort)
proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/") proxyURL, _ := url.Parse(backendURL.Scheme + "://" + backendHost + "/")
@ -61,3 +65,175 @@ func TestEncodedSlashes(t *testing.T) {
t.Errorf("got bad request %q expected %q", seen, encodedPath) t.Errorf("got bad request %q expected %q", seen, encodedPath)
} }
} }
type TestProvider struct {
*providers.ProviderData
EmailAddress string
}
func (tp *TestProvider) GetEmailAddress(unused_auth_response *simplejson.Json,
unused_access_token string) (string, error) {
return tp.EmailAddress, nil
}
type PassAccessTokenTest struct {
provider_server *httptest.Server
proxy *OauthProxy
opts *Options
}
type PassAccessTokenTestOptions struct {
PassAccessToken bool
}
func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest {
t := &PassAccessTokenTest{}
t.provider_server = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
url := r.URL
payload := ""
switch url.Path {
case "/oauth/token":
payload = `{"access_token": "my_auth_token"}`
default:
token_header := r.Header["X-Forwarded-Access-Token"]
if len(token_header) != 0 {
payload = token_header[0]
} else {
payload = "No access token found."
}
}
w.WriteHeader(200)
w.Write([]byte(payload))
}))
t.opts = NewOptions()
t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL)
// The CookieSecret must be 32 bytes in order to create the AES
// cipher.
t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp"
t.opts.ClientID = "bazquux"
t.opts.ClientSecret = "foobar"
t.opts.CookieSecure = false
t.opts.PassAccessToken = opts.PassAccessToken
t.opts.Validate()
provider_url, _ := url.Parse(t.provider_server.URL)
const email_address = "michael.bland@gsa.gov"
t.opts.provider = &TestProvider{
ProviderData: &providers.ProviderData{
ProviderName: "Test Provider",
LoginUrl: &url.URL{
Scheme: "http",
Host: provider_url.Host,
Path: "/oauth/authorize",
},
RedeemUrl: &url.URL{
Scheme: "http",
Host: provider_url.Host,
Path: "/oauth/token",
},
ProfileUrl: &url.URL{
Scheme: "http",
Host: provider_url.Host,
Path: "/api/v1/profile",
},
Scope: "profile.email",
},
EmailAddress: email_address,
}
t.proxy = NewOauthProxy(t.opts, func(email string) bool {
return email == email_address
})
return t
}
func Close(t *PassAccessTokenTest) {
t.provider_server.Close()
}
func getCallbackEndpoint(pac_test *PassAccessTokenTest) (http_code int, cookie string) {
rw := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code",
strings.NewReader(""))
if err != nil {
return 0, ""
}
pac_test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.HeaderMap["Set-Cookie"][0]
}
func getRootEndpoint(pac_test *PassAccessTokenTest, cookie string) (http_code int,
access_token string) {
cookie_key := pac_test.proxy.CookieKey
var value string
key_prefix := cookie_key + "="
for _, field := range strings.Split(cookie, "; ") {
value = strings.TrimPrefix(field, key_prefix)
if value != field {
break
} else {
value = ""
}
}
if value == "" {
return 0, ""
}
req, err := http.NewRequest("GET", "/", strings.NewReader(""))
if err != nil {
return 0, ""
}
req.AddCookie(&http.Cookie{
Name: cookie_key,
Value: value,
Path: "/",
Expires: time.Now().Add(time.Duration(24)),
HttpOnly: true,
})
rw := httptest.NewRecorder()
pac_test.proxy.ServeHTTP(rw, req)
return rw.Code, rw.Body.String()
}
func TestForwardAccessTokenUpstream(t *testing.T) {
pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: true,
})
defer Close(pac_test)
// A successful validation will redirect and set the auth cookie.
code, cookie := getCallbackEndpoint(pac_test)
assert.Equal(t, 302, code)
assert.NotEqual(t, nil, cookie)
// Now we make a regular request; the access_token from the cookie is
// forwarded as the "X-Forwarded-Access-Token" header. The token is
// read by the test provider server and written in the response body.
code, payload := getRootEndpoint(pac_test, cookie)
assert.Equal(t, 200, code)
assert.Equal(t, "my_auth_token", payload)
}
func TestDoNotForwardAccessTokenUpstream(t *testing.T) {
pac_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{
PassAccessToken: false,
})
defer Close(pac_test)
// A successful validation will redirect and set the auth cookie.
code, cookie := getCallbackEndpoint(pac_test)
assert.Equal(t, 302, code)
assert.NotEqual(t, nil, cookie)
// Now we make a regular request, but the access token header should
// not be present.
code, payload := getRootEndpoint(pac_test, cookie)
assert.Equal(t, 200, code)
assert.Equal(t, "No access token found.", payload)
}

View File

@ -33,6 +33,7 @@ type Options struct {
Upstreams []string `flag:"upstream" cfg:"upstreams"` Upstreams []string `flag:"upstream" cfg:"upstreams"`
SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"`
PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"` PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth"`
PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token"`
PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"` PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header"`
// These options allow for other providers besides Google, with // These options allow for other providers besides Google, with
@ -61,6 +62,7 @@ func NewOptions() *Options {
CookieHttpOnly: true, CookieHttpOnly: true,
CookieExpire: time.Duration(168) * time.Hour, CookieExpire: time.Duration(168) * time.Hour,
PassBasicAuth: true, PassBasicAuth: true,
PassAccessToken: false,
PassHostHeader: true, PassHostHeader: true,
RequestLogging: true, RequestLogging: true,
} }