diff --git a/Gopkg.toml b/Gopkg.toml index e064786..253f154 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -3,10 +3,6 @@ # for detailed Gopkg.toml documentation. # -[[constraint]] - name = "github.com/18F/hmacauth" - version = "~1.0.1" - [[constraint]] name = "github.com/BurntSushi/toml" version = "~0.3.0" diff --git a/api/api.go b/api/api.go index e8378ff..8b02934 100644 --- a/api/api.go +++ b/api/api.go @@ -32,7 +32,7 @@ func Request(req *http.Request) (*simplejson.Json, error) { return data, nil } -func RequestJson(req *http.Request, v interface{}) error { +func RequestJSON(req *http.Request, v interface{}) error { resp, err := http.DefaultClient.Do(req) if err != nil { log.Printf("%s %s %s", req.Method, req.URL, err) diff --git a/api/api_test.go b/api/api_test.go index 4f9ae2a..7bdf1b7 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -1,20 +1,21 @@ package api import ( - "github.com/bitly/go-simplejson" "io/ioutil" "net/http" "net/http/httptest" "strings" "testing" + "github.com/bitly/go-simplejson" + "github.com/stretchr/testify/assert" ) -func testBackend(response_code int, payload string) *httptest.Server { +func testBackend(responseCode int, payload string) *httptest.Server { return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(response_code) + w.WriteHeader(responseCode) w.Write([]byte(payload)) })) } diff --git a/cookie/cookies_test.go b/cookie/cookies_test.go index 74e78fb..500550e 100644 --- a/cookie/cookies_test.go +++ b/cookie/cookies_test.go @@ -24,10 +24,11 @@ func TestEncodeAndDecodeAccessToken(t *testing.T) { } func TestEncodeAndDecodeAccessTokenB64(t *testing.T) { - const secret_b64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk=" + const secretBase64 = "A3Xbr6fu6Al0HkgrP1ztjb-mYiwmxgNPP-XbNsz1WBk=" const token = "my access token" - secret, err := base64.URLEncoding.DecodeString(secret_b64) + secret, err := base64.URLEncoding.DecodeString(secretBase64) + assert.Equal(t, nil, err) c, err := NewCipher([]byte(secret)) assert.Equal(t, nil, err) diff --git a/cookie/nonce.go b/cookie/nonce.go index 3012ce2..6def148 100644 --- a/cookie/nonce.go +++ b/cookie/nonce.go @@ -5,6 +5,7 @@ import ( "fmt" ) +// Nonce generates a random 16 byte string to be used as a nonce func Nonce() (nonce string, err error) { b := make([]byte, 16) _, err = rand.Read(b) diff --git a/htpasswd.go b/htpasswd.go index c68e558..9a0c504 100644 --- a/htpasswd.go +++ b/htpasswd.go @@ -28,12 +28,12 @@ func NewHtpasswdFromFile(path string) (*HtpasswdFile, error) { } func NewHtpasswd(file io.Reader) (*HtpasswdFile, error) { - csv_reader := csv.NewReader(file) - csv_reader.Comma = ':' - csv_reader.Comment = '#' - csv_reader.TrimLeadingSpace = true + csvReader := csv.NewReader(file) + csvReader.Comma = ':' + csvReader.Comment = '#' + csvReader.TrimLeadingSpace = true - records, err := csv_reader.ReadAll() + records, err := csvReader.ReadAll() if err != nil { return nil, err } diff --git a/htpasswd_test.go b/htpasswd_test.go index ebfd503..7a043e4 100644 --- a/htpasswd_test.go +++ b/htpasswd_test.go @@ -20,6 +20,7 @@ func TestSHA(t *testing.T) { func TestBcrypt(t *testing.T) { hash1, err := bcrypt.GenerateFromPassword([]byte("password"), 1) + assert.Equal(t, err, nil) hash2, err := bcrypt.GenerateFromPassword([]byte("top-secret"), 2) assert.Equal(t, err, nil) diff --git a/http.go b/http.go index aa764c8..6b0012c 100644 --- a/http.go +++ b/http.go @@ -23,12 +23,12 @@ func (s *Server) ListenAndServe() { } func (s *Server) ServeHTTP() { - httpAddress := s.Opts.HttpAddress + HTTPAddress := s.Opts.HTTPAddress scheme := "" - i := strings.Index(httpAddress, "://") + i := strings.Index(HTTPAddress, "://") if i > -1 { - scheme = httpAddress[0:i] + scheme = HTTPAddress[0:i] } var networkType string @@ -39,7 +39,7 @@ func (s *Server) ServeHTTP() { networkType = scheme } - slice := strings.SplitN(httpAddress, "//", 2) + slice := strings.SplitN(HTTPAddress, "//", 2) listenAddr := slice[len(slice)-1] listener, err := net.Listen(networkType, listenAddr) @@ -58,7 +58,7 @@ func (s *Server) ServeHTTP() { } func (s *Server) ServeHTTPS() { - addr := s.Opts.HttpsAddress + addr := s.Opts.HTTPSAddress config := &tls.Config{ MinVersion: tls.VersionTLS12, MaxVersion: tls.VersionTLS12, diff --git a/oauthproxy.go b/oauthproxy.go index fb94c1e..9972d83 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -14,14 +14,19 @@ import ( "strings" "time" + "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/cookie" "github.com/pusher/oauth2_proxy/providers" - "github.com/mbland/hmacauth" ) -const SignatureHeader = "GAP-Signature" +const ( + SignatureHeader = "GAP-Signature" -var SignatureHeaders []string = []string{ + httpScheme = "http" + httpsScheme = "https" +) + +var SignatureHeaders = []string{ "Content-Length", "Content-Md5", "Content-Type", @@ -40,7 +45,7 @@ type OAuthProxy struct { CSRFCookieName string CookieDomain string CookieSecure bool - CookieHttpOnly bool + CookieHTTPOnly bool CookieExpire time.Duration CookieRefresh time.Duration Validator func(string) bool @@ -125,7 +130,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { for _, u := range opts.proxyURLs { path := u.Path switch u.Scheme { - case "http", "https": + case httpScheme, httpsScheme: u.Path = "" log.Printf("mapping path %q => upstream %q", path, u) proxy := NewReverseProxy(u) @@ -160,7 +165,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { refresh = fmt.Sprintf("after %s", opts.CookieRefresh) } - log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh) + log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHTTPOnly, opts.CookieExpire, opts.CookieDomain, refresh) var cipher *cookie.Cipher if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { @@ -177,7 +182,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { CookieSeed: opts.CookieSecret, CookieDomain: opts.CookieDomain, CookieSecure: opts.CookieSecure, - CookieHttpOnly: opts.CookieHttpOnly, + CookieHTTPOnly: opts.CookieHTTPOnly, CookieExpire: opts.CookieExpire, CookieRefresh: opts.CookieRefresh, Validator: validator, @@ -218,9 +223,9 @@ func (p *OAuthProxy) GetRedirectURI(host string) string { u = *p.redirectURL if u.Scheme == "" { if p.CookieSecure { - u.Scheme = "https" + u.Scheme = httpsScheme } else { - u.Scheme = "http" + u.Scheme = httpScheme } } u.Host = host @@ -285,7 +290,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex Value: value, Path: "/", Domain: p.CookieDomain, - HttpOnly: p.CookieHttpOnly, + HttpOnly: p.CookieHTTPOnly, Secure: p.CookieSecure, Expires: now.Add(expiration), } @@ -374,12 +379,12 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code p.ClearSessionCookie(rw, req) rw.WriteHeader(code) - redirect_url := req.URL.RequestURI() + redirecURL := req.URL.RequestURI() if req.Header.Get("X-Auth-Request-Redirect") != "" { - redirect_url = req.Header.Get("X-Auth-Request-Redirect") + redirecURL = req.Header.Get("X-Auth-Request-Redirect") } - if redirect_url == p.SignInPath { - redirect_url = "/" + if redirecURL == p.SignInPath { + redirecURL = "/" } t := struct { @@ -394,7 +399,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code ProviderName: p.provider.Data().ProviderName, SignInMessage: p.SignInMessage, CustomLogin: p.displayCustomLoginForm(), - Redirect: redirect_url, + Redirect: redirecURL, Version: VERSION, ProxyPrefix: p.ProxyPrefix, Footer: template.HTML(p.Footer), @@ -653,7 +658,7 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int } if saveSession && session != nil { - err := p.SaveSession(rw, req, session) + err = p.SaveSession(rw, req, session) if err != nil { log.Printf("%s %s", remoteAddr, err) return http.StatusInternalServerError diff --git a/oauthproxy_test.go b/oauthproxy_test.go index ccfbf09..0421b56 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -15,8 +15,8 @@ import ( "testing" "time" - "github.com/pusher/oauth2_proxy/providers" "github.com/mbland/hmacauth" + "github.com/pusher/oauth2_proxy/providers" "github.com/stretchr/testify/assert" ) @@ -98,28 +98,28 @@ type TestProvider struct { ValidToken bool } -func NewTestProvider(provider_url *url.URL, email_address string) *TestProvider { +func NewTestProvider(providerURL *url.URL, emailAddress string) *TestProvider { return &TestProvider{ ProviderData: &providers.ProviderData{ ProviderName: "Test Provider", LoginURL: &url.URL{ Scheme: "http", - Host: provider_url.Host, + Host: providerURL.Host, Path: "/oauth/authorize", }, RedeemURL: &url.URL{ Scheme: "http", - Host: provider_url.Host, + Host: providerURL.Host, Path: "/oauth/token", }, ProfileURL: &url.URL{ Scheme: "http", - Host: provider_url.Host, + Host: providerURL.Host, Path: "/api/v1/profile", }, Scope: "profile.email", }, - EmailAddress: email_address, + EmailAddress: emailAddress, } } @@ -132,11 +132,10 @@ func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bo } func TestBasicAuthPassword(t *testing.T) { - provider_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Printf("%#v", r) - url := r.URL - payload := "" - switch url.Path { + var payload string + switch r.URL.Path { case "/oauth/token": payload = `{"access_token": "my_auth_token"}` default: @@ -149,7 +148,7 @@ func TestBasicAuthPassword(t *testing.T) { w.Write([]byte(payload)) })) opts := NewOptions() - opts.Upstreams = append(opts.Upstreams, provider_server.URL) + opts.Upstreams = append(opts.Upstreams, providerServer.URL) // The CookieSecret must be 32 bytes in order to create the AES // cipher. opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" @@ -161,13 +160,13 @@ func TestBasicAuthPassword(t *testing.T) { opts.BasicAuthPassword = "This is a secure password" opts.Validate() - provider_url, _ := url.Parse(provider_server.URL) - const email_address = "michael.bland@gsa.gov" - const user_name = "michael.bland" + providerURL, _ := url.Parse(providerServer.URL) + const emailAddress = "michael.bland@gsa.gov" + const username = "michael.bland" - opts.provider = NewTestProvider(provider_url, email_address) + opts.provider = NewTestProvider(providerURL, emailAddress) proxy := NewOAuthProxy(opts, func(email string) bool { - return email == email_address + return email == emailAddress }) rw := httptest.NewRecorder() @@ -182,10 +181,10 @@ func TestBasicAuthPassword(t *testing.T) { cookieName := proxy.CookieName var value string - key_prefix := cookieName + "=" + keyPrefix := cookieName + "=" for _, field := range strings.Split(cookie, "; ") { - value = strings.TrimPrefix(field, key_prefix) + value = strings.TrimPrefix(field, keyPrefix) if value != field { break } else { @@ -206,15 +205,15 @@ func TestBasicAuthPassword(t *testing.T) { rw = httptest.NewRecorder() proxy.ServeHTTP(rw, req) - expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(user_name+":"+opts.BasicAuthPassword)) + expectedHeader := "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+opts.BasicAuthPassword)) assert.Equal(t, expectedHeader, rw.Body.String()) - provider_server.Close() + providerServer.Close() } type PassAccessTokenTest struct { - provider_server *httptest.Server - proxy *OAuthProxy - opts *Options + providerServer *httptest.Server + proxy *OAuthProxy + opts *Options } type PassAccessTokenTestOptions struct { @@ -224,12 +223,11 @@ type PassAccessTokenTestOptions struct { func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTest { t := &PassAccessTokenTest{} - t.provider_server = httptest.NewServer( + t.providerServer = httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Printf("%#v", r) - url := r.URL - payload := "" - switch url.Path { + var payload string + switch r.URL.Path { case "/oauth/token": payload = `{"access_token": "my_auth_token"}` default: @@ -243,7 +241,7 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes })) t.opts = NewOptions() - t.opts.Upstreams = append(t.opts.Upstreams, t.provider_server.URL) + t.opts.Upstreams = append(t.opts.Upstreams, t.providerServer.URL) // The CookieSecret must be 32 bytes in order to create the AES // cipher. t.opts.CookieSecret = "xyzzyplughxyzzyplughxyzzyplughxp" @@ -253,21 +251,21 @@ func NewPassAccessTokenTest(opts PassAccessTokenTestOptions) *PassAccessTokenTes t.opts.PassAccessToken = opts.PassAccessToken t.opts.Validate() - provider_url, _ := url.Parse(t.provider_server.URL) - const email_address = "michael.bland@gsa.gov" + providerURL, _ := url.Parse(t.providerServer.URL) + const emailAddress = "michael.bland@gsa.gov" - t.opts.provider = NewTestProvider(provider_url, email_address) + t.opts.provider = NewTestProvider(providerURL, emailAddress) t.proxy = NewOAuthProxy(t.opts, func(email string) bool { - return email == email_address + return email == emailAddress }) return t } -func (pat_test *PassAccessTokenTest) Close() { - pat_test.provider_server.Close() +func (patTest *PassAccessTokenTest) Close() { + patTest.providerServer.Close() } -func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, +func (patTest *PassAccessTokenTest) getCallbackEndpoint() (httpCode int, cookie string) { rw := httptest.NewRecorder() req, err := http.NewRequest("GET", "/oauth2/callback?code=callback_code&state=nonce:", @@ -275,18 +273,18 @@ func (pat_test *PassAccessTokenTest) getCallbackEndpoint() (http_code int, if err != nil { return 0, "" } - req.AddCookie(pat_test.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) - pat_test.proxy.ServeHTTP(rw, req) + req.AddCookie(patTest.proxy.MakeCSRFCookie(req, "nonce", time.Hour, time.Now())) + patTest.proxy.ServeHTTP(rw, req) return rw.Code, rw.HeaderMap["Set-Cookie"][1] } -func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code int, access_token string) { - cookieName := pat_test.proxy.CookieName +func (patTest *PassAccessTokenTest) getRootEndpoint(cookie string) (httpCode int, accessToken string) { + cookieName := patTest.proxy.CookieName var value string - key_prefix := cookieName + "=" + keyPrefix := cookieName + "=" for _, field := range strings.Split(cookie, "; ") { - value = strings.TrimPrefix(field, key_prefix) + value = strings.TrimPrefix(field, keyPrefix) if value != field { break } else { @@ -310,18 +308,18 @@ func (pat_test *PassAccessTokenTest) getRootEndpoint(cookie string) (http_code i }) rw := httptest.NewRecorder() - pat_test.proxy.ServeHTTP(rw, req) + patTest.proxy.ServeHTTP(rw, req) return rw.Code, rw.Body.String() } func TestForwardAccessTokenUpstream(t *testing.T) { - pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: true, }) - defer pat_test.Close() + defer patTest.Close() // A successful validation will redirect and set the auth cookie. - code, cookie := pat_test.getCallbackEndpoint() + code, cookie := patTest.getCallbackEndpoint() if code != 302 { t.Fatalf("expected 302; got %d", code) } @@ -330,7 +328,7 @@ func TestForwardAccessTokenUpstream(t *testing.T) { // Now we make a regular request; the access_token from the cookie is // forwarded as the "X-Forwarded-Access-Token" header. The token is // read by the test provider server and written in the response body. - code, payload := pat_test.getRootEndpoint(cookie) + code, payload := patTest.getRootEndpoint(cookie) if code != 200 { t.Fatalf("expected 200; got %d", code) } @@ -338,13 +336,13 @@ func TestForwardAccessTokenUpstream(t *testing.T) { } func TestDoNotForwardAccessTokenUpstream(t *testing.T) { - pat_test := NewPassAccessTokenTest(PassAccessTokenTestOptions{ + patTest := NewPassAccessTokenTest(PassAccessTokenTestOptions{ PassAccessToken: false, }) - defer pat_test.Close() + defer patTest.Close() // A successful validation will redirect and set the auth cookie. - code, cookie := pat_test.getCallbackEndpoint() + code, cookie := patTest.getCallbackEndpoint() if code != 302 { t.Fatalf("expected 302; got %d", code) } @@ -352,7 +350,7 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { // Now we make a regular request, but the access token header should // not be present. - code, payload := pat_test.getRootEndpoint(cookie) + code, payload := patTest.getRootEndpoint(cookie) if code != 200 { t.Fatalf("expected 200; got %d", code) } @@ -360,49 +358,49 @@ func TestDoNotForwardAccessTokenUpstream(t *testing.T) { } type SignInPageTest struct { - opts *Options - proxy *OAuthProxy - sign_in_regexp *regexp.Regexp - sign_in_provider_regexp *regexp.Regexp + opts *Options + proxy *OAuthProxy + signInRegexp *regexp.Regexp + signInProviderRegexp *regexp.Regexp } const signInRedirectPattern = `` const signInSkipProvider = `>Found<` func NewSignInPageTest(skipProvider bool) *SignInPageTest { - var sip_test SignInPageTest + var sipTest SignInPageTest - sip_test.opts = NewOptions() - sip_test.opts.CookieSecret = "foobar" - sip_test.opts.ClientID = "bazquux" - sip_test.opts.ClientSecret = "xyzzyplugh" - sip_test.opts.SkipProviderButton = skipProvider - sip_test.opts.Validate() + sipTest.opts = NewOptions() + sipTest.opts.CookieSecret = "foobar" + sipTest.opts.ClientID = "bazquux" + sipTest.opts.ClientSecret = "xyzzyplugh" + sipTest.opts.SkipProviderButton = skipProvider + sipTest.opts.Validate() - sip_test.proxy = NewOAuthProxy(sip_test.opts, func(email string) bool { + sipTest.proxy = NewOAuthProxy(sipTest.opts, func(email string) bool { return true }) - sip_test.sign_in_regexp = regexp.MustCompile(signInRedirectPattern) - sip_test.sign_in_provider_regexp = regexp.MustCompile(signInSkipProvider) + sipTest.signInRegexp = regexp.MustCompile(signInRedirectPattern) + sipTest.signInProviderRegexp = regexp.MustCompile(signInSkipProvider) - return &sip_test + return &sipTest } -func (sip_test *SignInPageTest) GetEndpoint(endpoint string) (int, string) { +func (sipTest *SignInPageTest) GetEndpoint(endpoint string) (int, string) { rw := httptest.NewRecorder() req, _ := http.NewRequest("GET", endpoint, strings.NewReader("")) - sip_test.proxy.ServeHTTP(rw, req) + sipTest.proxy.ServeHTTP(rw, req) return rw.Code, rw.Body.String() } func TestSignInPageIncludesTargetRedirect(t *testing.T) { - sip_test := NewSignInPageTest(false) + sipTest := NewSignInPageTest(false) const endpoint = "/some/random/endpoint" - code, body := sip_test.GetEndpoint(endpoint) + code, body := sipTest.GetEndpoint(endpoint) assert.Equal(t, 403, code) - match := sip_test.sign_in_regexp.FindStringSubmatch(body) + match := sipTest.signInRegexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInRedirectPattern + "\nBody:\n" + body) @@ -414,11 +412,11 @@ func TestSignInPageIncludesTargetRedirect(t *testing.T) { } func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { - sip_test := NewSignInPageTest(false) - code, body := sip_test.GetEndpoint("/oauth2/sign_in") + sipTest := NewSignInPageTest(false) + code, body := sipTest.GetEndpoint("/oauth2/sign_in") assert.Equal(t, 200, code) - match := sip_test.sign_in_regexp.FindStringSubmatch(body) + match := sipTest.signInRegexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInRedirectPattern + "\nBody:\n" + body) @@ -429,13 +427,13 @@ func TestSignInPageDirectAccessRedirectsToRoot(t *testing.T) { } func TestSignInPageSkipProvider(t *testing.T) { - sip_test := NewSignInPageTest(true) + sipTest := NewSignInPageTest(true) const endpoint = "/some/random/endpoint" - code, body := sip_test.GetEndpoint(endpoint) + code, body := sipTest.GetEndpoint(endpoint) assert.Equal(t, 302, code) - match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) + match := sipTest.signInProviderRegexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInSkipProvider + "\nBody:\n" + body) @@ -443,13 +441,13 @@ func TestSignInPageSkipProvider(t *testing.T) { } func TestSignInPageSkipProviderDirect(t *testing.T) { - sip_test := NewSignInPageTest(true) + sipTest := NewSignInPageTest(true) const endpoint = "/sign_in" - code, body := sip_test.GetEndpoint(endpoint) + code, body := sipTest.GetEndpoint(endpoint) assert.Equal(t, 302, code) - match := sip_test.sign_in_provider_regexp.FindStringSubmatch(body) + match := sipTest.signInProviderRegexp.FindStringSubmatch(body) if match == nil { t.Fatal("Did not find pattern in body: " + signInSkipProvider + "\nBody:\n" + body) @@ -457,50 +455,50 @@ func TestSignInPageSkipProviderDirect(t *testing.T) { } type ProcessCookieTest struct { - opts *Options - proxy *OAuthProxy - rw *httptest.ResponseRecorder - req *http.Request - provider TestProvider - response_code int - validate_user bool + opts *Options + proxy *OAuthProxy + rw *httptest.ResponseRecorder + req *http.Request + provider TestProvider + responseCode int + validateUser bool } type ProcessCookieTestOpts struct { - provider_validate_cookie_response bool + providerValidateCookieResponse bool } func NewProcessCookieTest(opts ProcessCookieTestOpts) *ProcessCookieTest { - var pc_test ProcessCookieTest + var pcTest ProcessCookieTest - pc_test.opts = NewOptions() - pc_test.opts.ClientID = "bazquux" - pc_test.opts.ClientSecret = "xyzzyplugh" - pc_test.opts.CookieSecret = "0123456789abcdefabcd" + pcTest.opts = NewOptions() + pcTest.opts.ClientID = "bazquux" + pcTest.opts.ClientSecret = "xyzzyplugh" + pcTest.opts.CookieSecret = "0123456789abcdefabcd" // First, set the CookieRefresh option so proxy.AesCipher is created, // needed to encrypt the access_token. - pc_test.opts.CookieRefresh = time.Hour - pc_test.opts.Validate() + pcTest.opts.CookieRefresh = time.Hour + pcTest.opts.Validate() - pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { - return pc_test.validate_user + pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool { + return pcTest.validateUser }) - pc_test.proxy.provider = &TestProvider{ - ValidToken: opts.provider_validate_cookie_response, + pcTest.proxy.provider = &TestProvider{ + ValidToken: opts.providerValidateCookieResponse, } // Now, zero-out proxy.CookieRefresh for the cases that don't involve // access_token validation. - pc_test.proxy.CookieRefresh = time.Duration(0) - pc_test.rw = httptest.NewRecorder() - pc_test.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) - pc_test.validate_user = true - return &pc_test + pcTest.proxy.CookieRefresh = time.Duration(0) + pcTest.rw = httptest.NewRecorder() + pcTest.req, _ = http.NewRequest("GET", "/", strings.NewReader("")) + pcTest.validateUser = true + return &pcTest } func NewProcessCookieTestWithDefaults() *ProcessCookieTest { return NewProcessCookieTest(ProcessCookieTestOpts{ - provider_validate_cookie_response: true, + providerValidateCookieResponse: true, }) } @@ -522,12 +520,12 @@ func (p *ProcessCookieTest) LoadCookiedSession() (*providers.SessionState, time. } func TestLoadCookiedSession(t *testing.T) { - pc_test := NewProcessCookieTestWithDefaults() + pcTest := NewProcessCookieTestWithDefaults() startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - pc_test.SaveSession(startSession, time.Now()) + pcTest.SaveSession(startSession, time.Now()) - session, _, err := pc_test.LoadCookiedSession() + session, _, err := pcTest.LoadCookiedSession() assert.Equal(t, nil, err) assert.Equal(t, startSession.Email, session.Email) assert.Equal(t, "michael.bland", session.User) @@ -535,9 +533,9 @@ func TestLoadCookiedSession(t *testing.T) { } func TestProcessCookieNoCookieError(t *testing.T) { - pc_test := NewProcessCookieTestWithDefaults() + pcTest := NewProcessCookieTestWithDefaults() - session, _, err := pc_test.LoadCookiedSession() + session, _, err := pcTest.LoadCookiedSession() assert.Equal(t, "Cookie \"_oauth2_proxy\" not present", err.Error()) if session != nil { t.Errorf("expected nil session. got %#v", session) @@ -545,14 +543,14 @@ func TestProcessCookieNoCookieError(t *testing.T) { } func TestProcessCookieRefreshNotSet(t *testing.T) { - pc_test := NewProcessCookieTestWithDefaults() - pc_test.proxy.CookieExpire = time.Duration(23) * time.Hour + pcTest := NewProcessCookieTestWithDefaults() + pcTest.proxy.CookieExpire = time.Duration(23) * time.Hour reference := time.Now().Add(time.Duration(-2) * time.Hour) startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - pc_test.SaveSession(startSession, reference) + pcTest.SaveSession(startSession, reference) - session, age, err := pc_test.LoadCookiedSession() + session, age, err := pcTest.LoadCookiedSession() assert.Equal(t, nil, err) if age < time.Duration(-2)*time.Hour { t.Errorf("cookie too young %v", age) @@ -561,13 +559,13 @@ func TestProcessCookieRefreshNotSet(t *testing.T) { } func TestProcessCookieFailIfCookieExpired(t *testing.T) { - pc_test := NewProcessCookieTestWithDefaults() - pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour + pcTest := NewProcessCookieTestWithDefaults() + pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - pc_test.SaveSession(startSession, reference) + pcTest.SaveSession(startSession, reference) - session, _, err := pc_test.LoadCookiedSession() + session, _, err := pcTest.LoadCookiedSession() assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expected nil session %#v", session) @@ -575,14 +573,14 @@ func TestProcessCookieFailIfCookieExpired(t *testing.T) { } func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { - pc_test := NewProcessCookieTestWithDefaults() - pc_test.proxy.CookieExpire = time.Duration(24) * time.Hour + pcTest := NewProcessCookieTestWithDefaults() + pcTest.proxy.CookieExpire = time.Duration(24) * time.Hour reference := time.Now().Add(time.Duration(25) * time.Hour * -1) startSession := &providers.SessionState{Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} - pc_test.SaveSession(startSession, reference) + pcTest.SaveSession(startSession, reference) - pc_test.proxy.CookieRefresh = time.Hour - session, _, err := pc_test.LoadCookiedSession() + pcTest.proxy.CookieRefresh = time.Hour + session, _, err := pcTest.LoadCookiedSession() assert.NotEqual(t, nil, err) if session != nil { t.Errorf("expected nil session %#v", session) @@ -590,10 +588,10 @@ func TestProcessCookieFailIfRefreshSetAndCookieExpired(t *testing.T) { } func NewAuthOnlyEndpointTest() *ProcessCookieTest { - pc_test := NewProcessCookieTestWithDefaults() - pc_test.req, _ = http.NewRequest("GET", - pc_test.opts.ProxyPrefix+"/auth", nil) - return pc_test + pcTest := NewProcessCookieTestWithDefaults() + pcTest.req, _ = http.NewRequest("GET", + pcTest.opts.ProxyPrefix+"/auth", nil) + return pcTest } func TestAuthOnlyEndpointAccepted(t *testing.T) { @@ -636,7 +634,7 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { startSession := &providers.SessionState{ Email: "michael.bland@gsa.gov", AccessToken: "my_access_token"} test.SaveSession(startSession, time.Now()) - test.validate_user = false + test.validateUser = false test.proxy.ServeHTTP(test.rw, test.req) assert.Equal(t, http.StatusUnauthorized, test.rw.Code) @@ -645,33 +643,33 @@ func TestAuthOnlyEndpointUnauthorizedOnEmailValidationFailure(t *testing.T) { } func TestAuthOnlyEndpointSetXAuthRequestHeaders(t *testing.T) { - var pc_test ProcessCookieTest + var pcTest ProcessCookieTest - pc_test.opts = NewOptions() - pc_test.opts.SetXAuthRequest = true - pc_test.opts.Validate() + pcTest.opts = NewOptions() + pcTest.opts.SetXAuthRequest = true + pcTest.opts.Validate() - pc_test.proxy = NewOAuthProxy(pc_test.opts, func(email string) bool { - return pc_test.validate_user + pcTest.proxy = NewOAuthProxy(pcTest.opts, func(email string) bool { + return pcTest.validateUser }) - pc_test.proxy.provider = &TestProvider{ + pcTest.proxy.provider = &TestProvider{ ValidToken: true, } - pc_test.validate_user = true + pcTest.validateUser = true - pc_test.rw = httptest.NewRecorder() - pc_test.req, _ = http.NewRequest("GET", - pc_test.opts.ProxyPrefix+"/auth", nil) + pcTest.rw = httptest.NewRecorder() + pcTest.req, _ = http.NewRequest("GET", + pcTest.opts.ProxyPrefix+"/auth", nil) startSession := &providers.SessionState{ User: "oauth_user", Email: "oauth_user@example.com", AccessToken: "oauth_token"} - pc_test.SaveSession(startSession, time.Now()) + pcTest.SaveSession(startSession, time.Now()) - pc_test.proxy.ServeHTTP(pc_test.rw, pc_test.req) - assert.Equal(t, http.StatusAccepted, pc_test.rw.Code) - assert.Equal(t, "oauth_user", pc_test.rw.HeaderMap["X-Auth-Request-User"][0]) - assert.Equal(t, "oauth_user@example.com", pc_test.rw.HeaderMap["X-Auth-Request-Email"][0]) + pcTest.proxy.ServeHTTP(pcTest.rw, pcTest.req) + assert.Equal(t, http.StatusAccepted, pcTest.rw.Code) + assert.Equal(t, "oauth_user", pcTest.rw.HeaderMap["X-Auth-Request-User"][0]) + assert.Equal(t, "oauth_user@example.com", pcTest.rw.HeaderMap["X-Auth-Request-Email"][0]) } func TestAuthSkippedForPreflightRequests(t *testing.T) { @@ -689,8 +687,8 @@ func TestAuthSkippedForPreflightRequests(t *testing.T) { opts.SkipAuthPreflight = true opts.Validate() - upstream_url, _ := url.Parse(upstream.URL) - opts.provider = NewTestProvider(upstream_url, "") + upstreamURL, _ := url.Parse(upstream.URL) + opts.provider = NewTestProvider(upstreamURL, "") proxy := NewOAuthProxy(opts, func(string) bool { return false }) rw := httptest.NewRecorder() @@ -723,7 +721,7 @@ func (v *SignatureAuthenticator) Authenticate(w http.ResponseWriter, r *http.Req type SignatureTest struct { opts *Options upstream *httptest.Server - upstream_host string + upstreamHost string provider *httptest.Server header http.Header rw *httptest.ResponseRecorder @@ -740,20 +738,20 @@ func NewSignatureTest() *SignatureTest { authenticator := &SignatureAuthenticator{} upstream := httptest.NewServer( http.HandlerFunc(authenticator.Authenticate)) - upstream_url, _ := url.Parse(upstream.URL) + upstreamURL, _ := url.Parse(upstream.URL) opts.Upstreams = append(opts.Upstreams, upstream.URL) providerHandler := func(w http.ResponseWriter, r *http.Request) { w.Write([]byte(`{"access_token": "my_auth_token"}`)) } provider := httptest.NewServer(http.HandlerFunc(providerHandler)) - provider_url, _ := url.Parse(provider.URL) - opts.provider = NewTestProvider(provider_url, "mbland@acm.org") + providerURL, _ := url.Parse(provider.URL) + opts.provider = NewTestProvider(providerURL, "mbland@acm.org") return &SignatureTest{ opts, upstream, - upstream_url.Host, + upstreamURL.Host, provider, make(http.Header), httptest.NewRecorder(), diff --git a/options.go b/options.go index 34c4ca8..4120587 100644 --- a/options.go +++ b/options.go @@ -13,16 +13,17 @@ import ( "strings" "time" - "github.com/pusher/oauth2_proxy/providers" oidc "github.com/coreos/go-oidc" "github.com/mbland/hmacauth" + "github.com/pusher/oauth2_proxy/providers" ) -// Configuration Options that can be set by Command Line Flag, or Config File +// Options holds Configuration Options that can be set by Command Line Flag, +// or Config File type Options struct { ProxyPrefix string `flag:"proxy-prefix" cfg:"proxy-prefix"` - HttpAddress string `flag:"http-address" cfg:"http_address"` - HttpsAddress string `flag:"https-address" cfg:"https_address"` + HTTPAddress string `flag:"http-address" cfg:"http_address"` + HTTPSAddress string `flag:"https-address" cfg:"https_address"` RedirectURL string `flag:"redirect-url" cfg:"redirect_url"` ClientID string `flag:"client-id" cfg:"client_id" env:"OAUTH2_PROXY_CLIENT_ID"` ClientSecret string `flag:"client-secret" cfg:"client_secret" env:"OAUTH2_PROXY_CLIENT_SECRET"` @@ -48,7 +49,7 @@ type Options struct { CookieExpire time.Duration `flag:"cookie-expire" cfg:"cookie_expire" env:"OAUTH2_PROXY_COOKIE_EXPIRE"` CookieRefresh time.Duration `flag:"cookie-refresh" cfg:"cookie_refresh" env:"OAUTH2_PROXY_COOKIE_REFRESH"` CookieSecure bool `flag:"cookie-secure" cfg:"cookie_secure"` - CookieHttpOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` + CookieHTTPOnly bool `flag:"cookie-httponly" cfg:"cookie_httponly"` Upstreams []string `flag:"upstream" cfg:"upstreams"` SkipAuthRegex []string `flag:"skip-auth-regex" cfg:"skip_auth_regex"` @@ -96,12 +97,12 @@ type SignatureData struct { func NewOptions() *Options { return &Options{ ProxyPrefix: "/oauth2", - HttpAddress: "127.0.0.1:4180", - HttpsAddress: ":443", + HTTPAddress: "127.0.0.1:4180", + HTTPSAddress: ":443", DisplayHtpasswdForm: true, CookieName: "_oauth2_proxy", CookieSecure: true, - CookieHttpOnly: true, + CookieHTTPOnly: true, CookieExpire: time.Duration(168) * time.Hour, CookieRefresh: time.Duration(0), SetXAuthRequest: false, @@ -116,11 +117,11 @@ func NewOptions() *Options { } } -func parseURL(to_parse string, urltype string, msgs []string) (*url.URL, []string) { - parsed, err := url.Parse(to_parse) +func parseURL(toParse string, urltype string, msgs []string) (*url.URL, []string) { + parsed, err := url.Parse(toParse) if err != nil { return nil, append(msgs, fmt.Sprintf( - "error parsing %s-url=%q %s", urltype, to_parse, err)) + "error parsing %s-url=%q %s", urltype, toParse, err)) } return parsed, msgs } @@ -190,17 +191,17 @@ func (o *Options) Validate() error { msgs = parseProviderInfo(o, msgs) if o.PassAccessToken || (o.CookieRefresh != time.Duration(0)) { - valid_cookie_secret_size := false + validCookieSecretSize := false for _, i := range []int{16, 24, 32} { if len(secretBytes(o.CookieSecret)) == i { - valid_cookie_secret_size = true + validCookieSecretSize = true } } var decoded bool if string(secretBytes(o.CookieSecret)) != o.CookieSecret { decoded = true } - if valid_cookie_secret_size == false { + if validCookieSecretSize == false { var suffix string if decoded { suffix = fmt.Sprintf(" note: cookie secret was base64 decoded from %q", o.CookieSecret) @@ -294,12 +295,13 @@ func parseSignatureKey(o *Options, msgs []string) []string { } algorithm, secretKey := components[0], components[1] - if hash, err := hmacauth.DigestNameToCryptoHash(algorithm); err != nil { + var hash crypto.Hash + var err error + if hash, err = hmacauth.DigestNameToCryptoHash(algorithm); err != nil { return append(msgs, "unsupported signature hash algorithm: "+ o.SignatureKey) - } else { - o.signatureData = &SignatureData{hash, secretKey} } + o.signatureData = &SignatureData{hash, secretKey} return msgs } diff --git a/options_test.go b/options_test.go index 23ae9fb..fd1489b 100644 --- a/options_test.go +++ b/options_test.go @@ -88,9 +88,9 @@ func TestProxyURLs(t *testing.T) { o.Upstreams = append(o.Upstreams, "http://127.0.0.1:8081") assert.Equal(t, nil, o.Validate()) expected := []*url.URL{ - &url.URL{Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, + {Scheme: "http", Host: "127.0.0.1:8080", Path: "/"}, // note the '/' was added - &url.URL{Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, + {Scheme: "http", Host: "127.0.0.1:8081", Path: "/"}, } assert.Equal(t, expected, o.proxyURLs) } diff --git a/providers/azure.go b/providers/azure.go index 0c925bb..4ec760c 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -3,11 +3,12 @@ package providers import ( "errors" "fmt" - "github.com/bitly/go-simplejson" - "github.com/pusher/oauth2_proxy/api" "log" "net/http" "net/url" + + "github.com/bitly/go-simplejson" + "github.com/pusher/oauth2_proxy/api" ) type AzureProvider struct { @@ -60,9 +61,9 @@ func (p *AzureProvider) Configure(tenant string) { } } -func getAzureHeader(access_token string) http.Header { +func getAzureHeader(accessToken string) http.Header { header := make(http.Header) - header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) + header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) return header } diff --git a/providers/azure_test.go b/providers/azure_test.go index f2cf353..469f2d1 100644 --- a/providers/azure_test.go +++ b/providers/azure_test.go @@ -110,8 +110,7 @@ func testAzureBackend(payload string) *httptest.Server { return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - url := r.URL - if url.Path != path || url.RawQuery != query { + if r.URL.Path != path || r.URL.RawQuery != query { w.WriteHeader(404) } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { w.WriteHeader(403) diff --git a/providers/facebook.go b/providers/facebook.go index a322c69..42d3ce3 100644 --- a/providers/facebook.go +++ b/providers/facebook.go @@ -43,11 +43,11 @@ func NewFacebookProvider(p *ProviderData) *FacebookProvider { return &FacebookProvider{ProviderData: p} } -func getFacebookHeader(access_token string) http.Header { +func getFacebookHeader(accessToken string) http.Header { header := make(http.Header) header.Set("Accept", "application/json") header.Set("x-li-format", "json") - header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) + header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) return header } @@ -65,7 +65,7 @@ func (p *FacebookProvider) GetEmailAddress(s *SessionState) (string, error) { Email string } var r result - err = api.RequestJson(req, &r) + err = api.RequestJSON(req, &r) if err != nil { return "", err } diff --git a/providers/github.go b/providers/github.go index 26526ce..a307658 100644 --- a/providers/github.go +++ b/providers/github.go @@ -106,7 +106,7 @@ func (p *GitHubProvider) hasOrg(accessToken string) (bool, error) { } orgs = append(orgs, op...) - pn += 1 + pn++ } var presentOrgs []string @@ -186,7 +186,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) { log.Printf("Missing Team:%q from Org:%q in teams: %v", p.Team, p.Org, presentTeams) } else { var allOrgs []string - for org, _ := range presentOrgs { + for org := range presentOrgs { allOrgs = append(allOrgs, org) } log.Printf("Missing Organization:%q in %#v", p.Org, allOrgs) diff --git a/providers/github_test.go b/providers/github_test.go index 4810182..c96877c 100644 --- a/providers/github_test.go +++ b/providers/github_test.go @@ -29,19 +29,18 @@ func testGitHubProvider(hostname string) *GitHubProvider { func testGitHubBackend(payload []string) *httptest.Server { pathToQueryMap := map[string][]string{ - "/user": []string{""}, - "/user/emails": []string{""}, - "/user/orgs": []string{"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"}, + "/user": {""}, + "/user/emails": {""}, + "/user/orgs": {"limit=200&page=1", "limit=200&page=2", "limit=200&page=3"}, } return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - url := r.URL - query, ok := pathToQueryMap[url.Path] + query, ok := pathToQueryMap[r.URL.Path] validQuery := false index := 0 for i, q := range query { - if q == url.RawQuery { + if q == r.URL.RawQuery { validQuery = true index = i } diff --git a/providers/gitlab_test.go b/providers/gitlab_test.go index 0eec5aa..19038e1 100644 --- a/providers/gitlab_test.go +++ b/providers/gitlab_test.go @@ -33,8 +33,7 @@ func testGitLabBackend(payload string) *httptest.Server { return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - url := r.URL - if url.Path != path || url.RawQuery != query { + if r.URL.Path != path || r.URL.RawQuery != query { w.WriteHeader(404) } else { w.WriteHeader(200) @@ -87,8 +86,8 @@ func TestGitLabProviderGetEmailAddress(t *testing.T) { b := testGitLabBackend("{\"email\": \"michael.bland@gsa.gov\"}") defer b.Close() - b_url, _ := url.Parse(b.URL) - p := testGitLabProvider(b_url.Host) + bURL, _ := url.Parse(b.URL) + p := testGitLabProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) @@ -102,8 +101,8 @@ func TestGitLabProviderGetEmailAddressFailedRequest(t *testing.T) { b := testGitLabBackend("unused payload") defer b.Close() - b_url, _ := url.Parse(b.URL) - p := testGitLabProvider(b_url.Host) + bURL, _ := url.Parse(b.URL) + p := testGitLabProvider(bURL.Host) // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as @@ -118,8 +117,8 @@ func TestGitLabProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { b := testGitLabBackend("{\"foo\": \"bar\"}") defer b.Close() - b_url, _ := url.Parse(b.URL) - p := testGitLabProvider(b_url.Host) + bURL, _ := url.Parse(b.URL) + p := testGitLabProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) diff --git a/providers/google.go b/providers/google.go index 66406bd..113a691 100644 --- a/providers/google.go +++ b/providers/google.go @@ -62,7 +62,7 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { } } -func emailFromIdToken(idToken string) (string, error) { +func emailFromIDToken(idToken string) (string, error) { // id_token is a base64 encode ID token payload // https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo @@ -129,14 +129,14 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int64 `json:"expires_in"` - IdToken string `json:"id_token"` + IDToken string `json:"id_token"` } err = json.Unmarshal(body, &jsonResponse) if err != nil { return } var email string - email, err = emailFromIdToken(jsonResponse.IdToken) + email, err = emailFromIDToken(jsonResponse.IDToken) if err != nil { return } diff --git a/providers/google_test.go b/providers/google_test.go index fedd8da..25b375a 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -81,7 +81,7 @@ type redeemResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int64 `json:"expires_in"` - IdToken string `json:"id_token"` + IDToken string `json:"id_token"` } func TestGoogleProviderGetEmailAddress(t *testing.T) { @@ -90,7 +90,7 @@ func TestGoogleProviderGetEmailAddress(t *testing.T) { AccessToken: "a1234", ExpiresIn: 10, RefreshToken: "refresh12345", - IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), + IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": "michael.bland@gsa.gov", "email_verified":true}`)), }) assert.Equal(t, nil, err) var server *httptest.Server @@ -127,7 +127,7 @@ func TestGoogleProviderGetEmailAddressInvalidEncoding(t *testing.T) { p := newGoogleProvider() body, err := json.Marshal(redeemResponse{ AccessToken: "a1234", - IdToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, + IDToken: "ignored prefix." + `{"email": "michael.bland@gsa.gov"}`, }) assert.Equal(t, nil, err) var server *httptest.Server @@ -146,7 +146,7 @@ func TestGoogleProviderGetEmailAddressInvalidJson(t *testing.T) { body, err := json.Marshal(redeemResponse{ AccessToken: "a1234", - IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), + IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"email": michael.bland@gsa.gov}`)), }) assert.Equal(t, nil, err) var server *httptest.Server @@ -165,7 +165,7 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { p := newGoogleProvider() body, err := json.Marshal(redeemResponse{ AccessToken: "a1234", - IdToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), + IDToken: "ignored prefix." + base64.URLEncoding.EncodeToString([]byte(`{"not_email": "missing"}`)), }) assert.Equal(t, nil, err) var server *httptest.Server diff --git a/providers/internal_util.go b/providers/internal_util.go index b2d3b71..a2098d3 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -46,13 +46,13 @@ func stripParam(param, endpoint string) string { } // validateToken returns true if token is valid -func validateToken(p Provider, access_token string, header http.Header) bool { - if access_token == "" || p.Data().ValidateURL == nil { +func validateToken(p Provider, accessToken string, header http.Header) bool { + if accessToken == "" || p.Data().ValidateURL == nil { return false } endpoint := p.Data().ValidateURL.String() if len(header) == 0 { - params := url.Values{"access_token": {access_token}} + params := url.Values{"access_token": {accessToken}} endpoint = endpoint + "?" + params.Encode() } resp, err := api.RequestUnparsedResponse(endpoint, header) @@ -72,8 +72,3 @@ func validateToken(p Provider, access_token string, header http.Header) bool { log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) return false } - -func updateURL(url *url.URL, hostname string) { - url.Scheme = "http" - url.Host = hostname -} diff --git a/providers/internal_util_test.go b/providers/internal_util_test.go index 5fe0e8e..1a03fc5 100644 --- a/providers/internal_util_test.go +++ b/providers/internal_util_test.go @@ -10,6 +10,11 @@ import ( "github.com/stretchr/testify/assert" ) +func updateURL(url *url.URL, hostname string) { + url.Scheme = "http" + url.Host = hostname +} + type ValidateSessionStateTestProvider struct { *ProviderData } @@ -25,28 +30,28 @@ func (tp *ValidateSessionStateTestProvider) ValidateSessionState(s *SessionState } type ValidateSessionStateTest struct { - backend *httptest.Server - response_code int - provider *ValidateSessionStateTestProvider - header http.Header + backend *httptest.Server + responseCode int + provider *ValidateSessionStateTestProvider + header http.Header } func NewValidateSessionStateTest() *ValidateSessionStateTest { - var vt_test ValidateSessionStateTest + var vtTest ValidateSessionStateTest - vt_test.backend = httptest.NewServer( + vtTest.backend = httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/oauth/tokeninfo" { w.WriteHeader(500) w.Write([]byte("unknown URL")) } - token_param := r.FormValue("access_token") - if token_param == "" { + tokenParam := r.FormValue("access_token") + if tokenParam == "" { missing := false - received_headers := r.Header - for k, _ := range vt_test.header { - received := received_headers.Get(k) - expected := vt_test.header.Get(k) + receivedHeaders := r.Header + for k := range vtTest.header { + received := receivedHeaders.Get(k) + expected := vtTest.header.Get(k) if received == "" || received != expected { missing = true } @@ -56,68 +61,68 @@ func NewValidateSessionStateTest() *ValidateSessionStateTest { w.Write([]byte("no token param and missing or incorrect headers")) } } - w.WriteHeader(vt_test.response_code) + w.WriteHeader(vtTest.responseCode) w.Write([]byte("only code matters; contents disregarded")) })) - backend_url, _ := url.Parse(vt_test.backend.URL) - vt_test.provider = &ValidateSessionStateTestProvider{ + backendURL, _ := url.Parse(vtTest.backend.URL) + vtTest.provider = &ValidateSessionStateTestProvider{ ProviderData: &ProviderData{ ValidateURL: &url.URL{ Scheme: "http", - Host: backend_url.Host, + Host: backendURL.Host, Path: "/oauth/tokeninfo", }, }, } - vt_test.response_code = 200 - return &vt_test + vtTest.responseCode = 200 + return &vtTest } -func (vt_test *ValidateSessionStateTest) Close() { - vt_test.backend.Close() +func (vtTest *ValidateSessionStateTest) Close() { + vtTest.backend.Close() } func TestValidateSessionStateValidToken(t *testing.T) { - vt_test := NewValidateSessionStateTest() - defer vt_test.Close() - assert.Equal(t, true, validateToken(vt_test.provider, "foobar", nil)) + vtTest := NewValidateSessionStateTest() + defer vtTest.Close() + assert.Equal(t, true, validateToken(vtTest.provider, "foobar", nil)) } func TestValidateSessionStateValidTokenWithHeaders(t *testing.T) { - vt_test := NewValidateSessionStateTest() - defer vt_test.Close() - vt_test.header = make(http.Header) - vt_test.header.Set("Authorization", "Bearer foobar") + vtTest := NewValidateSessionStateTest() + defer vtTest.Close() + vtTest.header = make(http.Header) + vtTest.header.Set("Authorization", "Bearer foobar") assert.Equal(t, true, - validateToken(vt_test.provider, "foobar", vt_test.header)) + validateToken(vtTest.provider, "foobar", vtTest.header)) } func TestValidateSessionStateEmptyToken(t *testing.T) { - vt_test := NewValidateSessionStateTest() - defer vt_test.Close() - assert.Equal(t, false, validateToken(vt_test.provider, "", nil)) + vtTest := NewValidateSessionStateTest() + defer vtTest.Close() + assert.Equal(t, false, validateToken(vtTest.provider, "", nil)) } func TestValidateSessionStateEmptyValidateURL(t *testing.T) { - vt_test := NewValidateSessionStateTest() - defer vt_test.Close() - vt_test.provider.Data().ValidateURL = nil - assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) + vtTest := NewValidateSessionStateTest() + defer vtTest.Close() + vtTest.provider.Data().ValidateURL = nil + assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) } func TestValidateSessionStateRequestNetworkFailure(t *testing.T) { - vt_test := NewValidateSessionStateTest() + vtTest := NewValidateSessionStateTest() // Close immediately to simulate a network failure - vt_test.Close() - assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) + vtTest.Close() + assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) } func TestValidateSessionStateExpiredToken(t *testing.T) { - vt_test := NewValidateSessionStateTest() - defer vt_test.Close() - vt_test.response_code = 401 - assert.Equal(t, false, validateToken(vt_test.provider, "foobar", nil)) + vtTest := NewValidateSessionStateTest() + defer vtTest.Close() + vtTest.responseCode = 401 + assert.Equal(t, false, validateToken(vtTest.provider, "foobar", nil)) } func TestStripTokenNotPresent(t *testing.T) { diff --git a/providers/linkedin.go b/providers/linkedin.go index 8d02e95..e6de34e 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -39,11 +39,11 @@ func NewLinkedInProvider(p *ProviderData) *LinkedInProvider { return &LinkedInProvider{ProviderData: p} } -func getLinkedInHeader(access_token string) http.Header { +func getLinkedInHeader(accessToken string) http.Header { header := make(http.Header) header.Set("Accept", "application/json") header.Set("x-li-format", "json") - header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) + header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken)) return header } diff --git a/providers/linkedin_test.go b/providers/linkedin_test.go index a0d255b..7911522 100644 --- a/providers/linkedin_test.go +++ b/providers/linkedin_test.go @@ -31,8 +31,7 @@ func testLinkedInBackend(payload string) *httptest.Server { return httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { - url := r.URL - if url.Path != path { + if r.URL.Path != path { w.WriteHeader(404) } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { w.WriteHeader(403) @@ -95,8 +94,8 @@ func TestLinkedInProviderGetEmailAddress(t *testing.T) { b := testLinkedInBackend(`"user@linkedin.com"`) defer b.Close() - b_url, _ := url.Parse(b.URL) - p := testLinkedInProvider(b_url.Host) + bURL, _ := url.Parse(b.URL) + p := testLinkedInProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) @@ -108,8 +107,8 @@ func TestLinkedInProviderGetEmailAddressFailedRequest(t *testing.T) { b := testLinkedInBackend("unused payload") defer b.Close() - b_url, _ := url.Parse(b.URL) - p := testLinkedInProvider(b_url.Host) + bURL, _ := url.Parse(b.URL) + p := testLinkedInProvider(bURL.Host) // We'll trigger a request failure by using an unexpected access // token. Alternatively, we could allow the parsing of the payload as @@ -124,8 +123,8 @@ func TestLinkedInProviderGetEmailAddressEmailNotPresentInPayload(t *testing.T) { b := testLinkedInBackend("{\"foo\": \"bar\"}") defer b.Close() - b_url, _ := url.Parse(b.URL) - p := testLinkedInProvider(b_url.Host) + bURL, _ := url.Parse(b.URL) + p := testLinkedInProvider(bURL.Host) session := &SessionState{AccessToken: "imaginary_access_token"} email, err := p.GetEmailAddress(session) diff --git a/providers/provider_default.go b/providers/provider_default.go index 6fc8638..01a2060 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -121,7 +121,8 @@ func (p *ProviderData) ValidateSessionState(s *SessionState) bool { return validateToken(p, s.AccessToken, nil) } -// RefreshSessionIfNeeded +// RefreshSessionIfNeeded should refresh the user's session if required and +// do nothing if a refresh is not required func (p *ProviderData) RefreshSessionIfNeeded(s *SessionState) (bool, error) { return false, nil } diff --git a/validator.go b/validator.go index 1b04923..df25f35 100644 --- a/validator.go +++ b/validator.go @@ -42,11 +42,11 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() { log.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) } defer r.Close() - csv_reader := csv.NewReader(r) - csv_reader.Comma = ',' - csv_reader.Comment = '#' - csv_reader.TrimLeadingSpace = true - records, err := csv_reader.ReadAll() + csvReader := csv.NewReader(r) + csvReader.Comma = ',' + csvReader.Comment = '#' + csvReader.TrimLeadingSpace = true + records, err := csvReader.ReadAll() if err != nil { log.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) return diff --git a/validator_test.go b/validator_test.go index f91f41c..6e72cdb 100644 --- a/validator_test.go +++ b/validator_test.go @@ -8,15 +8,15 @@ import ( ) type ValidatorTest struct { - auth_email_file *os.File - done chan bool - update_seen bool + authEmailFile *os.File + done chan bool + updateSeen bool } func NewValidatorTest(t *testing.T) *ValidatorTest { vt := &ValidatorTest{} var err error - vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") + vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_") if err != nil { t.Fatal("failed to create temp file: " + err.Error()) } @@ -26,27 +26,27 @@ func NewValidatorTest(t *testing.T) *ValidatorTest { func (vt *ValidatorTest) TearDown() { vt.done <- true - os.Remove(vt.auth_email_file.Name()) + os.Remove(vt.authEmailFile.Name()) } func (vt *ValidatorTest) NewValidator(domains []string, updated chan<- bool) func(string) bool { - return newValidatorImpl(domains, vt.auth_email_file.Name(), + return newValidatorImpl(domains, vt.authEmailFile.Name(), vt.done, func() { - if vt.update_seen == false { + if vt.updateSeen == false { updated <- true - vt.update_seen = true + vt.updateSeen = true } }) } -// This will close vt.auth_email_file. +// This will close vt.authEmailFile. func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { - defer vt.auth_email_file.Close() - vt.auth_email_file.WriteString(strings.Join(emails, "\n")) - if err := vt.auth_email_file.Close(); err != nil { + defer vt.authEmailFile.Close() + vt.authEmailFile.WriteString(strings.Join(emails, "\n")) + if err := vt.authEmailFile.Close(); err != nil { t.Fatal("failed to close temp file " + - vt.auth_email_file.Name() + ": " + err.Error()) + vt.authEmailFile.Name() + ": " + err.Error()) } } diff --git a/validator_watcher_copy_test.go b/validator_watcher_copy_test.go index 68c4cb7..15ed6fa 100644 --- a/validator_watcher_copy_test.go +++ b/validator_watcher_copy_test.go @@ -12,18 +12,18 @@ import ( func (vt *ValidatorTest) UpdateEmailFileViaCopyingOver( t *testing.T, emails []string) { - orig_file := vt.auth_email_file + origFile := vt.authEmailFile var err error - vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") + vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_") if err != nil { t.Fatal("failed to create temp file for copy: " + err.Error()) } vt.WriteEmails(t, emails) - err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) + err = os.Rename(vt.authEmailFile.Name(), origFile.Name()) if err != nil { t.Fatal("failed to copy over temp file: " + err.Error()) } - vt.auth_email_file = orig_file + vt.authEmailFile = origFile } func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { diff --git a/validator_watcher_test.go b/validator_watcher_test.go index dc16a7d..b022d68 100644 --- a/validator_watcher_test.go +++ b/validator_watcher_test.go @@ -10,8 +10,8 @@ import ( func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { var err error - vt.auth_email_file, err = os.OpenFile( - vt.auth_email_file.Name(), os.O_WRONLY|os.O_CREATE, 0600) + vt.authEmailFile, err = os.OpenFile( + vt.authEmailFile.Name(), os.O_WRONLY|os.O_CREATE, 0600) if err != nil { t.Fatal("failed to re-open temp file for updates") } @@ -20,24 +20,24 @@ func (vt *ValidatorTest) UpdateEmailFile(t *testing.T, emails []string) { func (vt *ValidatorTest) UpdateEmailFileViaRenameAndReplace( t *testing.T, emails []string) { - orig_file := vt.auth_email_file + origFile := vt.authEmailFile var err error - vt.auth_email_file, err = ioutil.TempFile("", "test_auth_emails_") + vt.authEmailFile, err = ioutil.TempFile("", "test_auth_emails_") if err != nil { t.Fatal("failed to create temp file for rename and replace: " + err.Error()) } vt.WriteEmails(t, emails) - moved_name := orig_file.Name() + "-moved" - err = os.Rename(orig_file.Name(), moved_name) - err = os.Rename(vt.auth_email_file.Name(), orig_file.Name()) + movedName := origFile.Name() + "-moved" + err = os.Rename(origFile.Name(), movedName) + err = os.Rename(vt.authEmailFile.Name(), origFile.Name()) if err != nil { t.Fatal("failed to rename and replace temp file: " + err.Error()) } - vt.auth_email_file = orig_file - os.Remove(moved_name) + vt.authEmailFile = origFile + os.Remove(movedName) } func TestValidatorOverwriteEmailListDirectly(t *testing.T) { diff --git a/watcher.go b/watcher.go index 80297cc..6cda7d9 100644 --- a/watcher.go +++ b/watcher.go @@ -13,11 +13,11 @@ import ( func WaitForReplacement(filename string, op fsnotify.Op, watcher *fsnotify.Watcher) { - const sleep_interval = 50 * time.Millisecond + const sleepInterval = 50 * time.Millisecond // Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod. if op&fsnotify.Chmod != 0 { - time.Sleep(sleep_interval) + time.Sleep(sleepInterval) } for { if _, err := os.Stat(filename); err == nil { @@ -26,7 +26,7 @@ func WaitForReplacement(filename string, op fsnotify.Op, return } } - time.Sleep(sleep_interval) + time.Sleep(sleepInterval) } } @@ -56,7 +56,7 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) { } log.Printf("reloading after event: %s", event) action() - case err := <-watcher.Errors: + case err = <-watcher.Errors: log.Printf("error watching %s: %s", filename, err) } }