Self code review changes

This commit is contained in:
MisterWil 2019-02-10 09:01:13 -08:00
parent b46e34be72
commit 2e5c877dd1
2 changed files with 15 additions and 15 deletions

View File

@ -19,7 +19,7 @@ func main() {
flagSet := flag.NewFlagSet("oauth2_proxy", flag.ExitOnError) flagSet := flag.NewFlagSet("oauth2_proxy", flag.ExitOnError)
emailDomains := StringArray{} emailDomains := StringArray{}
whitelistandardomains := StringArray{} whitelistDomains := StringArray{}
upstreams := StringArray{} upstreams := StringArray{}
skipAuthRegex := StringArray{} skipAuthRegex := StringArray{}
googleGroups := StringArray{} googleGroups := StringArray{}
@ -48,7 +48,7 @@ func main() {
flagSet.Duration("flush-interval", time.Duration(1)*time.Second, "period between response flushing when streaming responses") flagSet.Duration("flush-interval", time.Duration(1)*time.Second, "period between response flushing when streaming responses")
flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email")
flagSet.Var(&whitelistandardomains, "whitelist-domain", "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)") flagSet.Var(&whitelistDomains, "whitelist-domain", "allowed domains for redirection after authentication. Prefix domain with a . to allow subdomains (eg .example.com)")
flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.") flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.")
flagSet.String("github-org", "", "restrict logins to members of this organisation") flagSet.String("github-org", "", "restrict logins to members of this organisation")
flagSet.String("github-team", "", "restrict logins to members of this team") flagSet.String("github-team", "", "restrict logins to members of this team")

View File

@ -553,10 +553,10 @@ func (p *OAuthProxy) ManualSignIn(rw http.ResponseWriter, req *http.Request) (st
} }
// check auth // check auth
if p.HtpasswdFile.Validate(user, passwd) { if p.HtpasswdFile.Validate(user, passwd) {
logger.PrintAuthf(user, req, logger.AuthSuccess, "Successful authentication via HtpasswdFile") logger.PrintAuthf(user, req, logger.AuthSuccess, "Authenticated via HtpasswdFile")
return user, true return user, true
} }
logger.PrintAuthf(user, req, logger.AuthFailure, "Failed authentication via HtpasswdFile; unauthorized") logger.PrintAuthf(user, req, logger.AuthFailure, "Invalid authentication via HtpasswdFile")
return "", false return "", false
} }
@ -704,27 +704,27 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
// finish the oauth cycle // finish the oauth cycle
err := req.ParseForm() err := req.ParseForm()
if err != nil { if err != nil {
logger.Printf("Error while parsing OAuth callback: %s" + err.Error()) logger.Printf("Error while parsing OAuth2 callback: %s" + err.Error())
p.ErrorPage(rw, 500, "Internal Error", err.Error()) p.ErrorPage(rw, 500, "Internal Error", err.Error())
return return
} }
errorString := req.Form.Get("error") errorString := req.Form.Get("error")
if errorString != "" { if errorString != "" {
logger.Printf("Error while parsing OAuth callback: %s ", errorString) logger.Printf("Error while parsing OAuth2 callback: %s ", errorString)
p.ErrorPage(rw, 403, "Permission Denied", errorString) p.ErrorPage(rw, 403, "Permission Denied", errorString)
return return
} }
session, err := p.redeemCode(req.Host, req.Form.Get("code")) session, err := p.redeemCode(req.Host, req.Form.Get("code"))
if err != nil { if err != nil {
logger.Printf("Error while parsing OAuth callback: %s ", errorString) logger.Printf("Error redeeming code during OAuth2 callback: %s ", errorString)
p.ErrorPage(rw, 500, "Internal Error", "Internal Error") p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
return return
} }
s := strings.SplitN(req.Form.Get("state"), ":", 2) s := strings.SplitN(req.Form.Get("state"), ":", 2)
if len(s) != 2 { if len(s) != 2 {
logger.Printf("Error while parsing OAuth state; invalid length") logger.Printf("Error while parsing OAuth2 state; invalid length")
p.ErrorPage(rw, 500, "Internal Error", "Invalid State") p.ErrorPage(rw, 500, "Internal Error", "Invalid State")
return return
} }
@ -732,13 +732,13 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
redirect := s[1] redirect := s[1]
c, err := req.Cookie(p.CSRFCookieName) c, err := req.Cookie(p.CSRFCookieName)
if err != nil { if err != nil {
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Failed authentication via oauth2; unable too obtain CSRF cookie") logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2; unable too obtain CSRF cookie")
p.ErrorPage(rw, 403, "Permission Denied", err.Error()) p.ErrorPage(rw, 403, "Permission Denied", err.Error())
return return
} }
p.ClearCSRFCookie(rw, req) p.ClearCSRFCookie(rw, req)
if c.Value != nonce { if c.Value != nonce {
logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Failed authentication via oauth2; csrf token mismatch, potential attack") logger.PrintAuthf(session.Email, req, logger.AuthFailure, "Invalid authentication via OAuth2; csrf token mismatch, potential attack")
p.ErrorPage(rw, 403, "Permission Denied", "csrf failed") p.ErrorPage(rw, 403, "Permission Denied", "csrf failed")
return return
} }
@ -749,7 +749,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
// set cookie, or deny // set cookie, or deny
if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) {
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Successful authentication via oauth2; %s", session) logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2; %s", session)
err := p.SaveSession(rw, req, session) err := p.SaveSession(rw, req, session)
if err != nil { if err != nil {
logger.Printf("%s %s", remoteAddr, err) logger.Printf("%s %s", remoteAddr, err)
@ -758,7 +758,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
} }
http.Redirect(rw, req, redirect, 302) http.Redirect(rw, req, redirect, 302)
} else { } else {
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Failed authentication via oauth2; unauthorized") logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Invalid authentication via OAuth2; unauthorized")
p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account")
} }
} }
@ -834,7 +834,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int
} }
if session != nil && session.Email != "" && !p.Validator(session.Email) { if session != nil && session.Email != "" && !p.Validator(session.Email) {
logger.Printf(session.Email, req, logger.AuthFailure, "Failed authentication via session; removing session %s", session) logger.Printf(session.Email, req, logger.AuthFailure, "Invalid authentication via session; removing session %s", session)
session = nil session = nil
saveSession = false saveSession = false
clearSession = true clearSession = true
@ -925,10 +925,10 @@ func (p *OAuthProxy) CheckBasicAuth(req *http.Request) (*providers.SessionState,
return nil, fmt.Errorf("invalid format %s", b) return nil, fmt.Errorf("invalid format %s", b)
} }
if p.HtpasswdFile.Validate(pair[0], pair[1]) { if p.HtpasswdFile.Validate(pair[0], pair[1]) {
logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Successful authentication via basic auth") logger.PrintAuthf(pair[0], req, logger.AuthSuccess, "Authenticated via basic auth and HTpasswd File")
return &providers.SessionState{User: pair[0]}, nil return &providers.SessionState{User: pair[0]}, nil
} }
logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Failed authentication via basic auth; not in Htpasswd file") logger.PrintAuthf(pair[0], req, logger.AuthFailure, "Invalid authentication via basic auth; not in Htpasswd File")
return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0]) return nil, fmt.Errorf("%s not in HtpasswdFile", pair[0])
} }