SessionState refactoring; improve token renewal and cookie refresh

* New SessionState to consolidate email, access token and refresh token
* split ServeHttp into individual methods
* log on session renewal
* log on access token refresh
* refactor cookie encription/decription and session state serialization
This commit is contained in:
Jehiah Czebotar 2015-06-23 07:23:39 -04:00
parent b9ae5dc8d7
commit d49c3e167f
21 changed files with 883 additions and 597 deletions

View File

@ -3,6 +3,7 @@ package api
import ( import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
"net/http" "net/http"
"github.com/bitly/go-simplejson" "github.com/bitly/go-simplejson"
@ -11,10 +12,12 @@ import (
func Request(req *http.Request) (*simplejson.Json, error) { func Request(req *http.Request) (*simplejson.Json, error) {
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
log.Printf("%s %s %s", req.Method, req.URL, err)
return nil, err return nil, err
} }
body, err := ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
log.Printf("%d %s %s %s", resp.StatusCode, req.Method, req.URL, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }

128
cookie/cookies.go Normal file
View File

@ -0,0 +1,128 @@
package cookie
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
)
// cookies are stored in a 3 part (value + timestamp + signature) to enforce that the values are as originally set.
// additionally, the 'value' is encrypted so it's opaque to the browser
// Validate ensures a cookie is properly signed
func Validate(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) {
// value, timestamp, sig
parts := strings.Split(cookie.Value, "|")
if len(parts) != 3 {
return
}
sig := cookieSignature(seed, cookie.Name, parts[0], parts[1])
if checkHmac(parts[2], sig) {
ts, err := strconv.Atoi(parts[1])
if err != nil {
return
}
// The expiration timestamp set when the cookie was created
// isn't sent back by the browser. Hence, we check whether the
// creation timestamp stored in the cookie falls within the
// window defined by (Now()-expiration, Now()].
t = time.Unix(int64(ts), 0)
if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) {
// it's a valid cookie. now get the contents
rawValue, err := base64.URLEncoding.DecodeString(parts[0])
if err == nil {
value = string(rawValue)
ok = true
return
}
}
}
return
}
// SignedValue returns a cookie that is signed and can later be checked with Validate
func SignedValue(seed string, key string, value string, now time.Time) string {
encodedValue := base64.URLEncoding.EncodeToString([]byte(value))
timeStr := fmt.Sprintf("%d", now.Unix())
sig := cookieSignature(seed, key, encodedValue, timeStr)
cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig)
return cookieVal
}
func cookieSignature(args ...string) string {
h := hmac.New(sha1.New, []byte(args[0]))
for _, arg := range args[1:] {
h.Write([]byte(arg))
}
var b []byte
b = h.Sum(b)
return base64.URLEncoding.EncodeToString(b)
}
func checkHmac(input, expected string) bool {
inputMAC, err1 := base64.URLEncoding.DecodeString(input)
if err1 == nil {
expectedMAC, err2 := base64.URLEncoding.DecodeString(expected)
if err2 == nil {
return hmac.Equal(inputMAC, expectedMAC)
}
}
return false
}
// Cipher provides methods to encrypt and decrypt cookie values
type Cipher struct {
cipher.Block
}
// NewCipher returns a new aes Cipher for encrypting cookie values
func NewCipher(secret string) (*Cipher, error) {
c, err := aes.NewCipher([]byte(secret))
if err != nil {
return nil, err
}
return &Cipher{Block: c}, err
}
// Encrypt a value for use in a cookie
func (c *Cipher) Encrypt(value string) (string, error) {
ciphertext := make([]byte, aes.BlockSize+len(value))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", fmt.Errorf("failed to create initialization vector %s", err)
}
stream := cipher.NewCFBEncrypter(c.Block, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(value))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt a value from a cookie to it's original string
func (c *Cipher) Decrypt(s string) (string, error) {
encrypted, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return "", fmt.Errorf("failed to decrypt cookie value %s", err)
}
if len(encrypted) < aes.BlockSize {
return "", fmt.Errorf("encrypted cookie value should be "+
"at least %d bytes, but is only %d bytes",
aes.BlockSize, len(encrypted))
}
iv := encrypted[:aes.BlockSize]
encrypted = encrypted[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(c.Block, iv)
stream.XORKeyStream(encrypted, encrypted)
return string(encrypted), nil
}

23
cookie/cookies_test.go Normal file
View File

@ -0,0 +1,23 @@
package cookie
import (
"testing"
"github.com/bmizerany/assert"
)
func TestEncodeAndDecodeAccessToken(t *testing.T) {
const secret = "0123456789abcdefghijklmnopqrstuv"
const token = "my access token"
c, err := NewCipher(secret)
assert.Equal(t, nil, err)
encoded, err := c.Encrypt(token)
assert.Equal(t, nil, err)
decoded, err := c.Decrypt(encoded)
assert.Equal(t, nil, err)
assert.NotEqual(t, token, encoded)
assert.Equal(t, token, decoded)
}

View File

