oauth2_proxy/oauthproxy.go

651 lines
17 KiB
Go

package main
import (
b64 "encoding/base64"
"errors"
"fmt"
"html/template"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"strings"
"time"
"github.com/18F/hmacauth"
"github.com/bitly/oauth2_proxy/cookie"
"github.com/bitly/oauth2_proxy/providers"
)
const SignatureHeader = "GAP-Signature"
var SignatureHeaders []string = []string{
"Content-Length",
"Content-Md5",
"Content-Type",
"Date",
"Authorization",
"X-Forwarded-User",
"X-Forwarded-Email",
"X-Forwarded-Access-Token",
"Cookie",
"Gap-Auth",
}
type OAuthProxy struct {
CookieSeed string
CookieName string
CookieDomain string
CookieSecure bool
CookieHttpOnly bool
CookieExpire time.Duration
CookieRefresh time.Duration
Validator func(string) bool
RobotsPath string
PingPath string
SignInPath string
SignOutPath string
OAuthStartPath string
OAuthCallbackPath string
AuthOnlyPath string
redirectURL *url.URL // the url to receive requests at
provider providers.Provider
ProxyPrefix string
SignInMessage string
HtpasswdFile *HtpasswdFile
DisplayHtpasswdForm bool
serveMux http.Handler
PassBasicAuth bool
SkipProviderButton bool
BasicAuthPassword string
PassAccessToken bool
CookieCipher *cookie.Cipher
skipAuthRegex []string
compiledRegex []*regexp.Regexp
templates *template.Template
Footer string
}
type UpstreamProxy struct {
upstream string
handler http.Handler
auth hmacauth.HmacAuth
}
func (u *UpstreamProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("GAP-Upstream-Address", u.upstream)
if u.auth != nil {
r.Header.Set("GAP-Auth", w.Header().Get("GAP-Auth"))
u.auth.SignRequest(r)
}
u.handler.ServeHTTP(w, r)
}
func NewReverseProxy(target *url.URL) (proxy *httputil.ReverseProxy) {
return httputil.NewSingleHostReverseProxy(target)
}
func setProxyUpstreamHostHeader(proxy *httputil.ReverseProxy, target *url.URL) {
director := proxy.Director
proxy.Director = func(req *http.Request) {
director(req)
// use RequestURI so that we aren't unescaping encoded slashes in the request path
req.Host = target.Host
req.URL.Opaque = req.RequestURI
req.URL.RawQuery = ""
}
}
func setProxyDirector(proxy *httputil.ReverseProxy) {
director := proxy.Director
proxy.Director = func(req *http.Request) {
director(req)
// use RequestURI so that we aren't unescaping encoded slashes in the request path
req.URL.Opaque = req.RequestURI
req.URL.RawQuery = ""
}
}
func NewFileServer(path string, filesystemPath string) (proxy http.Handler) {
return http.StripPrefix(path, http.FileServer(http.Dir(filesystemPath)))
}
func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
serveMux := http.NewServeMux()
var auth hmacauth.HmacAuth
if sigData := opts.signatureData; sigData != nil {
auth = hmacauth.NewHmacAuth(sigData.hash, []byte(sigData.key),
SignatureHeader, SignatureHeaders)
}
for _, u := range opts.proxyURLs {
path := u.Path
switch u.Scheme {
case "http", "https":
u.Path = ""
log.Printf("mapping path %q => upstream %q", path, u)
proxy := NewReverseProxy(u)
if !opts.PassHostHeader {
setProxyUpstreamHostHeader(proxy, u)
} else {
setProxyDirector(proxy)
}
serveMux.Handle(path,
&UpstreamProxy{u.Host, proxy, auth})
case "file":
if u.Fragment != "" {
path = u.Fragment
}
log.Printf("mapping path %q => file system %q", path, u.Path)
proxy := NewFileServer(path, u.Path)
serveMux.Handle(path, &UpstreamProxy{path, proxy, nil})
default:
panic(fmt.Sprintf("unknown upstream protocol %s", u.Scheme))
}
}
for _, u := range opts.CompiledRegex {
log.Printf("compiled skip-auth-regex => %q", u)
}
redirectURL := opts.redirectURL
redirectURL.Path = fmt.Sprintf("%s/callback", opts.ProxyPrefix)
log.Printf("OAuthProxy configured for %s Client ID: %s", opts.provider.Data().ProviderName, opts.ClientID)
domain := opts.CookieDomain
if domain == "" {
domain = "<default>"
}
refresh := "disabled"
if opts.CookieRefresh != time.Duration(0) {
refresh = fmt.Sprintf("after %s", opts.CookieRefresh)
}
log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, domain, refresh)
var cipher *cookie.Cipher
if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) {
var err error
cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret))
if err != nil {
log.Fatal("cookie-secret error: ", err)
}
}
return &OAuthProxy{
CookieName: opts.CookieName,
CookieSeed: opts.CookieSecret,
CookieDomain: opts.CookieDomain,
CookieSecure: opts.CookieSecure,
CookieHttpOnly: opts.CookieHttpOnly,
CookieExpire: opts.CookieExpire,
CookieRefresh: opts.CookieRefresh,
Validator: validator,
RobotsPath: "/robots.txt",
PingPath: "/ping",
SignInPath: fmt.Sprintf("%s/sign_in", opts.ProxyPrefix),
SignOutPath: fmt.Sprintf("%s/sign_out", opts.ProxyPrefix),
OAuthStartPath: fmt.Sprintf("%s/start", opts.ProxyPrefix),
OAuthCallbackPath: fmt.Sprintf("%s/callback", opts.ProxyPrefix),
AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix),
ProxyPrefix: opts.ProxyPrefix,
provider: opts.provider,
serveMux: serveMux,
redirectURL: redirectURL,
skipAuthRegex: opts.SkipAuthRegex,
compiledRegex: opts.CompiledRegex,
PassBasicAuth: opts.PassBasicAuth,
BasicAuthPassword: opts.BasicAuthPassword,
PassAccessToken: opts.PassAccessToken,
SkipProviderButton: opts.SkipProviderButton,
CookieCipher: cipher,
templates: loadTemplates(opts.CustomTemplatesDir),
Footer: opts.Footer,
}
}
func (p *OAuthProxy) GetRedirectURI(host string) string {
// default to the request Host if not set
if p.redirectURL.Host != "" {
return p.redirectURL.String()
}
var u url.URL
u = *p.redirectURL
if u.Scheme == "" {
if p.CookieSecure {
u.Scheme = "https"
} else {
u.Scheme = "http"
}
}
u.Host = host
return u.String()
}
func (p *OAuthProxy) displayCustomLoginForm() bool {
return p.HtpasswdFile != nil && p.DisplayHtpasswdForm
}
func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, err error) {
if code == "" {
return nil, errors.New("missing code")
}
redirectURI := p.GetRedirectURI(host)
s, err = p.provider.Redeem(redirectURI, code)
if err != nil {
return
}
if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(s)
}
return
}
func (p *OAuthProxy) MakeCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie {
domain := req.Host
if h, _, err := net.SplitHostPort(domain); err == nil {
domain = h
}
if p.CookieDomain != "" {
if !strings.HasSuffix(domain, p.CookieDomain) {
log.Printf("Warning: request host is %q but using configured cookie domain of %q", domain, p.CookieDomain)
}
domain = p.CookieDomain
}
if value != "" {
value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now)
}
return &http.Cookie{
Name: p.CookieName,
Value: value,
Path: "/",
Domain: domain,
HttpOnly: p.CookieHttpOnly,
Secure: p.CookieSecure,
Expires: now.Add(expiration),
}
}
func (p *OAuthProxy) ClearCookie(rw http.ResponseWriter, req *http.Request) {
http.SetCookie(rw, p.MakeCookie(req, "", time.Hour*-1, time.Now()))
}
func (p *OAuthProxy) SetCookie(rw http.ResponseWriter, req *http.Request, val string) {
http.SetCookie(rw, p.MakeCookie(req, val, p.CookieExpire, time.Now()))
}
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) {
var age time.Duration
c, err := req.Cookie(p.CookieName)
if err != nil {
// always http.ErrNoCookie
return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName)
}
val, timestamp, ok := cookie.Validate(c, p.CookieSeed, p.CookieExpire)
if !ok {
return nil, age, errors.New("Cookie Signature not valid")
}
session, err := p.provider.SessionFromCookie(val, p.CookieCipher)
if err != nil {
return nil, age, err
}
age = time.Now().Truncate(time.Second).Sub(timestamp)
return session, age, nil
}
func (p *OAuthProxy) SaveSession(rw http.ResponseWriter, req *http.Request, s *providers.SessionState) error {
value, err := p.provider.CookieForSession(s, p.CookieCipher)
if err != nil {
return err
}
p.SetCookie(rw, req, value)
return nil
}
func (p *OAuthProxy) RobotsTxt(rw http.ResponseWriter) {
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "User-agent: *\nDisallow: /")
}
func (p *OAuthProxy) PingPage(rw http.ResponseWriter) {
rw.WriteHeader(http.StatusOK)
fmt.Fprintf(rw, "OK")
}
func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, code int, title string, message string) {
log.Printf("ErrorPage %d %s %s", code, title, message)
rw.WriteHeader(code)
t := struct {
Title string
Message string
ProxyPrefix string
}{
Title: fmt.Sprintf("%d %s", code, title),
Message: message,
ProxyPrefix: p.ProxyPrefix,
}
p.templates.ExecuteTemplate(rw, "error.html", t)
}
func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code int) {
p.ClearCookie(rw, req)
rw.WriteHeader(code)
redirect_url := req.URL.RequestURI()
if redirect_url == p.SignInPath {
redirect_url = "/"
}
t := struct {
ProviderName string
SignInMessage string
CustomLogin bool
Redirect string
Version string
ProxyPrefix string
Footer template.HTML
}{
ProviderName: p.provider.Data().ProviderName,
SignInMessage: p.SignInMessage,
CustomLogin: p.displayCustomLoginForm(),
Redirect: redirect_url,
Version: VERSION,
ProxyPrefix: p.ProxyPrefix,
Footer: template.HTML(p.Footer),
}
p.templates.ExecuteTemplate(rw, "sign_in.html", t)
}
func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (string, bool) {
if req.Method != "POST" || p.HtpasswdFile == nil {
return "", false
}
user := req.FormValue("username")
passwd := req.FormValue("password")
if user == "" {
return "", false
}
// check auth
if p.HtpasswdFile.Validate(user, passwd) {
log.Printf("authenticated %q via HtpasswdFile", user)
return user, true
}
return "", false
}
func (p *OAuthProxy) GetRedirect(req *http.Request) (string, error) {
err := req.ParseForm()
if err != nil {
return "", err
}
redirect := req.FormValue("rd")
if redirect == "" {
redirect = "/"
}
return redirect, err
}
func (p *OAuthProxy) IsWhitelistedPath(path string) (ok bool) {
for _, u := range p.compiledRegex {
ok = u.MatchString(path)
if ok {
return
}
}
return
}
func getRemoteAddr(req *http.Request) (s string) {
s = req.RemoteAddr
if req.Header.Get("X-Real-IP") != "" {
s += fmt.Sprintf(" (%q)", req.Header.Get("X-Real-IP"))
}
return
}
func (p *OAuthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
switch path := req.URL.Path; {
case path == p.RobotsPath:
p.RobotsTxt(rw)
case path == p.PingPath:
p.PingPage(rw)
case p.IsWhitelistedPath(path):
p.serveMux.ServeHTTP(rw, req)
case path == p.SignInPath:
p.SignIn(rw, req)
case path == p.SignOutPath:
p.SignOut(rw, req)
case path == p.OAuthStartPath:
p.OAuthStart(rw, req)
case path == p.OAuthCallbackPath:
p.OAuthCallback(rw, req)
case path == p.AuthOnlyPath:
p.AuthenticateOnly(rw, req)
default:
p.Proxy(rw, req)
}
}
func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.GetRedirect(req)
if err != nil {
p.ErrorPage(rw, 500, "Internal Error", err.Error())
return
}
user, ok := p.ManualSignIn(rw, req)
if ok {
session := &providers.SessionState{User: user}
p.SaveSession(rw, req, session)
http.Redirect(rw, req, redirect, 302)
} else {
p.SignInPage(rw, req, 200)
}
}
func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
p.ClearCookie(rw, req)
http.Redirect(rw, req, "/", 302)
}
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
redirect, err := p.GetRedirect(req)
if err != nil {
p.ErrorPage(rw, 500, "Internal Error", err.Error())
return
}
redirectURI := p.GetRedirectURI(req.Host)
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, redirect), 302)
}
func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
remoteAddr := getRemoteAddr(req)
// finish the oauth cycle
err := req.ParseForm()
if err != nil {
p.ErrorPage(rw, 500, "Internal Error", err.Error())
return
}
errorString := req.Form.Get("error")
if errorString != "" {
p.ErrorPage(rw, 403, "Permission Denied", errorString)
return
}
session, err := p.redeemCode(req.Host, req.Form.Get("code"))
if err != nil {
log.Printf("%s error redeeming code %s", remoteAddr, err)
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
return
}
redirect := req.Form.Get("state")
if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
redirect = "/"
}
// set cookie, or deny
if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) {
log.Printf("%s authentication complete %s", remoteAddr, session)
err := p.SaveSession(rw, req, session)
if err != nil {
log.Printf("%s %s", remoteAddr, err)
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
return
}
http.Redirect(rw, req, redirect, 302)
} else {
log.Printf("%s Permission Denied: %q is unauthorized", remoteAddr, session.Email)
p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account")
}
}
func (p *OAuthProxy) AuthenticateOnly(rw http.ResponseWriter, req *http.Request) {
status := p.Authenticate(rw, req)
if status == http.StatusAccepted {
rw.WriteHeader(http.StatusAccepted)
} else {
http.Error(rw, "unauthorized request", http.StatusUnauthorized)
}
}
func (p *OAuthProxy) Proxy(rw http.ResponseWriter, req *http.Request) {
status := p.Authenticate(rw, req)
if status == http.StatusInternalServerError {
p.ErrorPage(rw, http.StatusInternalServerError,
"Internal Error", "Internal Error")
} else if status == http.StatusForbidden {
if p.SkipProviderButton {
p.OAuthStart(rw, req)
} else {
p.SignInPage(rw, req, http.StatusForbidden)
}
} else {
p.serveMux.ServeHTTP(rw, req)
}
}
func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int {
var saveSession, clearSession, revalidated bool
remoteAddr := getRemoteAddr(req)
session, sessionAge, err := p.LoadCookiedSession(req)
if err != nil {
log.Printf("%s %s", remoteAddr, err)
}
if session != nil && sessionAge > p.CookieRefresh && p.CookieRefresh != time.Duration(0) {
log.Printf("%s refreshing %s old session cookie for %s (refresh after %s)", remoteAddr, sessionAge, session, p.CookieRefresh)
saveSession = true
}
if ok, err := p.provider.RefreshSessionIfNeeded(session); err != nil {
log.Printf("%s removing session. error refreshing access token %s %s", remoteAddr, err, session)
clearSession = true
session = nil
} else if ok {
saveSession = true
revalidated = true
}
if session != nil && session.IsExpired() {
log.Printf("%s removing session. token expired %s", remoteAddr, session)
session = nil
saveSession = false
clearSession = true
}
if saveSession && !revalidated && session != nil && session.AccessToken != "" {
if !p.provider.ValidateSessionState(session) {
log.Printf("%s removing session. error validating %s", remoteAddr, session)
saveSession = false
session = nil
clearSession = true
}
}
if session != nil && session.Email != "" && !p.Validator(session.Email) {
log.Printf("%s Permission Denied: removing session %s", remoteAddr, session)
session = nil
saveSession = false
clearSession = true
}
if saveSession && session != nil {
err := p.SaveSession(rw, req, session)
if err != nil {
log.Printf("%s %s", remoteAddr, err)
return http.StatusInternalServerError
}
}
if clearSession {
p.ClearCookie(rw, req)
}
if session == nil {
session, err = p.CheckBasicAuth(req)
if err != nil {
log.Printf("%s %s", remoteAddr, err)
}
}
if session == nil {
return http.StatusForbidden
}
// At this point, the user is authenticated. proxy normally
if p.PassBasicAuth {
req.SetBasicAuth(session.User, p.BasicAuthPassword)
req.Header["X-Forwarded-User"] = []string{session.User}
if session.Email != "" {
req.Header["X-Forwarded-Email"] = []string{session.Email}
}
}
if p.PassAccessToken && session.AccessToken != "" {
req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken}
}
if session.Email == "" {
rw.Header().Set("GAP-Auth", session.User)
} else {
rw.Header().Set("GAP-Auth", session.Email)
}
return http.StatusAccepted
}
func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState, error) {
if p.HtpasswdFile == nil {
return nil, nil
}
auth := req.Header.Get("Authorization")
if auth == "" {
return nil, nil
}
s := strings.SplitN(auth, " ", 2)
if len(s) != 2 || s[0] != "Basic" {
return nil, fmt.Errorf("invalid Authorization header %s", req.Header.Get("Authorization"))
}
b, err := b64.StdEncoding.DecodeString(s[1])
if err != nil {
return nil, err
}
pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 {
return nil, fmt.Errorf("invalid format %s", b)
}
if p.HtpasswdFile.Validate(pair[0], pair[1]) {
log.Printf("authenticated %q via basic auth", pair[0])
return &providers.SessionState{User: pair[0]}, nil
}
return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0])
}