Compare commits
11 Commits
master
...
oidc-refre
Author | SHA1 | Date | |
---|---|---|---|
|
acdd66c70c | ||
|
2bdc656590 | ||
|
bcd5ac513c | ||
|
7b7cc8fdc4 | ||
|
c1e1f38621 | ||
|
b19f87a884 | ||
|
dace5cde18 | ||
|
3940d7e1cd | ||
|
543575a7ad | ||
|
b31369d71d | ||
|
2e75a863be |
10
Dockerfile
Normal file
10
Dockerfile
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
FROM golang:1.9 AS builder
|
||||||
|
WORKDIR /go/src/github.com/bitly/oauth2_proxy
|
||||||
|
COPY . .
|
||||||
|
RUN go get -d -v; \
|
||||||
|
CGO_ENABLED=0 GOOS=linux go build
|
||||||
|
|
||||||
|
FROM scratch
|
||||||
|
COPY --from=builder /go/src/github.com/bitly/oauth2_proxy/oauth2_proxy /bin/oauth2_proxy
|
||||||
|
|
||||||
|
ENTRYPOINT ["/bin/oauth2_proxy"]
|
4
main.go
4
main.go
@ -18,6 +18,7 @@ func main() {
|
|||||||
flagSet := flag.NewFlagSet("oauth2_proxy", flag.ExitOnError)
|
flagSet := flag.NewFlagSet("oauth2_proxy", flag.ExitOnError)
|
||||||
|
|
||||||
emailDomains := StringArray{}
|
emailDomains := StringArray{}
|
||||||
|
whitelistDomains := StringArray{}
|
||||||
upstreams := StringArray{}
|
upstreams := StringArray{}
|
||||||
skipAuthRegex := StringArray{}
|
skipAuthRegex := StringArray{}
|
||||||
googleGroups := StringArray{}
|
googleGroups := StringArray{}
|
||||||
@ -37,12 +38,15 @@ func main() {
|
|||||||
flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header")
|
flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header")
|
||||||
flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header")
|
flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header")
|
||||||
flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream")
|
flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream")
|
||||||
|
flagSet.Bool("pass-authorization-header", false, "pass the Authorization Header to upstream")
|
||||||
|
flagSet.Bool("set-authorization-header", false, "set Authorization response headers (useful in Nginx auth_request mode)")
|
||||||
flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)")
|
flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)")
|
||||||
flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start")
|
flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start")
|
||||||
flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests")
|
flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests")
|
||||||
flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS")
|
flagSet.Bool("ssl-insecure-skip-verify", false, "skip validation of certificates presented when using HTTPS")
|
||||||
|
|
||||||
flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email")
|
flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email")
|
||||||
|
flagSet.Var(&whitelistDomains, "whitelist-domain", "allowed domains for redirection after authentication")
|
||||||
flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.")
|
flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.")
|
||||||
flagSet.String("github-org", "", "restrict logins to members of this organisation")
|
flagSet.String("github-org", "", "restrict logins to members of this organisation")
|
||||||
flagSet.String("github-team", "", "restrict logins to members of this team")
|
flagSet.String("github-team", "", "restrict logins to members of this team")
|
||||||
|
142
oauthproxy.go
142
oauthproxy.go
@ -54,6 +54,7 @@ type OAuthProxy struct {
|
|||||||
AuthOnlyPath string
|
AuthOnlyPath string
|
||||||
|
|
||||||
redirectURL *url.URL // the url to receive requests at
|
redirectURL *url.URL // the url to receive requests at
|
||||||
|
whitelistDomains []string
|
||||||
provider providers.Provider
|
provider providers.Provider
|
||||||
ProxyPrefix string
|
ProxyPrefix string
|
||||||
SignInMessage string
|
SignInMessage string
|
||||||
@ -66,6 +67,8 @@ type OAuthProxy struct {
|
|||||||
PassUserHeaders bool
|
PassUserHeaders bool
|
||||||
BasicAuthPassword string
|
BasicAuthPassword string
|
||||||
PassAccessToken bool
|
PassAccessToken bool
|
||||||
|
SetAuthorization bool
|
||||||
|
PassAuthorization bool
|
||||||
CookieCipher *cookie.Cipher
|
CookieCipher *cookie.Cipher
|
||||||
skipAuthRegex []string
|
skipAuthRegex []string
|
||||||
skipAuthPreflight bool
|
skipAuthPreflight bool
|
||||||
@ -163,7 +166,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
|||||||
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh)
|
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh)
|
||||||
|
|
||||||
var cipher *cookie.Cipher
|
var cipher *cookie.Cipher
|
||||||
if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
|
if opts.PassAccessToken || opts.SetAuthorization || opts.PassAuthorization || (opts.CookieRefresh != time.Duration(0)) {
|
||||||
var err error
|
var err error
|
||||||
cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret))
|
cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -194,6 +197,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
|||||||
provider: opts.provider,
|
provider: opts.provider,
|
||||||
serveMux: serveMux,
|
serveMux: serveMux,
|
||||||
redirectURL: redirectURL,
|
redirectURL: redirectURL,
|
||||||
|
whitelistDomains: opts.WhitelistDomains,
|
||||||
skipAuthRegex: opts.SkipAuthRegex,
|
skipAuthRegex: opts.SkipAuthRegex,
|
||||||
skipAuthPreflight: opts.SkipAuthPreflight,
|
skipAuthPreflight: opts.SkipAuthPreflight,
|
||||||
compiledRegex: opts.CompiledRegex,
|
compiledRegex: opts.CompiledRegex,
|
||||||
@ -202,6 +206,8 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
|||||||
PassUserHeaders: opts.PassUserHeaders,
|
PassUserHeaders: opts.PassUserHeaders,
|
||||||
BasicAuthPassword: opts.BasicAuthPassword,
|
BasicAuthPassword: opts.BasicAuthPassword,
|
||||||
PassAccessToken: opts.PassAccessToken,
|
PassAccessToken: opts.PassAccessToken,
|
||||||
|
SetAuthorization: opts.SetAuthorization,
|
||||||
|
PassAuthorization: opts.PassAuthorization,
|
||||||
SkipProviderButton: opts.SkipProviderButton,
|
SkipProviderButton: opts.SkipProviderButton,
|
||||||
CookieCipher: cipher,
|
CookieCipher: cipher,
|
||||||
templates: loadTemplates(opts.CustomTemplatesDir),
|
templates: loadTemplates(opts.CustomTemplatesDir),
|
||||||
@ -254,15 +260,92 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie {
|
||||||
if value != "" {
|
if value != "" {
|
||||||
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
|
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
|
||||||
if len(value) > 4096 {
|
}
|
||||||
// Cookies cannot be larger than 4kb
|
c := p.makeCookie(req, p.CookieName, value, expiration, now)
|
||||||
log.Printf("WARNING - Cookie Size: %d bytes", len(value))
|
if len(c.Value) > 4096 {
|
||||||
|
return splitCookie(c)
|
||||||
|
}
|
||||||
|
return []*http.Cookie{c}
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyCookie(c *http.Cookie) *http.Cookie {
|
||||||
|
return &http.Cookie{
|
||||||
|
Name: c.Name,
|
||||||
|
Value: c.Value,
|
||||||
|
Path: c.Path,
|
||||||
|
Domain: c.Domain,
|
||||||
|
Expires: c.Expires,
|
||||||
|
RawExpires: c.RawExpires,
|
||||||
|
MaxAge: c.MaxAge,
|
||||||
|
Secure: c.Secure,
|
||||||
|
HttpOnly: c.HttpOnly,
|
||||||
|
Raw: c.Raw,
|
||||||
|
Unparsed: c.Unparsed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitCookie(c *http.Cookie) []*http.Cookie {
|
||||||
|
if len(c.Value) < 3840 {
|
||||||
|
return []*http.Cookie{c}
|
||||||
|
}
|
||||||
|
cookies := []*http.Cookie{}
|
||||||
|
valueBytes := []byte(c.Value)
|
||||||
|
count := 0
|
||||||
|
for len(valueBytes) > 0 {
|
||||||
|
new := copyCookie(c)
|
||||||
|
new.Name = fmt.Sprintf("%s-%d", c.Name, count)
|
||||||
|
count++
|
||||||
|
if len(valueBytes) < 3840 {
|
||||||
|
new.Value = string(valueBytes)
|
||||||
|
valueBytes = []byte{}
|
||||||
|
} else {
|
||||||
|
newValue := valueBytes[:3840]
|
||||||
|
valueBytes = valueBytes[3840:]
|
||||||
|
new.Value = string(newValue)
|
||||||
|
}
|
||||||
|
cookies = append(cookies, new)
|
||||||
|
}
|
||||||
|
return cookies
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) {
|
||||||
|
if len(cookies) == 0 {
|
||||||
|
return nil, fmt.Errorf("Could not load cookie.")
|
||||||
|
}
|
||||||
|
if len(cookies) == 1 {
|
||||||
|
return cookies[0], nil
|
||||||
|
}
|
||||||
|
c := copyCookie(cookies[0])
|
||||||
|
for i := 1; i < len(cookies); i++ {
|
||||||
|
c.Value += cookies[i].Value
|
||||||
|
}
|
||||||
|
c.Name = strings.TrimRight(c.Name, "-0")
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
|
||||||
|
c, err := req.Cookie(cookieName)
|
||||||
|
if err == nil {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
cookies := []*http.Cookie{}
|
||||||
|
err = nil
|
||||||
|
count := 0
|
||||||
|
for err == nil {
|
||||||
|
var c *http.Cookie
|
||||||
|
c, err = req.Cookie(fmt.Sprintf("%s-%d", cookieName, count))
|
||||||
|
if err == nil {
|
||||||
|
cookies = append(cookies, c)
|
||||||
|
count++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return p.makeCookie(req, p.CookieName, value, expiration, now)
|
if len(cookies) == 0 {
|
||||||
|
return nil, fmt.Errorf("Could not find cookie %s", cookieName)
|
||||||
|
}
|
||||||
|
return joinCookies(cookies)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
|
||||||
@ -292,6 +375,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {
|
func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
|
||||||
http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -300,24 +384,28 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
|
func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) {
|
||||||
clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())
|
cookies := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now())
|
||||||
http.SetCookie(rw, clr)
|
for _, clr := range cookies {
|
||||||
|
http.SetCookie(rw, clr)
|
||||||
|
}
|
||||||
|
|
||||||
// ugly hack because default domain changed
|
// ugly hack because default domain changed
|
||||||
if p.CookieDomain == "" {
|
if p.CookieDomain == "" && len(cookies) > 0 {
|
||||||
clr2 := *clr
|
clr2 := *cookies[0]
|
||||||
clr2.Domain = req.Host
|
clr2.Domain = req.Host
|
||||||
http.SetCookie(rw, &clr2)
|
http.SetCookie(rw, &clr2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) {
|
||||||
http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()))
|
for _, c := range p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()) {
|
||||||
|
http.SetCookie(rw, c)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
|
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
|
||||||
var age time.Duration
|
var age time.Duration
|
||||||
c, err := req.Cookie(p.CookieName)
|
c, err := loadCookie(req, p.CookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// always http.ErrNoCookie
|
// always http.ErrNoCookie
|
||||||
return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
|
return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
|
||||||
@ -426,13 +514,33 @@ func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
redirect = req.Form.Get("rd")
|
redirect = req.Form.Get("rd")
|
||||||
if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
if !p.IsValidRedirect(redirect) {
|
||||||
redirect = "/"
|
redirect = "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//"):
|
||||||
|
return true
|
||||||
|
case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
|
||||||
|
url, err := url.Parse(redirect)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, domain := range p.whitelistDomains {
|
||||||
|
if (url.Host == domain) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(url.Host, domain)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) (ok bool) {
|
func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) (ok bool) {
|
||||||
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
|
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
|
||||||
return isPreflightRequestAllowed || p.IsWhitelistedPath(req.URL.Path)
|
return isPreflightRequestAllowed || p.IsWhitelistedPath(req.URL.Path)
|
||||||
@ -562,7 +670,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
|
if !p.IsValidRedirect(redirect) {
|
||||||
redirect = "/"
|
redirect = "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -698,6 +806,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
|
|||||||
if p.PassAccessToken && session.AccessToken != "" {
|
if p.PassAccessToken && session.AccessToken != "" {
|
||||||
req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken}
|
req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken}
|
||||||
}
|
}
|
||||||
|
if p.PassAuthorization && session.IdToken != "" {
|
||||||
|
req.Header["Authorization"] = []string{fmt.Sprintf("Bearer %s", session.IdToken)}
|
||||||
|
}
|
||||||
|
if p.SetAuthorization && session.IdToken != "" {
|
||||||
|
rw.Header().Set("Authorization", fmt.Sprintf("Bearer %s", session.IdToken))
|
||||||
|
}
|
||||||
if session.Email == "" {
|
if session.Email == "" {
|
||||||
rw.Header().Set("GAP-Auth", session.User)
|
rw.Header().Set("GAP-Auth", session.User)
|
||||||
} else {
|
} else {
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@ -92,6 +93,124 @@ func TestRobotsTxt(t *testing.T) {
|
|||||||
assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String())
|
assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestIsValidRedirect(t *testing.T) {
|
||||||
|
opts := NewOptions()
|
||||||
|
opts.ClientID = "bazquux"
|
||||||
|
opts.ClientSecret = "foobar"
|
||||||
|
opts.CookieSecret = "xyzzyplugh"
|
||||||
|
// Should match domains that are exactly foo.bar and any subdomain of bar.foo
|
||||||
|
opts.WhitelistDomains = []string{"foo.bar", ".bar.foo"}
|
||||||
|
opts.Validate()
|
||||||
|
|
||||||
|
proxy := NewOAuthProxy(opts, func(string) bool { return true })
|
||||||
|
|
||||||
|
noRD := proxy.IsValidRedirect("")
|
||||||
|
assert.Equal(t, false, noRD)
|
||||||
|
|
||||||
|
singleSlash := proxy.IsValidRedirect("/redirect")
|
||||||
|
assert.Equal(t, true, singleSlash)
|
||||||
|
|
||||||
|
doubleSlash := proxy.IsValidRedirect("//redirect")
|
||||||
|
assert.Equal(t, false, doubleSlash)
|
||||||
|
|
||||||
|
validHttp := proxy.IsValidRedirect("http://foo.bar/redirect")
|
||||||
|
assert.Equal(t, true, validHttp)
|
||||||
|
|
||||||
|
validHttps := proxy.IsValidRedirect("https://foo.bar/redirect")
|
||||||
|
assert.Equal(t, true, validHttps)
|
||||||
|
|
||||||
|
invalidHttpSubdomain := proxy.IsValidRedirect("http://baz.foo.bar/redirect")
|
||||||
|
assert.Equal(t, false, invalidHttpSubdomain)
|
||||||
|
|
||||||
|
invalidHttpsSubdomain := proxy.IsValidRedirect("https://baz.foo.bar/redirect")
|
||||||
|
assert.Equal(t, false, invalidHttpsSubdomain)
|
||||||
|
|
||||||
|
validHttpSubdomain := proxy.IsValidRedirect("http://baz.bar.foo/redirect")
|
||||||
|
assert.Equal(t, true, validHttpSubdomain)
|
||||||
|
|
||||||
|
validHttpsSubdomain := proxy.IsValidRedirect("https://baz.bar.foo/redirect")
|
||||||
|
assert.Equal(t, true, validHttpsSubdomain)
|
||||||
|
|
||||||
|
invalidHttp1 := proxy.IsValidRedirect("http://foo.bar.evil.corp/redirect")
|
||||||
|
assert.Equal(t, false, invalidHttp1)
|
||||||
|
|
||||||
|
invalidHttps1 := proxy.IsValidRedirect("https://foo.bar.evil.corp/redirect")
|
||||||
|
assert.Equal(t, false, invalidHttps1)
|
||||||
|
|
||||||
|
invalidHttp2 := proxy.IsValidRedirect("http://evil.corp/redirect?rd=foo.bar")
|
||||||
|
assert.Equal(t, false, invalidHttp2)
|
||||||
|
|
||||||
|
invalidHttps2 := proxy.IsValidRedirect("https://evil.corp/redirect?rd=foo.bar")
|
||||||
|
assert.Equal(t, false, invalidHttps2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomString(length int) string {
|
||||||
|
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
b := make([]byte, length)
|
||||||
|
for i := range b {
|
||||||
|
b[i] = charset[seededRand.Intn(len(charset))]
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitCookie(t *testing.T) {
|
||||||
|
c1 := &http.Cookie{
|
||||||
|
Name: "cookie-name",
|
||||||
|
Value: randomString(5120),
|
||||||
|
Path: "/",
|
||||||
|
Domain: "foo.bar",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
Expires: time.Now(),
|
||||||
|
}
|
||||||
|
cookies := splitCookie(c1)
|
||||||
|
assert.Equal(t, 2, len(cookies))
|
||||||
|
|
||||||
|
assert.Equal(t, c1.Name+"-0", cookies[0].Name)
|
||||||
|
assert.Equal(t, c1.Name+"-1", cookies[1].Name)
|
||||||
|
|
||||||
|
assert.Equal(t, 3840, len(cookies[0].Value))
|
||||||
|
assert.Equal(t, 5120-3840, len(cookies[1].Value))
|
||||||
|
|
||||||
|
c2 := &http.Cookie{
|
||||||
|
Name: "cookie-name",
|
||||||
|
Value: randomString(3000),
|
||||||
|
Path: "/",
|
||||||
|
Domain: "foo.bar",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
Expires: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
cookies2 := splitCookie(c2)
|
||||||
|
assert.Equal(t, 1, len(cookies2))
|
||||||
|
|
||||||
|
assert.Equal(t, c2.Name, cookies2[0].Name)
|
||||||
|
assert.Equal(t, c2.Value, cookies2[0].Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoinCookies(t *testing.T) {
|
||||||
|
c1 := &http.Cookie{
|
||||||
|
Name: "cookie-name",
|
||||||
|
Value: randomString(5120),
|
||||||
|
Path: "/",
|
||||||
|
Domain: "foo.bar",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
Expires: time.Now(),
|
||||||
|
}
|
||||||
|
// Split Cookies
|
||||||
|
cookies := splitCookie(c1)
|
||||||
|
assert.Equal(t, 2, len(cookies))
|
||||||
|
|
||||||
|
// join cookies should be the ivnerse
|
||||||
|
c2, _ := joinCookies(cookies)
|
||||||
|
|
||||||
|
assert.Equal(t, c1.Name, c2.Name)
|
||||||
|
assert.Equal(t, c1.Value, c2.Value)
|
||||||
|
}
|
||||||
|
|
||||||
type TestProvider struct {
|
type TestProvider struct {
|
||||||
*providers.ProviderData
|
*providers.ProviderData
|
||||||
EmailAddress string
|
EmailAddress string
|
||||||
@ -504,7 +623,7 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
|
func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie {
|
||||||
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -513,7 +632,9 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref))
|
for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) {
|
||||||
|
p.req.AddCookie(c)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -802,8 +923,9 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now())
|
for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) {
|
||||||
req.AddCookie(cookie)
|
req.AddCookie(c)
|
||||||
|
}
|
||||||
// This is used by the upstream to validate the signature.
|
// This is used by the upstream to validate the signature.
|
||||||
st.authenticator.auth = hmacauth.NewHmacAuth(
|
st.authenticator.auth = hmacauth.NewHmacAuth(
|
||||||
crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)
|
crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders)
|
||||||
|
@ -32,6 +32,7 @@ type Options struct {
|
|||||||
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
|
AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"`
|
||||||
AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"`
|
AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"`
|
||||||
EmailDomains []string `flag:"email-domain" cfg:"email_domains"`
|
EmailDomains []string `flag:"email-domain" cfg:"email_domains"`
|
||||||
|
WhitelistDomains []string `flag:"whitelist-domain" cfg:"whitelist_domains" env:"OAUTH2_PROXY_WHITELIST_DOMAINS"`
|
||||||
GitHubOrg string `flag:"github-org" cfg:"github_org"`
|
GitHubOrg string `flag:"github-org" cfg:"github_org"`
|
||||||
GitHubTeam string `flag:"github-team" cfg:"github_team"`
|
GitHubTeam string `flag:"github-team" cfg:"github_team"`
|
||||||
GoogleGroups []string `flag:"google-group" cfg:"google_group"`
|
GoogleGroups []string `flag:"google-group" cfg:"google_group"`
|
||||||
@ -60,6 +61,8 @@ type Options struct {
|
|||||||
PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"`
|
PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"`
|
||||||
SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"`
|
SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"`
|
||||||
SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"`
|
SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"`
|
||||||
|
SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header"`
|
||||||
|
PassAuthorization bool `flag:"pass-authorization-header" cfg:"pass_authorization_header"`
|
||||||
SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"`
|
SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"`
|
||||||
|
|
||||||
// These options allow for other providers besides Google, with
|
// These options allow for other providers besides Google, with
|
||||||
@ -110,6 +113,8 @@ func NewOptions() *Options {
|
|||||||
PassUserHeaders: true,
|
PassUserHeaders: true,
|
||||||
PassAccessToken: false,
|
PassAccessToken: false,
|
||||||
PassHostHeader: true,
|
PassHostHeader: true,
|
||||||
|
SetAuthorization: false,
|
||||||
|
PassAuthorization: false,
|
||||||
ApprovalPrompt: "force",
|
ApprovalPrompt: "force",
|
||||||
RequestLogging: true,
|
RequestLogging: true,
|
||||||
RequestLoggingFormat: defaultRequestLoggingFormat,
|
RequestLoggingFormat: defaultRequestLoggingFormat,
|
||||||
|
@ -141,6 +141,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
|
|||||||
}
|
}
|
||||||
s = &SessionState{
|
s = &SessionState{
|
||||||
AccessToken: jsonResponse.AccessToken,
|
AccessToken: jsonResponse.AccessToken,
|
||||||
|
IdToken: jsonResponse.IdToken,
|
||||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||||
RefreshToken: jsonResponse.RefreshToken,
|
RefreshToken: jsonResponse.RefreshToken,
|
||||||
Email: email,
|
Email: email,
|
||||||
|
@ -35,7 +35,59 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("token exchange: %v", err)
|
return nil, fmt.Errorf("token exchange: %v", err)
|
||||||
}
|
}
|
||||||
|
s, err = p.createSessionState(token, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to update session: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
||||||
|
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
origExpiration := s.ExpiresOn
|
||||||
|
|
||||||
|
err := p.redeemRefreshToken(s)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("unable to redeem refresh token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("refreshed id token %s (expired on %s)\n", s, origExpiration)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OIDCProvider) redeemRefreshToken(s *SessionState) (err error) {
|
||||||
|
c := oauth2.Config{
|
||||||
|
ClientID: p.ClientID,
|
||||||
|
ClientSecret: p.ClientSecret,
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
TokenURL: p.RedeemURL.String(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
t := &oauth2.Token{
|
||||||
|
RefreshToken: s.RefreshToken,
|
||||||
|
Expiry: time.Now().Add(-time.Hour),
|
||||||
|
}
|
||||||
|
token, err := c.TokenSource(ctx, t).Token()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get token: %v", err)
|
||||||
|
}
|
||||||
|
newSession, err := p.createSessionState(token, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to update session: %v", err)
|
||||||
|
}
|
||||||
|
s.AccessToken = newSession.AccessToken
|
||||||
|
s.IdToken = newSession.IdToken
|
||||||
|
s.RefreshToken = newSession.RefreshToken
|
||||||
|
s.ExpiresOn = newSession.ExpiresOn
|
||||||
|
s.Email = newSession.Email
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OIDCProvider) createSessionState(token *oauth2.Token, ctx context.Context) (*SessionState, error) {
|
||||||
rawIDToken, ok := token.Extra("id_token").(string)
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("token response did not contain an id_token")
|
return nil, fmt.Errorf("token response did not contain an id_token")
|
||||||
@ -63,23 +115,11 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
|
|||||||
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
return nil, fmt.Errorf("email in id_token (%s) isn't verified", claims.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
s = &SessionState{
|
return &SessionState{
|
||||||
AccessToken: token.AccessToken,
|
AccessToken: token.AccessToken,
|
||||||
|
IdToken: rawIDToken,
|
||||||
RefreshToken: token.RefreshToken,
|
RefreshToken: token.RefreshToken,
|
||||||
ExpiresOn: token.Expiry,
|
ExpiresOn: token.Expiry,
|
||||||
Email: claims.Email,
|
Email: claims.Email,
|
||||||
}
|
}, nil
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
|
|
||||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
origExpiration := s.ExpiresOn
|
|
||||||
s.ExpiresOn = time.Now().Add(time.Second).Truncate(time.Second)
|
|
||||||
fmt.Printf("refreshed access token %s (expired on %s)\n", s, origExpiration)
|
|
||||||
return false, nil
|
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
type SessionState struct {
|
type SessionState struct {
|
||||||
AccessToken string
|
AccessToken string
|
||||||
|
IdToken string
|
||||||
ExpiresOn time.Time
|
ExpiresOn time.Time
|
||||||
RefreshToken string
|
RefreshToken string
|
||||||
Email string
|
Email string
|
||||||
@ -29,6 +30,9 @@ func (s *SessionState) String() string {
|
|||||||
if s.AccessToken != "" {
|
if s.AccessToken != "" {
|
||||||
o += " token:true"
|
o += " token:true"
|
||||||
}
|
}
|
||||||
|
if s.IdToken != "" {
|
||||||
|
o += " id_token:true"
|
||||||
|
}
|
||||||
if !s.ExpiresOn.IsZero() {
|
if !s.ExpiresOn.IsZero() {
|
||||||
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
|
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
|
||||||
}
|
}
|
||||||
@ -60,13 +64,20 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
i := s.IdToken
|
||||||
|
if i != "" {
|
||||||
|
fmt.Printf("Encrytping ID Token")
|
||||||
|
if i, err = c.Encrypt(i); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
r := s.RefreshToken
|
r := s.RefreshToken
|
||||||
if r != "" {
|
if r != "" {
|
||||||
if r, err = c.Encrypt(r); err != nil {
|
if r, err = c.Encrypt(r); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil
|
return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeSessionStatePlain(v string) (s *SessionState, err error) {
|
func decodeSessionStatePlain(v string) (s *SessionState, err error) {
|
||||||
@ -90,8 +101,8 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
chunks := strings.Split(v, "|")
|
chunks := strings.Split(v, "|")
|
||||||
if len(chunks) != 4 {
|
if len(chunks) != 5 {
|
||||||
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
|
err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,11 +117,17 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ts, _ := strconv.Atoi(chunks[2])
|
if chunks[2] != "" {
|
||||||
|
if sessionState.IdToken, err = c.Decrypt(chunks[2]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, _ := strconv.Atoi(chunks[3])
|
||||||
sessionState.ExpiresOn = time.Unix(int64(ts), 0)
|
sessionState.ExpiresOn = time.Unix(int64(ts), 0)
|
||||||
|
|
||||||
if chunks[3] != "" {
|
if chunks[4] != "" {
|
||||||
if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil {
|
if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,12 +21,13 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
s := &SessionState{
|
s := &SessionState{
|
||||||
Email: "user@domain.com",
|
Email: "user@domain.com",
|
||||||
AccessToken: "token1234",
|
AccessToken: "token1234",
|
||||||
|
IdToken: "rawtoken1234",
|
||||||
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
|
||||||
RefreshToken: "refresh4321",
|
RefreshToken: "refresh4321",
|
||||||
}
|
}
|
||||||
encoded, err := s.EncodeSessionState(c)
|
encoded, err := s.EncodeSessionState(c)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, 3, strings.Count(encoded, "|"))
|
assert.Equal(t, 4, strings.Count(encoded, "|"))
|
||||||
|
|
||||||
ss, err := DecodeSessionState(encoded, c)
|
ss, err := DecodeSessionState(encoded, c)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
@ -34,6 +35,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
assert.Equal(t, "user", ss.User)
|
assert.Equal(t, "user", ss.User)
|
||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
assert.Equal(t, s.AccessToken, ss.AccessToken)
|
||||||
|
assert.Equal(t, s.IdToken, ss.IdToken)
|
||||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||||
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
|
||||||
|
|
||||||
@ -45,6 +47,7 @@ func TestSessionStateSerialization(t *testing.T) {
|
|||||||
assert.Equal(t, s.Email, ss.Email)
|
assert.Equal(t, s.Email, ss.Email)
|
||||||
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
|
||||||
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
|
||||||
|
assert.NotEqual(t, s.IdToken, ss.IdToken)
|
||||||
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
|
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
encoded, err := s.EncodeSessionState(c)
|
encoded, err := s.EncodeSessionState(c)
|
||||||
assert.Equal(t, nil, err)
|
assert.Equal(t, nil, err)
|
||||||
assert.Equal(t, 3, strings.Count(encoded, "|"))
|
assert.Equal(t, 4, strings.Count(encoded, "|"))
|
||||||
|
|
||||||
ss, err := DecodeSessionState(encoded, c)
|
ss, err := DecodeSessionState(encoded, c)
|
||||||
t.Logf("%#v", ss)
|
t.Logf("%#v", ss)
|
||||||
|
Loading…
Reference in New Issue
Block a user