@ -1,140 +0,0 @@
package main
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
)
func validateCookie(cookie *http.Cookie, seed string, expiration time.Duration) (value string, t time.Time, ok bool) {
// value, timestamp, sig
parts := strings.Split(cookie.Value, "|")
if len(parts) != 3 {
return
}
sig := cookieSignature(seed, cookie.Name, parts[0], parts[1])
if checkHmac(parts[2], sig) {
ts, err := strconv.Atoi(parts[1])
if err != nil {
return
}
// The expiration timestamp set when the cookie was created
// isn't sent back by the browser. Hence, we check whether the
// creation timestamp stored in the cookie falls within the
// window defined by (Now()-expiration, Now()].
t = time.Unix(int64(ts), 0)
if t.After(time.Now().Add(expiration*-1)) && t.Before(time.Now().Add(time.Minute*5)) {
// it's a valid cookie. now get the contents
rawValue, err := base64.URLEncoding.DecodeString(parts[0])
if err == nil {
value = string(rawValue)
ok = true
return
}
}
}
return
}
func signedCookieValue(seed string, key string, value string, now time.Time) string {
encodedValue := base64.URLEncoding.EncodeToString([]byte(value))
timeStr := fmt.Sprintf("%d", now.Unix())
sig := cookieSignature(seed, key, encodedValue, timeStr)
cookieVal := fmt.Sprintf("%s|%s|%s", encodedValue, timeStr, sig)
return cookieVal
}
func cookieSignature(args ...string) string {
h := hmac.New(sha1.New, []byte(args[0]))
for _, arg := range args[1:] {
h.Write([]byte(arg))
}
var b []byte
b = h.Sum(b)
return base64.URLEncoding.EncodeToString(b)
}
func checkHmac(input, expected string) bool {
inputMAC, err1 := base64.URLEncoding.DecodeString(input)
if err1 == nil {
expectedMAC, err2 := base64.URLEncoding.DecodeString(expected)
if err2 == nil {
return hmac.Equal(inputMAC, expectedMAC)
}
}
return false
}
func encodeAccessToken(aes_cipher cipher.Block, access_token string) (string, error) {
ciphertext := make([]byte, aes.BlockSize+len(access_token))
iv := ciphertext[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", fmt.Errorf("failed to create access code initialization vector")
}
stream := cipher.NewCFBEncrypter(aes_cipher, iv)
stream.XORKeyStream(ciphertext[aes.BlockSize:], []byte(access_token))
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
func decodeAccessToken(aes_cipher cipher.Block, encoded_access_token string) (string, error) {
encrypted_access_token, err := base64.StdEncoding.DecodeString(
encoded_access_token)
if err != nil {
return "", fmt.Errorf("failed to decode access token")
}
if len(encrypted_access_token) < aes.BlockSize {
return "", fmt.Errorf("encrypted access token should be "+
"at least %d bytes, but is only %d bytes",
aes.BlockSize, len(encrypted_access_token))
}
iv := encrypted_access_token[:aes.BlockSize]
encrypted_access_token = encrypted_access_token[aes.BlockSize:]
stream := cipher.NewCFBDecrypter(aes_cipher, iv)
stream.XORKeyStream(encrypted_access_token, encrypted_access_token)
return string(encrypted_access_token), nil
}
func buildCookieValue(email string, aes_cipher cipher.Block,
access_token string) (string, error) {
if aes_cipher == nil {
return email, nil
}
encoded_token, err := encodeAccessToken(aes_cipher, access_token)
if err != nil {
return email, fmt.Errorf(
"error encoding access token for %s: %s", email, err)
}
return email + "|" + encoded_token, nil
}
func parseCookieValue(value string, aes_cipher cipher.Block) (email, user,
access_token string, err error) {
components := strings.Split(value, "|")
email = components[0]
user = strings.Split(email, "@")[0]
if aes_cipher != nil && len(components) == 2 {
access_token, err = decodeAccessToken(aes_cipher, components[1])
if err != nil {
err = fmt.Errorf(
"error decoding access token for %s: %s",
email, err)
}
}
return email, user, access_token, err
}

View File

@ -1,75 +0,0 @@
package main
import (
"crypto/aes"
"github.com/bmizerany/assert"
"strings"
"testing"
)
func TestEncodeAndDecodeAccessToken(t *testing.T) {
const key = "0123456789abcdefghijklmnopqrstuv"
const access_token = "my access token"
c, err := aes.NewCipher([]byte(key))
assert.Equal(t, nil, err)
encoded_token, err := encodeAccessToken(c, access_token)
assert.Equal(t, nil, err)
decoded_token, err := decodeAccessToken(c, encoded_token)
assert.Equal(t, nil, err)
assert.NotEqual(t, access_token, encoded_token)
assert.Equal(t, access_token, decoded_token)
}
func TestBuildCookieValueWithoutAccessToken(t *testing.T) {
value, err := buildCookieValue("michael.bland@gsa.gov", nil, "")
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", value)
}
func TestBuildCookieValueWithAccessTokenAndNilCipher(t *testing.T) {
value, err := buildCookieValue("michael.bland@gsa.gov", nil,
"access token")
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", value)
}
func TestParseCookieValueWithoutAccessToken(t *testing.T) {
email, user, access_token, err := parseCookieValue(
"michael.bland@gsa.gov", nil)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
assert.Equal(t, "michael.bland", user)
assert.Equal(t, "", access_token)
}
func TestParseCookieValueWithAccessTokenAndNilCipher(t *testing.T) {
email, user, access_token, err := parseCookieValue(
"michael.bland@gsa.gov|access_token", nil)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
assert.Equal(t, "michael.bland", user)
assert.Equal(t, "", access_token)
}
func TestBuildAndParseCookieValueWithAccessToken(t *testing.T) {
aes_cipher, err := aes.NewCipher([]byte("0123456789abcdef"))
assert.Equal(t, nil, err)
value, err := buildCookieValue("michael.bland@gsa.gov", aes_cipher,
"access_token")
assert.Equal(t, nil, err)
prefix := "michael.bland@gsa.gov|"
if !strings.HasPrefix(value, prefix) {
t.Fatal("cookie value does not start with \"%s\": %s",
prefix, value)
}
email, user, access_token, err := parseCookieValue(value, aes_cipher)
assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email)
assert.Equal(t, "michael.bland", user)
assert.Equal(t, "access_token", access_token)
}

View File

@ -1,8 +1,6 @@
package main package main
import ( import (
"crypto/aes"
"crypto/cipher"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
@ -16,6 +14,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/bitly/oauth2_proxy/cookie"
"github.com/bitly/oauth2_proxy/providers" "github.com/bitly/oauth2_proxy/providers"
) )
@ -44,7 +43,7 @@ type OauthProxy struct {
serveMux http.Handler serveMux http.Handler
PassBasicAuth bool PassBasicAuth bool
PassAccessToken bool PassAccessToken bool
AesCipher cipher.Block CookieCipher *cookie.Cipher
skipAuthRegex []string skipAuthRegex []string
compiledRegex []*regexp.Regexp compiledRegex []*regexp.Regexp
templates *template.Template templates *template.Template
@ -116,10 +115,10 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain, refresh) log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain, refresh)
var aes_cipher cipher.Block var cipher *cookie.Cipher
if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
var err error var err error
aes_cipher, err = aes.NewCipher([]byte(opts.CookieSecret)) cipher, err = cookie.NewCipher(opts.CookieSecret)
if err != nil { if err != nil {
log.Fatal("error creating AES cipher with "+ log.Fatal("error creating AES cipher with "+
"cookie-secret ", opts.CookieSecret, ": ", err) "cookie-secret ", opts.CookieSecret, ": ", err)
@ -150,7 +149,7 @@ func NewOauthProxy(opts *Options, validator func(string) bool) *OauthProxy {
compiledRegex: opts.CompiledRegex, compiledRegex: opts.CompiledRegex,
PassBasicAuth: opts.PassBasicAuth, PassBasicAuth: opts.PassBasicAuth,
PassAccessToken: opts.PassAccessToken, PassAccessToken: opts.PassAccessToken,
AesCipher: aes_cipher, CookieCipher: cipher,
templates: loadTemplates(opts.CustomTemplatesDir), templates: loadTemplates(opts.CustomTemplatesDir),
} }
} }
@ -177,22 +176,20 @@ func (p *OauthProxy) displayCustomLoginForm() bool {
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
} }
func (p *OauthProxy) redeemCode(host, code string) (string, string, error) { func (p *OauthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) {
if code == "" { if code == "" {
return "", "", errors.New("missing code") return nil, errors.New("missing code")
} }
redirectUri := p.GetRedirectURI(host) redirectUri := p.GetRedirectURI(host)
body, access_token, err := p.provider.Redeem(redirectUri, code) s, err = p.provider.Redeem(redirectUri, code)
if err != nil { if err != nil {
return "", "", err return
} }
email, err := p.provider.GetEmailAddress(body, access_token) if s.Email == "" {
if err != nil { s.Email, err = p.provider.GetEmailAddress(s)
return "", "", err
} }
return
return access_token, email, nil
} }
func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
@ -208,9 +205,8 @@ func (p *OauthProxy) MakeCookie(req *http.Request, value string, expiration time
} }
if value != "" { if value != "" {
value = signedCookieValue(p.CookieSeed, p.CookieName, value, now) value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
} }
return &http.Cookie{ return &http.Cookie{
Name: p.CookieName, Name: p.CookieName,
Value: value, Value: value,
@ -230,35 +226,34 @@ func (p *OauthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val st
http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now())) http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now()))
} }
func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (email, user, access_token string, ok bool) { func (p *OauthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
var value string var age time.Duration
var timestamp time.Time c, err := req.Cookie(p.CookieName)
cookie, err := req.Cookie(p.CookieName)
if err == nil {
value, timestamp, ok = validateCookie(cookie, p.CookieSeed, p.CookieExpire)
if ok {
email, user, access_token, err = parseCookieValue(value, p.AesCipher)
}
}
if err != nil { if err != nil {
log.Printf(err.Error()) // always http.ErrNoCookie
ok = false return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
} else if ok && p.CookieRefresh != time.Duration(0) { }
refresh := timestamp.Add(p.CookieRefresh) val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire)
if refresh.Before(time.Now()) { if !ok {
log.Printf("refreshing %s old session for %s (refresh after %s)", time.Now().Sub(timestamp), email, p.CookieRefresh) return nil, age, errors.New("Cookie Signature not valid")
ok = p.Validator(email) }
log.Printf("re-validating %s valid:%v", email, ok)
if ok { session, err := p.provider.SessionFromCookie(val, p.CookieCipher)
ok = p.provider.ValidateToken(access_token) if err != nil {
log.Printf("re-validating access token. valid:%v", ok) return nil, age, err
}
age = time.Now().Truncate(time.Second).Sub(timestamp)
return session, age, nil
}
func (p *OauthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
value, err := p.provider.CookieForSession(s, p.CookieCipher)
if err != nil {
return err
} }
if ok {
p.SetCookie(rw, req, value) p.SetCookie(rw, req, value)
} return nil
}
}
return
} }
func (p *OauthProxy) RobotsTxt(rw http.ResponseWriter) { func (p *OauthProxy) RobotsTxt(rw http.ResponseWriter) {
@ -344,54 +339,61 @@ func (p *OauthProxy) GetRedirect(req *http.Request) (string, error) {
return redirect, err return redirect, err
} }
func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (p *OauthProxy) IsWhitelistedPath(path string) (ok bool) {
// check if this is a redirect back at the end of oauth
remoteAddr := req.RemoteAddr
if req.Header.Get("X-Real-IP") != "" {
remoteAddr += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP"))
}
var ok bool
var user string
var email string
var access_token string
if req.URL.Path == p.RobotsPath {
p.RobotsTxt(rw)
return
}
if req.URL.Path == p.PingPath {
p.PingPage(rw)
return
}
for _, u := range p.compiledRegex { for _, u := range p.compiledRegex {
match := u.MatchString(req.URL.Path) ok = u.MatchString(path)
if match { if ok {
p.serveMux.ServeHTTP(rw, req) return
}
}
return return
} }
func getRemoteAddr(req *http.Request) (s string) {
s = req.RemoteAddr
if req.Header.Get("X-Real-IP") != "" {
s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP"))
}
return
} }
if req.URL.Path == p.SignInPath { func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
switch path := req.URL.Path; {
case path == p.RobotsPath:
p.RobotsTxt(rw)
case path == p.PingPath:
p.PingPage(rw)
case p.IsWhitelistedPath(path):
p.serveMux.ServeHTTP(rw, req)
case path == p.SignInPath:
p.SignIn(rw, req)
case path == p.OauthStartPath:
p.OauthStart(rw, req)
case path == p.OauthCallbackPath:
p.OauthCallback(rw, req)
default:
p.Proxy(rw, req)
}
}
func (p *OauthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.GetRedirect(req) redirect, err := p.GetRedirect(req)
if err != nil { if err != nil {
p.ErrorPage(rw, 500, "Internal Error", err.Error()) p.ErrorPage(rw, 500, "Internal Error", err.Error())
return return
} }
user, ok = p.ManualSignIn(rw, req) user, ok := p.ManualSignIn(rw, req)
if ok { if ok {
p.SetCookie(rw, req, user) session := &providers.SessionState{User: user}
p.SaveSession(rw, req, session)
http.Redirect(rw, req, redirect, 302) http.Redirect(rw, req, redirect, 302)
} else { } else {
p.SignInPage(rw, req, 200) p.SignInPage(rw, req, 200)
} }
return
} }
if req.URL.Path == p.OauthStartPath {
func (p *OauthProxy) OauthStart(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.GetRedirect(req) redirect, err := p.GetRedirect(req)
if err != nil { if err != nil {
p.ErrorPage(rw, 500, "Internal Error", err.Error()) p.ErrorPage(rw, 500, "Internal Error", err.Error())
@ -399,9 +401,11 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
} }
redirectURI := p.GetRedirectURI(req.Host) redirectURI := p.GetRedirectURI(req.Host)
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
return
} }
if req.URL.Path == p.OauthCallbackPath {
func (p *OauthProxy) OauthCallback(rw http.ResponseWriter, req *http.Request) {
remoteAddr := getRemoteAddr(req)
// finish the oauth cycle // finish the oauth cycle
err := req.ParseForm() err := req.ParseForm()
if err != nil { if err != nil {
@ -414,10 +418,10 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return return
} }
access_token, email, err = p.redeemCode(req.Host, req.Form.Get("code")) session, err := p.redeemCode(req.Host, req.Form.Get("code"))
if err != nil { if err != nil {
log.Printf("%s error redeeming code %s", remoteAddr, err) log.Printf("%s error redeeming code %s", remoteAddr, err)
p.ErrorPage(rw, 500, "Internal Error", err.Error()) p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
return return
} }
@ -427,73 +431,134 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
} }
// set cookie, or deny // set cookie, or deny
if p.Validator(email) { if p.Validator(session.Email) {
log.Printf("%s authenticating %s completed", remoteAddr, email) log.Printf("%s authentication complete %s", remoteAddr, session)
value, err := buildCookieValue( err := p.SaveSession(rw, req, session)
email, p.AesCipher, access_token)
if err != nil { if err != nil {
log.Printf("%s", err) log.Printf("%s %s", remoteAddr, err)
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
return
} }
p.SetCookie(rw, req, value)
http.Redirect(rw, req, redirect, 302) http.Redirect(rw, req, redirect, 302)
return
} else { } else {
log.Printf("validating: %s is unauthorized") log.Printf("%s Permission Denied: %q is unauthorized", remoteAddr, session.Email)
p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account")
}
}
func (p *OauthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
var saveSession, clearSession, revalidated bool
remoteAddr := getRemoteAddr(req)
session, sessionAge, err := p.LoadCookiedSession(req)
if err != nil {
log.Printf("%s %s", remoteAddr, err)
}
if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
log.Printf("%s refreshing %s old session cookie for %s (refresh after %s)", remoteAddr, sessionAge, session, p.CookieRefresh)
saveSession = true
}
if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil {
log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
clearSession = true
session = nil
} else if ok {
saveSession = true
revalidated = true
}
if session != nil && session.IsExpired() {
log.Printf("%s removing session. token expired %s", remoteAddr, session)
session = nil
saveSession = false
clearSession = true
}
if saveSession && !revalidated && session.AccessToken != "" {
if !p.provider.ValidateSessionState(session) {
log.Printf("%s removing session. error validating %s", remoteAddr, session)
saveSession = false
session = nil
clearSession = true
}
}
if saveSession && session.Email != "" && !p.Validator(session.Email) {
log.Printf("%s Permission Denied: removing session %s", remoteAddr, session)
session = nil
saveSession = false
clearSession = true
}
if saveSession {
err := p.SaveSession(rw, req, session)
if err != nil {
log.Printf("%s %s", remoteAddr, err)
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
return return
} }
} }
if !ok { if clearSession {
email, user, access_token, ok = p.ProcessCookie(rw, req) p.ClearCookie(rw, req)
} }
if !ok { if session == nil {
user, ok = p.CheckBasicAuth(req) session, err = p.CheckBasicAuth(req)
if err != nil {
log.Printf("%s %s", remoteAddr, err)
}
} }
if !ok { if session == nil {
p.SignInPage(rw, req, 403) p.SignInPage(rw, req, 403)
return return
} }
// At this point, the user is authenticated. proxy normally // At this point, the user is authenticated. proxy normally
if p.PassBasicAuth { if p.PassBasicAuth {
req.SetBasicAuth(user, "") req.SetBasicAuth(session.User, "")
req.Header["X-Forwarded-User"] = []string{user} req.Header["X-Forwarded-User"] = []string{session.User}
req.Header["X-Forwarded-Email"] = []string{email} if session.Email != "" {
req.Header["X-Forwarded-Email"] = []string{session.Email}
} }
if p.PassAccessToken {
req.Header["X-Forwarded-Access-Token"] = []string{access_token}
} }
if email == "" { if p.PassAccessToken && session.AccessToken != "" {
rw.Header().Set("GAP-Auth", user) req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken}
}
if session.Email == "" {
rw.Header().Set("GAP-Auth", session.User)
} else { } else {
rw.Header().Set("GAP-Auth", email) rw.Header().Set("GAP-Auth", session.Email)
} }
p.serveMux.ServeHTTP(rw, req) p.serveMux.ServeHTTP(rw, req)
} }
func (p *OauthProxy) CheckBasicAuth(req *http.Request) (string, bool) { func (p *OauthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
if p.HtpasswdFile == nil { if p.HtpasswdFile == nil {
return "", false return nil, nil
} }
s := strings.SplitN(req.Header.Get("Authorization"), " ", 2) auth := req.Header.Get("Authorization")
if auth == "" {
return nil, nil
}
s := strings.SplitN(auth, " ", 2)
if len(s) != 2 || s[0] != "Basic" { if len(s) != 2 || s[0] != "Basic" {
return "", false return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization"))
} }
b, err := base64.StdEncoding.DecodeString(s[1]) b, err := base64.StdEncoding.DecodeString(s[1])
if err != nil { if err != nil {
return "", false return nil, err
} }
pair := strings.SplitN(string(b), ":", 2) pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 { if len(pair) != 2 {
return "", false return nil, fmt.Errorf("invalid format %s", b)
} }
if p.HtpasswdFile.Validate(pair[0], pair[1]) { if p.HtpasswdFile.Validate(pair[0], pair[1]) {
log.Printf("authenticated %q via basic auth", pair[0]) log.Printf("authenticated %q via basic auth", pair[0])
return pair[0], true return &providers.SessionState{User: pair[0]}, nil
} }
return "", false return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0])
} }

View File

@ -94,11 +94,11 @@ type TestProvider struct {
ValidToken bool ValidToken bool
} }
func (tp *TestProvider) GetEmailAddress(body []byte, access_token string) (string, error) { func (tp *TestProvider) GetEmailAddress(session *providers.SessionState) (string, error) {
return tp.EmailAddress, nil return tp.EmailAddress, nil
} }
func (tp *TestProvider) ValidateToken(access_token string) bool { func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bool {
return tp.ValidToken return tp.ValidToken
} }
@ -378,97 +378,73 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest {
}) })
} }
func (p *ProcessCookieTest) MakeCookie(value, access_token string, ref time.Time) *http.Cookie { func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie {
cookie_value, _ := buildCookieValue(value, p.proxy.AesCipher, access_token) return p.proxy.MakeCookie(p.req, value, p.opts.CookieExpire, ref)
return p.proxy.MakeCookie(p.req, cookie_value, p.opts.CookieExpire, ref)
} }
func (p *ProcessCookieTest) AddCookie(value, access_token string) { func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time) error {
p.req.AddCookie(p.MakeCookie(value, access_token, time.Now())) value, err := p.proxy.provider.CookieForSession(s, p.proxy.CookieCipher)
if err != nil {
return err
}
p.req.AddCookie(p.proxy.MakeCookie(p.req, value, p.proxy.CookieExpire, ref))
return nil
} }
func (p *ProcessCookieTest) ProcessCookie() (email, user, access_token string, ok bool) { func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time.Duration, error) {
return p.proxy.ProcessCookie(p.rw, p.req) return p.proxy.LoadCookiedSession(p.req)
} }
func TestProcessCookie(t *testing.T) { func TestLoadCookiedSession(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pc_test := NewProcessCookieTestWithDefaults()
pc_test.AddCookie("michael.bland@gsa.gov", "my_access_token") startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
email, user, access_token, ok := pc_test.ProcessCookie() pc_test.SaveSession(startSession, time.Now())
assert.Equal(t, true, ok)
assert.Equal(t, "michael.bland@gsa.gov", email) session, _, err := pc_test.LoadCookiedSession()
assert.Equal(t, "michael.bland", user) assert.Equal(t, nil, err)
assert.Equal(t, "my_access_token", access_token) assert.Equal(t, startSession.Email, session.Email)
assert.Equal(t, "michael.bland", session.User)
assert.Equal(t, startSession.AccessToken, session.AccessToken)
} }
func TestProcessCookieNoCookieError(t *testing.T) { func TestProcessCookieNoCookieError(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pc_test := NewProcessCookieTestWithDefaults()
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, false, ok)
}
func TestProcessCookieFailIfParsingCookieValueFails(t *testing.T) { session, _, err := pc_test.LoadCookiedSession()
pc_test := NewProcessCookieTestWithDefaults() assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error())
value, _ := buildCookieValue("michael.bland@gsa.gov", if session != nil {
pc_test.proxy.AesCipher, "my_access_token") t.Errorf("expected nil session. got %#v", session)
pc_test.req.AddCookie(pc_test.proxy.MakeCookie( }
pc_test.req, value+"some bogus bytes",
pc_test.opts.CookieExpire, time.Now()))
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, false, ok)
} }
func TestProcessCookieRefreshNotSet(t *testing.T) { func TestProcessCookieRefreshNotSet(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-2) * time.Hour) reference := time.Now().Add(time.Duration(-2) * time.Hour)
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "", reference)
pc_test.req.AddCookie(cookie)
_, _, _, ok := pc_test.ProcessCookie() startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
assert.Equal(t, true, ok) pc_test.SaveSession(startSession, reference)
assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
session, age, err := pc_test.LoadCookiedSession()
assert.Equal(t, nil, err)
if age < time.Duration(-2)*time.Hour {
t.Errorf("cookie too young %v", age)
} }
assert.Equal(t, startSession.Email, session.Email)
func TestProcessCookieRefresh(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-2) * time.Hour)
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference)
pc_test.req.AddCookie(cookie)
pc_test.proxy.CookieRefresh = time.Hour
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, true, ok)
assert.NotEqual(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
}
func TestProcessCookieRefreshThresholdNotCrossed(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-30) * time.Minute)
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference)
pc_test.req.AddCookie(cookie)
pc_test.proxy.CookieRefresh = time.Hour
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, true, ok)
assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
} }
func TestProcessCookieFailIfCookieExpired(t *testing.T) { func TestProcessCookieFailIfCookieExpired(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.req.AddCookie(cookie) pc_test.SaveSession(startSession, reference)
if _, _, _, ok := pc_test.ProcessCookie(); ok { session, _, err := pc_test.LoadCookiedSession()
t.Error("ProcessCookie() should have failed") assert.NotEqual(t, nil, err)
} if session != nil {
if set_cookie := pc_test.rw.HeaderMap["Set-Cookie"]; set_cookie != nil { t.Errorf("expected nil session %#v", session)
t.Error("expected Set-Cookie to be nil, instead was: ", set_cookie)
} }
} }
@ -476,44 +452,13 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults() pc_test := NewProcessCookieTestWithDefaults()
pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour
reference := time.Now().Add(time.Duration(25) * time.Hour * -1) reference := time.Now().Add(time.Duration(25) * time.Hour * -1)
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference) startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"}
pc_test.req.AddCookie(cookie) pc_test.SaveSession(startSession, reference)
pc_test.proxy.CookieRefresh = time.Hour pc_test.proxy.CookieRefresh = time.Hour
if _, _, _, ok := pc_test.ProcessCookie(); ok { session, _, err := pc_test.LoadCookiedSession()
t.Error("ProcessCookie() should have failed") assert.NotEqual(t, nil, err)
} if session != nil {
if set_cookie := pc_test.rw.HeaderMap["Set-Cookie"]; set_cookie != nil { t.Errorf("expected nil session %#v", session)
t.Error("expected Set-Cookie to be nil, instead was: ", set_cookie)
} }
} }
func TestProcessCookieFailIfRefreshSetAndTokenNoLongerValid(t *testing.T) {
pc_test := NewProcessCookieTest(ProcessCookieTestOpts{
provider_validate_cookie_response: false,
})
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-24) * time.Hour)
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference)
pc_test.req.AddCookie(cookie)
pc_test.proxy.CookieRefresh = time.Hour
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, false, ok)
assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
}
func TestProcessCookieFailIfRefreshSetAndUserNoLongerValid(t *testing.T) {
pc_test := NewProcessCookieTestWithDefaults()
pc_test.validate_user = false
pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour
reference := time.Now().Add(time.Duration(-2) * time.Hour)
cookie := pc_test.MakeCookie("michael.bland@gsa.gov", "my_access_token", reference)
pc_test.req.AddCookie(cookie)
pc_test.proxy.CookieRefresh = time.Hour
_, _, _, ok := pc_test.ProcessCookie()
assert.Equal(t, false, ok)
assert.Equal(t, []string(nil), pc_test.rw.HeaderMap["Set-Cookie"])
}

View File

@ -2,8 +2,10 @@ package providers
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
"net/http" "net/http"
"net/url" "net/url"
) )
@ -138,7 +140,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
return false, nil return false, nil
} }
func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (string, error) { func (p *GitHubProvider) GetEmailAddress(s *SessionState) (string, error) {
var emails []struct { var emails []struct {
Email string `json:"email"` Email string `json:"email"`
@ -148,31 +150,34 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri
// if we require an Org or Team, check that first // if we require an Org or Team, check that first
if p.Org != "" { if p.Org != "" {
if p.Team != "" { if p.Team != "" {
if ok, err := p.hasOrgAndTeam(access_token); err != nil || !ok { if ok, err := p.hasOrgAndTeam(s.AccessToken); err != nil || !ok {
return "", err return "", err
} }
} else { } else {
if ok, err := p.hasOrg(access_token); err != nil || !ok { if ok, err := p.hasOrg(s.AccessToken); err != nil || !ok {
return "", err return "", err
} }
} }
} }
params := url.Values{ params := url.Values{
"access_token": {access_token}, "access_token": {s.AccessToken},
} }
endpoint := "https://api.github.com/user/emails?" + params.Encode() endpoint := "https://api.github.com/user/emails?" + params.Encode()
resp, err := http.DefaultClient.Get(endpoint) resp, err := http.DefaultClient.Get(endpoint)
if err != nil { if err != nil {
return "", err return "", err
} }
body, err = ioutil.ReadAll(resp.Body) body, err := ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
return "", err return "", err
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint, body) return "", fmt.Errorf("got %d from %q %s", resp.StatusCode, endpoint, body)
} else {
log.Printf("got %d from %q %s", resp.StatusCode, endpoint, body)
} }
if err := json.Unmarshal(body, &emails); err != nil { if err := json.Unmarshal(body, &emails); err != nil {
@ -185,9 +190,5 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri
} }
} }
return "", nil return "", errors.New("no email address found")
}
func (p *GitHubProvider) ValidateToken(access_token string) bool {
return validateToken(p, access_token, nil)
} }

View File

@ -7,9 +7,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time"
) )
type GoogleProvider struct { type GoogleProvider struct {
@ -43,18 +45,11 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider {
return &GoogleProvider{ProviderData: p} return &GoogleProvider{ProviderData: p}
} }
func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (string, error) { func emailFromIdToken(idToken string) (string, error) {
var response struct {
IdToken string `json:"id_token"`
}
if err := json.Unmarshal(body, &response); err != nil {
return "", err
}
// id_token is a base64 encode ID token payload // id_token is a base64 encode ID token payload
// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
jwt := strings.Split(response.IdToken, ".") jwt := strings.Split(idToken, ".")
b, err := jwtDecodeSegment(jwt[1]) b, err := jwtDecodeSegment(jwt[1])
if err != nil { if err != nil {
return "", err return "", err
@ -62,6 +57,7 @@ func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (stri
var email struct { var email struct {
Email string `json:"email"` Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
} }
err = json.Unmarshal(b, &email) err = json.Unmarshal(b, &email)
if err != nil { if err != nil {
@ -70,6 +66,9 @@ func (s *GoogleProvider) GetEmailAddress(body []byte, access_token string) (stri
if email.Email == "" { if email.Email == "" {
return "", errors.New("missing email") return "", errors.New("missing email")
} }
if !email.EmailVerified {
return "", fmt.Errorf("email %s not listed as verified", email.Email)
}
return email.Email, nil return email.Email, nil
} }
@ -81,11 +80,7 @@ func jwtDecodeSegment(seg string) ([]byte, error) {
return base64.URLEncoding.DecodeString(seg) return base64.URLEncoding.DecodeString(seg)
} }
func (p *GoogleProvider) ValidateToken(access_token string) bool { func (p *GoogleProvider) Redeem(redirectUrl, code string) (s *SessionState, err error) {
return validateToken(p, access_token, nil)
}
func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token string, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -108,6 +103,7 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token st
if err != nil { if err != nil {
return return
} }
var body []byte
body, err = ioutil.ReadAll(resp.Body) body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
@ -122,17 +118,44 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token st
var jsonResponse struct { var jsonResponse struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
IdToken string `json:"id_token"`
} }
err = json.Unmarshal(body, &jsonResponse) err = json.Unmarshal(body, &jsonResponse)
if err != nil { if err != nil {
return return
} }
var email string
token, err = p.redeemRefreshToken(jsonResponse.RefreshToken) email, err = emailFromIdToken(jsonResponse.IdToken)
if err != nil {
return
}
s = &SessionState{
AccessToken: jsonResponse.AccessToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
RefreshToken: jsonResponse.RefreshToken,
Email: email,
}
return return
} }
func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, err error) { func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil
}
newToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
if err != nil {
return false, err
}
origExpiration := s.ExpiresOn
s.AccessToken = newToken
s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
log.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
return true, nil
}
func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) {
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
params := url.Values{} params := url.Values{}
params.Add("client_id", p.ClientID) params.Add("client_id", p.ClientID)
@ -162,12 +185,15 @@ func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string,
return return
} }
var jsonResponse struct { var data struct {
AccessToken string `json:"access_token"` AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
} }
err = json.Unmarshal(body, &jsonResponse) err = json.Unmarshal(body, &data)
if err != nil { if err != nil {
return return
} }
return jsonResponse.AccessToken, nil token = data.AccessToken
expires = time.Duration(data.ExpiresIn) * time.Second
return
} }

View File

@ -3,11 +3,22 @@ package providers
import ( import (
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"github.com/bmizerany/assert" "net/http"
"net/http/httptest"
"net/url" "net/url"
"testing" "testing"
"github.com/bmizerany/assert"
) )
func newRedeemServer(body []byte) (*url.URL, *httptest.Server) {
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write(body)
}))
u, _ := url.Parse(s.URL)
return u, s
}
func newGoogleProvider() *GoogleProvider { func newGoogleProvider() *GoogleProvider {
return NewGoogleProvider( return NewGoogleProvider(
&ProviderData{ &ProviderData{
@ -66,63 +77,88 @@ func TestGoogleProviderOverrides(t *testing.T) {
assert.Equal(t, "profile", p.Data().Scope) assert.Equal(t, "profile", p.Data().Scope)
} }
func TestGoogleProviderGetEmailAddress(t *testing.T) { type redeemResponse struct {
p := newGoogleProvider() AccessToken string `json:"access_token"`
body, err := json.Marshal( RefreshToken string `json:"refresh_token"`
struct { ExpiresIn int64 `json:"expires_in"`
IdToken string `json:"id_token"` IdToken string `json:"id_token"`
}{
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov"}`)),
},
)
assert.Equal(t, nil, err)
email, err := p.GetEmailAddress(body, "ignored access_token")
assert.Equal(t, "michael.bland@gsa.gov", email)
assert.Equal(t, nil, err)
} }
func TestGoogleProviderGetEmailAddress(t *testing.T) {
p := newGoogleProvider()
body, err := json.Marshal(redeemResponse{
AccessToken: "a1234",
ExpiresIn: 10,
RefreshToken: "refresh12345",
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)),
})
assert.Equal(t, nil, err)
var server *httptest.Server
p.RedeemUrl, server = newRedeemServer(body)
defer server.Close()
session, err := p.Redeem("http://redirect/", "code1234")
assert.Equal(t, nil, err)
assert.NotEqual(t, session, nil)
assert.Equal(t, "michael.bland@gsa.gov", session.Email)
assert.Equal(t, "a1234", session.AccessToken)
assert.Equal(t, "refresh12345", session.RefreshToken)
}
//
func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
body, err := json.Marshal( body, err := json.Marshal(redeemResponse{
struct { AccessToken: "a1234",
IdToken string `json:"id_token"`
}{
IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`,
}, })
)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
email, err := p.GetEmailAddress(body, "ignored access_token") var server *httptest.Server
assert.Equal(t, "", email) p.RedeemUrl, server = newRedeemServer(body)
defer server.Close()
session, err := p.Redeem("http://redirect/", "code1234")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expect nill session %#v", session)
}
} }
func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
body, err := json.Marshal( body, err := json.Marshal(redeemResponse{
struct { AccessToken: "a1234",
IdToken string `json:"id_token"`
}{
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)),
}, })
)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
email, err := p.GetEmailAddress(body, "ignored access_token") var server *httptest.Server
assert.Equal(t, "", email) p.RedeemUrl, server = newRedeemServer(body)
defer server.Close()
session, err := p.Redeem("http://redirect/", "code1234")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expect nill session %#v", session)
}
} }
func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
p := newGoogleProvider() p := newGoogleProvider()
body, err := json.Marshal( body, err := json.Marshal(redeemResponse{
struct { AccessToken: "a1234",
IdToken string `json:"id_token"`
}{
IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)),
}, })
)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
email, err := p.GetEmailAddress(body, "ignored access_token") var server *httptest.Server
assert.Equal(t, "", email) p.RedeemUrl, server = newRedeemServer(body)
defer server.Close()
session, err := p.Redeem("http://redirect/", "code1234")
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
if session != nil {
t.Errorf("expect nill session %#v", session)
}
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/bitly/oauth2_proxy/api" "github.com/bitly/oauth2_proxy/api"
) )
// validateToken returns true if token is valid
func validateToken(p Provider, access_token string, header http.Header) bool { func validateToken(p Provider, access_token string, header http.Header) bool {
if access_token == "" || p.Data().ValidateUrl == nil { if access_token == "" || p.Data().ValidateUrl == nil {
return false return false
@ -20,12 +21,15 @@ func validateToken(p Provider, access_token string, header http.Header) bool {
} }
resp, err := api.RequestUnparsedResponse(endpoint, header) resp, err := api.RequestUnparsedResponse(endpoint, header)
if err != nil { if err != nil {
log.Printf("GET %s", endpoint)
log.Printf("token validation request failed: %s", err) log.Printf("token validation request failed: %s", err)
return false return false
} }
body, _ := ioutil.ReadAll(resp.Body) body, _ := ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
log.Printf("%d GET %s %s", resp.StatusCode, endpoint, body)
if resp.StatusCode == 200 { if resp.StatusCode == 200 {
return true return true
} }

View File

@ -1,36 +1,38 @@
package providers package providers
import ( import (
"github.com/bmizerany/assert" "errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
"github.com/bmizerany/assert"
) )
type ValidateTokenTestProvider struct { type ValidateSessionStateTestProvider struct {
*ProviderData *ProviderData
} }
func (tp *ValidateTokenTestProvider) GetEmailAddress(body []byte, access_token string) (string, error) { func (tp *ValidateSessionStateTestProvider) GetEmailAddress(s *SessionState) (string, error) {
return "", nil return "", errors.New("not implemented")
} }
// Note that we're testing the internal validateToken() used to implement // Note that we're testing the internal validateToken() used to implement
// several Provider's ValidateToken() implementations // several Provider's ValidateSessionState() implementations
func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) bool { func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState) bool {
return false return false
} }
type ValidateTokenTest struct { type ValidateSessionStateTest struct {
backend *httptest.Server backend *httptest.Server
response_code int response_code int
provider *ValidateTokenTestProvider provider *ValidateSessionStateTestProvider
header http.Header header http.Header
} }
func NewValidateTokenTest() *ValidateTokenTest { func NewValidateSessionStateTest() *ValidateSessionStateTest {
var vt_test ValidateTokenTest var vt_test ValidateSessionStateTest
vt_test.backend = httptest.NewServer( vt_test.backend = httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -59,7 +61,7 @@ func NewValidateTokenTest() *ValidateTokenTest {
})) }))
backend_url, _ := url.Parse(vt_test.backend.URL) backend_url, _ := url.Parse(vt_test.backend.URL)
vt_test.provider = &ValidateTokenTestProvider{ vt_test.provider = &ValidateSessionStateTestProvider{
ProviderData: &ProviderData{ ProviderData: &ProviderData{
ValidateUrl: &url.URL{ ValidateUrl: &url.URL{
Scheme: "http", Scheme: "http",
@ -72,18 +74,18 @@ func NewValidateTokenTest() *ValidateTokenTest {
return &vt_test return &vt_test
} }
func (vt_test *ValidateTokenTest) Close() { func (vt_test *ValidateSessionStateTest) Close() {
vt_test.backend.Close() vt_test.backend.Close()
} }
func TestValidateTokenValidToken(t *testing.T) { func TestValidateSessionStateValidToken(t *testing.T) {
vt_test := NewValidateTokenTest() vt_test := NewValidateSessionStateTest()
defer vt_test.Close() defer vt_test.Close()
assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil))
} }
func TestValidateTokenValidTokenWithHeaders(t *testing.T) { func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) {
vt_test := NewValidateTokenTest() vt_test := NewValidateSessionStateTest()
defer vt_test.Close() defer vt_test.Close()
vt_test.header = make(http.Header) vt_test.header = make(http.Header)
vt_test.header.Set("Authorization", "Bearer foobar") vt_test.header.Set("Authorization", "Bearer foobar")
@ -91,28 +93,28 @@ func TestValidateTokenValidTokenWithHeaders(t *testing.T) {
validateToken(vt_test.provider, "foobar", vt_test.header)) validateToken(vt_test.provider, "foobar", vt_test.header))
} }
func TestValidateTokenEmptyToken(t *testing.T) { func TestValidateSessionStateEmptyToken(t *testing.T) {
vt_test := NewValidateTokenTest() vt_test := NewValidateSessionStateTest()
defer vt_test.Close() defer vt_test.Close()
assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) assert.Equal(t, false, validateToken(vt_test.provider, "", nil))
} }
func TestValidateTokenEmptyValidateUrl(t *testing.T) { func TestValidateSessionStateEmptyValidateUrl(t *testing.T) {
vt_test := NewValidateTokenTest() vt_test := NewValidateSessionStateTest()
defer vt_test.Close() defer vt_test.Close()
vt_test.provider.Data().ValidateUrl = nil vt_test.provider.Data().ValidateUrl = nil
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
} }
func TestValidateTokenRequestNetworkFailure(t *testing.T) { func TestValidateSessionStateRequestNetworkFailure(t *testing.T) {
vt_test := NewValidateTokenTest() vt_test := NewValidateSessionStateTest()
// Close immediately to simulate a network failure // Close immediately to simulate a network failure
vt_test.Close() vt_test.Close()
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))
} }
func TestValidateTokenExpiredToken(t *testing.T) { func TestValidateSessionStateExpiredToken(t *testing.T) {
vt_test := NewValidateTokenTest() vt_test := NewValidateSessionStateTest()
defer vt_test.Close() defer vt_test.Close()
vt_test.response_code = 401 vt_test.response_code = 401
assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil))

View File

@ -1,7 +1,6 @@
package providers package providers
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"log" "log"
@ -49,16 +48,15 @@ func getLinkedInHeader(access_token string) http.Header {
return header return header
} }
func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (string, error) { func (p *LinkedInProvider) GetEmailAddress(s *SessionState) (string, error) {
if access_token == "" { if s.AccessToken == "" {
return "", errors.New("missing access token") return "", errors.New("missing access token")
} }
params := url.Values{} req, err := http.NewRequest("GET", p.ProfileUrl.String()+"?format=json", nil)
req, err := http.NewRequest("GET", p.ProfileUrl.String()+"?format=json", bytes.NewBufferString(params.Encode()))
if err != nil { if err != nil {
return "", err return "", err
} }
req.Header = getLinkedInHeader(access_token) req.Header = getLinkedInHeader(s.AccessToken)
json, err := api.Request(req) json, err := api.Request(req)
if err != nil { if err != nil {
@ -74,6 +72,6 @@ func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (st
return email, nil return email, nil
} }
func (p *LinkedInProvider) ValidateToken(access_token string) bool { func (p *LinkedInProvider) ValidateSessionState(s *SessionState) bool {
return validateToken(p, access_token, getLinkedInHeader(access_token)) return validateToken(p, s.AccessToken, getLinkedInHeader(s.AccessToken))
} }

View File

@ -97,8 +97,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) {
b_url, _ := url.Parse(b.URL) b_url, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(b_url.Host)
email, err := p.GetEmailAddress([]byte{}, session := &SessionState{AccessToken: "imaginary_access_token"}
"imaginary_access_token") email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "user@linkedin.com", email) assert.Equal(t, "user@linkedin.com", email)
} }
@ -113,7 +113,8 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token") session := &SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -125,7 +126,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b_url, _ := url.Parse(b.URL) b_url, _ := url.Parse(b.URL)
p := testLinkedInProvider(b_url.Host) p := testLinkedInProvider(b_url.Host)
email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }

View File

@ -42,9 +42,9 @@ func NewMyUsaProvider(p *ProviderData) *MyUsaProvider {
return &MyUsaProvider{ProviderData: p} return &MyUsaProvider{ProviderData: p}
} }
func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (string, error) { func (p *MyUsaProvider) GetEmailAddress(s *SessionState) (string, error) {
req, err := http.NewRequest("GET", req, err := http.NewRequest("GET",
p.ProfileUrl.String()+"?access_token="+access_token, nil) p.ProfileUrl.String()+"?access_token="+s.AccessToken, nil)
if err != nil { if err != nil {
log.Printf("failed building request %s", err) log.Printf("failed building request %s", err)
return "", err return "", err
@ -56,7 +56,3 @@ func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (strin
} }
return json.Get("email").String() return json.Get("email").String()
} }
func (p *MyUsaProvider) ValidateToken(access_token string) bool {
return validateToken(p, access_token, nil)
}

View File

@ -1,11 +1,12 @@
package providers package providers
import ( import (
"github.com/bmizerany/assert"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
"github.com/bmizerany/assert"
) )
func updateUrl(url *url.URL, hostname string) { func updateUrl(url *url.URL, hostname string) {
@ -102,7 +103,8 @@ func TestMyUsaProviderGetEmailAddress(t *testing.T) {
b_url, _ := url.Parse(b.URL) b_url, _ := url.Parse(b.URL)
p := testMyUsaProvider(b_url.Host) p := testMyUsaProvider(b_url.Host)
email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
assert.Equal(t, "michael.bland@gsa.gov", email) assert.Equal(t, "michael.bland@gsa.gov", email)
} }
@ -119,7 +121,8 @@ func TestMyUsaProviderGetEmailAddressFailedRequest(t *testing.T) {
// We'll trigger a request failure by using an unexpected access // We'll trigger a request failure by using an unexpected access
// token. Alternatively, we could allow the parsing of the payload as // token. Alternatively, we could allow the parsing of the payload as
// JSON to fail. // JSON to fail.
email, err := p.GetEmailAddress([]byte{}, "unexpected_access_token") session := &SessionState{AccessToken: "unexpected_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }
@ -131,7 +134,8 @@ func TestMyUsaProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) {
b_url, _ := url.Parse(b.URL) b_url, _ := url.Parse(b.URL)
p := testMyUsaProvider(b_url.Host) p := testMyUsaProvider(b_url.Host)
email, err := p.GetEmailAddress([]byte{}, "imaginary_access_token") session := &SessionState{AccessToken: "imaginary_access_token"}
email, err := p.GetEmailAddress(session)
assert.NotEqual(t, nil, err) assert.NotEqual(t, nil, err)
assert.Equal(t, "", email) assert.Equal(t, "", email)
} }

View File

@ -9,9 +9,11 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/bitly/oauth2_proxy/cookie"
) )
func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token string, err error) { func (p *ProviderData) Redeem(redirectUrl, code string) (s *SessionState, err error) {
if code == "" { if code == "" {
err = errors.New("missing code") err = errors.New("missing code")
return return
@ -23,24 +25,28 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri
params.Add("client_secret", p.ClientSecret) params.Add("client_secret", p.ClientSecret)
params.Add("code", code) params.Add("code", code)
params.Add("grant_type", "authorization_code") params.Add("grant_type", "authorization_code")
req, err := http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) var req *http.Request
req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode()))
if err != nil { if err != nil {
return nil, "", err return
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req) var resp *http.Response
resp, err = http.DefaultClient.Do(req)
if err != nil { if err != nil {
return nil, "", err return nil, err
} }
var body []byte
body, err = ioutil.ReadAll(resp.Body) body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close() resp.Body.Close()
if err != nil { if err != nil {
return nil, "", err return
} }
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
return body, "", fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body) err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body)
return
} }
// blindly try json and x-www-form-urlencoded // blindly try json and x-www-form-urlencoded
@ -49,11 +55,23 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri
} }
err = json.Unmarshal(body, &jsonResponse) err = json.Unmarshal(body, &jsonResponse)
if err == nil { if err == nil {
return body, jsonResponse.AccessToken, nil s = &SessionState{
AccessToken: jsonResponse.AccessToken,
}
return
} }
v, err := url.ParseQuery(string(body)) var v url.Values
return body, v.Get("access_token"), err v, err = url.ParseQuery(string(body))
if err != nil {
return
}
if a := v.Get("access_token"); a != "" {
s = &SessionState{AccessToken: a}
} else {
err = fmt.Errorf("no access token found %s", body)
}
return
} }
// GetLoginURL with typical oauth parameters // GetLoginURL with typical oauth parameters
@ -72,3 +90,26 @@ func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string {
a.RawQuery = params.Encode() a.RawQuery = params.Encode()
return a.String() return a.String()
} }
// CookieForSession serializes a session state for storage in a cookie
func (p *ProviderData) CookieForSession(s *SessionState, c *cookie.Cipher) (string, error) {
return s.EncodeSessionState(c)
}
// SessionFromCookie deserializes a session from a cookie value
func (p *ProviderData) SessionFromCookie(v string, c *cookie.Cipher) (s *SessionState, err error) {
return DecodeSessionState(v, c)
}
func (p *ProviderData) GetEmailAddress(s *SessionState) (string, error) {
return "", errors.New("not implemented")
}
func (p *ProviderData) ValidateSessionState(s *SessionState) bool {
return validateToken(p, s.AccessToken, nil)
}
// RefreshSessionIfNeeded
func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
return false, nil
}

View File

@ -0,0 +1,17 @@
package providers
import (
"testing"
"time"
"github.com/bmizerany/assert"
)
func TestRefresh(t *testing.T) {
p := &ProviderData{}
refreshed, err := p.RefreshSessionIfNeeded(&SessionState{
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
})
assert.Equal(t, false, refreshed)
assert.Equal(t, nil, err)
}

View File

@ -1,11 +1,18 @@
package providers package providers
import (
"github.com/bitly/oauth2_proxy/cookie"
)
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(body []byte, access_token string) (string, error) GetEmailAddress(*SessionState) (string, error)
Redeem(string, string) ([]byte, string, error) Redeem(string, string) (*SessionState, error)
ValidateToken(access_token string) bool ValidateSessionState(*SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(*SessionState) (bool, error)
SessionFromCookie(string, *cookie.Cipher) (*SessionState, error)
CookieForSession(*SessionState, *cookie.Cipher) (string, error)
} }
func New(provider string, p *ProviderData) Provider { func New(provider string, p *ProviderData) Provider {

115
providers/session_state.go Normal file
View File

@ -0,0 +1,115 @@
package providers
import (
"fmt"
"strconv"
"strings"
"time"
"github.com/bitly/oauth2_proxy/cookie"
)
type SessionState struct {
AccessToken string
ExpiresOn time.Time
RefreshToken string
Email string
User string
}
func (s *SessionState) IsExpired() bool {
if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
return true
}
return false
}
func (s *SessionState) String() string {
o := fmt.Sprintf("Session{%s", s.userOrEmail())
if s.AccessToken != "" {
o += " token:true"
}
if !s.ExpiresOn.IsZero() {
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
}
if s.RefreshToken != "" {
o += " refresh_token:true"
}
return o + "}"
}
func (s *SessionState) EncodeSessionState(c *cookie.Cipher) (string, error) {
if c == nil || s.AccessToken == "" {
return s.userOrEmail(), nil
}
return s.EncryptedString(c)
}
func (s *SessionState) userOrEmail() string {
u := s.User
if s.Email != "" {
u = s.Email
}
return u
}
func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) {
var err error
if c == nil {
panic("error. missing cipher")
}
a := s.AccessToken
if a != "" {
a, err = c.Encrypt(a)
if err != nil {
return "", err
}
}
r := s.RefreshToken
if r != "" {
r, err = c.Encrypt(r)
if err != nil {
return "", err
}
}
return fmt.Sprintf("%s|%s|%d|%s", s.userOrEmail(), a, s.ExpiresOn.Unix(), r), nil
}
func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) {
chunks := strings.Split(v, "|")
if len(chunks) == 1 {
if strings.Contains(chunks[0], "@") {
u := strings.Split(v, "@")[0]
return &SessionState{Email: v, User: u}, nil
}
return &SessionState{User: v}, nil
}
if len(chunks) != 4 {
err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks))
return
}
s = &SessionState{}
if c != nil && chunks[1] != "" {
s.AccessToken, err = c.Decrypt(chunks[1])
if err != nil {
return nil, err
}
}
if c != nil && chunks[3] != "" {
s.RefreshToken, err = c.Decrypt(chunks[3])
if err != nil {
return nil, err
}
}
if u := chunks[0]; strings.Contains(u, "@") {
s.Email = u
s.User = strings.Split(u, "@")[0]
} else {
s.User = u
}
ts, _ := strconv.Atoi(chunks[2])
s.ExpiresOn = time.Unix(int64(ts), 0)
return
}

View File

@ -0,0 +1,88 @@
package providers
import (
"strings"
"testing"
"time"
"github.com/bitly/oauth2_proxy/cookie"
"github.com/bmizerany/assert"
)
const secret = "0123456789abcdefghijklmnopqrstuv"
const altSecret = "0000000000abcdefghijklmnopqrstuv"
func TestSessionStateSerialization(t *testing.T) {
c, err := cookie.NewCipher(secret)
assert.Equal(t, nil, err)
c2, err := cookie.NewCipher(altSecret)
assert.Equal(t, nil, err)
s := &SessionState{
Email: "user@domain.com",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
assert.Equal(t, nil, err)
assert.Equal(t, 3, strings.Count(encoded, "|"))
ss, err := DecodeSessionState(encoded, c)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken)
// ensure a different cipher can't decode properly (ie: it gets gibberish)
ss, err = DecodeSessionState(encoded, c2)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
}
func TestSessionStateSerializationNoCipher(t *testing.T) {
s := &SessionState{
Email: "user@domain.com",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
assert.Equal(t, nil, err)
assert.Equal(t, s.Email, encoded)
// only email should have been serialized
ss, err := DecodeSessionState(encoded, nil)
assert.Equal(t, nil, err)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, "", ss.AccessToken)
assert.Equal(t, "", ss.RefreshToken)
}
func TestSessionStateUserOrEmail(t *testing.T) {
s := &SessionState{
Email: "user@domain.com",
User: "just-user",
}
assert.Equal(t, "user@domain.com", s.userOrEmail())
s.Email = ""
assert.Equal(t, "just-user", s.userOrEmail())
}
func TestExpired(t *testing.T) {
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
assert.Equal(t, true, s.IsExpired())
s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
assert.Equal(t, false, s.IsExpired())
s = &SessionState{}
assert.Equal(t, false, s.IsExpired())
}