diff --git a/README.md b/README.md index 4cb2bc0..3dd94ea 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,8 @@ For Google, the registration steps are: * Fill in the necessary fields and Save (this is _required_) 5. Take note of the **Client ID** and **Client Secret** +It's recommended to refresh sessions on a short interval (1h) with `cookie-refresh` setting which validates that the account is still authorized. + ### GitHub Auth Provider 1. Create a new project: https://github.com/settings/developers @@ -100,7 +102,7 @@ Usage of oauth2_proxy: -cookie-expire=168h0m0s: expire timeframe for cookie -cookie-httponly=true: set HttpOnly cookie flag -cookie-key="_oauth2_proxy": the name of the cookie that the oauth_proxy creates - -cookie-refresh=0: refresh the cookie when less than this much time remains before expiration; 0 to disable + -cookie-refresh=0: refresh the cookie after this duration; 0 to disable -cookie-secret="": the seed string for secure cookies -cookie-secure=true: set secure (HTTPS) cookie flag -custom-templates-dir="": path to custom html templates diff --git a/main.go b/main.go index fb97e87..acf18fe 100644 --- a/main.go +++ b/main.go @@ -50,7 +50,7 @@ func main() { flagSet.String("cookie-secret", "", "the seed string for secure cookies") flagSet.String("cookie-domain", "", "an optional cookie domain to force cookies to (ie: .yourcompany.com)*") flagSet.Duration("cookie-expire", time.Duration(168)*time.Hour, "expire timeframe for cookie") - flagSet.Duration("cookie-refresh", time.Duration(0), "refresh the cookie when less than this much time remains before expiration; 0 to disable") + flagSet.Duration("cookie-refresh", time.Duration(0), "refresh the cookie after this duration; 0 to disable") flagSet.Bool("cookie-secure", true, "set secure (HTTPS) cookie flag") flagSet.Bool("cookie-httponly", true, "set HttpOnly cookie flag") diff --git a/oauthproxy.go b/oauthproxy.go index 2591384..62084cb 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -246,7 +246,13 @@ func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (e } else if ok && p.CookieRefresh != time.Duration(0) { refresh := timestamp.Add(p.CookieRefresh) if refresh.Before(time.Now()) { - ok = p.Validator(email) && p.provider.ValidateToken(access_token) + log.Printf("refreshing %s old session for %s (refresh after %s)", time.Now().Sub(timestamp), email, p.CookieRefresh) + ok = p.Validator(email) + log.Printf("re-validating %s valid:%v", email, ok) + if ok { + ok = p.provider.ValidateToken(access_token) + log.Printf("re-validating access token. valid:%v", ok) + } if ok { p.SetCookie(rw, req, value) } @@ -432,6 +438,7 @@ func (p *OauthProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { http.Redirect(rw, req, redirect, 302) return } else { + log.Printf("validating: %s is unauthorized") p.ErrorPage(rw, 403, "Permission Denied", "Invalid Account") return } diff --git a/options_test.go b/options_test.go index fc1233b..8d8fdf8 100644 --- a/options_test.go +++ b/options_test.go @@ -96,9 +96,9 @@ func TestDefaultProviderApiSettings(t *testing.T) { o := testOptions() assert.Equal(t, nil, o.Validate()) p := o.provider.Data() - assert.Equal(t, "https://accounts.google.com/o/oauth2/auth", + assert.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", p.LoginUrl.String()) - assert.Equal(t, "https://accounts.google.com/o/oauth2/token", + assert.Equal(t, "https://www.googleapis.com/oauth2/v3/token", p.RedeemUrl.String()) assert.Equal(t, "", p.ProfileUrl.String()) assert.Equal(t, "profile email", p.Scope) diff --git a/providers/google.go b/providers/google.go index 265eaa6..40a6228 100644 --- a/providers/google.go +++ b/providers/google.go @@ -1,15 +1,20 @@ package providers import ( + "bytes" "encoding/base64" "encoding/json" "errors" + "fmt" + "io/ioutil" + "net/http" "net/url" "strings" ) type GoogleProvider struct { *ProviderData + RedeemRefreshUrl *url.URL } func NewGoogleProvider(p *ProviderData) *GoogleProvider { @@ -17,12 +22,15 @@ func NewGoogleProvider(p *ProviderData) *GoogleProvider { if p.LoginUrl.String() == "" { p.LoginUrl = &url.URL{Scheme: "https", Host: "accounts.google.com", - Path: "/o/oauth2/auth"} + Path: "/o/oauth2/auth", + // to get a refresh token. see https://developers.google.com/identity/protocols/OAuth2WebServer#offline + RawQuery: "access_type=offline", + } } if p.RedeemUrl.String() == "" { p.RedeemUrl = &url.URL{Scheme: "https", - Host: "accounts.google.com", - Path: "/o/oauth2/token"} + Host: "www.googleapis.com", + Path: "/oauth2/v3/token"} } if p.ValidateUrl.String() == "" { p.ValidateUrl = &url.URL{Scheme: "https", @@ -76,3 +84,90 @@ func jwtDecodeSegment(seg string) ([]byte, error) { func (p *GoogleProvider) ValidateToken(access_token string) bool { return validateToken(p, access_token, nil) } + +func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token string, err error) { + if code == "" { + err = errors.New("missing code") + return + } + + params := url.Values{} + params.Add("redirect_uri", redirectUrl) + params.Add("client_id", p.ClientID) + params.Add("client_secret", p.ClientSecret) + params.Add("code", code) + params.Add("grant_type", "authorization_code") + var req *http.Request + req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) + if err != nil { + return + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return + } + + if resp.StatusCode != 200 { + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body) + return + } + + var jsonResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + } + err = json.Unmarshal(body, &jsonResponse) + if err != nil { + return + } + + token, err = p.redeemRefreshToken(jsonResponse.RefreshToken) + return +} + +func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, err error) { + // https://developers.google.com/identity/protocols/OAuth2WebServer#refresh + params := url.Values{} + params.Add("client_id", p.ClientID) + params.Add("client_secret", p.ClientSecret) + params.Add("refresh_token", refreshToken) + params.Add("grant_type", "refresh_token") + var req *http.Request + req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode())) + if err != nil { + return + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + var body []byte + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return + } + + if resp.StatusCode != 200 { + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body) + return + } + + var jsonResponse struct { + AccessToken string `json:"access_token"` + } + err = json.Unmarshal(body, &jsonResponse) + if err != nil { + return + } + return jsonResponse.AccessToken, nil +} diff --git a/providers/google_test.go b/providers/google_test.go index 7456d3b..7551640 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -23,9 +23,9 @@ func TestGoogleProviderDefaults(t *testing.T) { p := newGoogleProvider() assert.NotEqual(t, nil, p) assert.Equal(t, "Google", p.Data().ProviderName) - assert.Equal(t, "https://accounts.google.com/o/oauth2/auth", + assert.Equal(t, "https://accounts.google.com/o/oauth2/auth?access_type=offline", p.Data().LoginUrl.String()) - assert.Equal(t, "https://accounts.google.com/o/oauth2/token", + assert.Equal(t, "https://www.googleapis.com/oauth2/v3/token", p.Data().RedeemUrl.String()) assert.Equal(t, "https://www.googleapis.com/oauth2/v1/tokeninfo", p.Data().ValidateUrl.String()) diff --git a/providers/provider_default.go b/providers/provider_default.go index 7a884a6..edf7338 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -56,15 +56,19 @@ func (p *ProviderData) Redeem(redirectUrl, code string) (body []byte, token stri return body, v.Get("access_token"), err } +// GetLoginURL with typical oauth parameters func (p *ProviderData) GetLoginURL(redirectURI, finalRedirect string) string { - params := url.Values{} - params.Add("redirect_uri", redirectURI) - params.Add("approval_prompt", "force") + var a url.URL + a = *p.LoginUrl + params, _ := url.ParseQuery(a.RawQuery) + params.Set("redirect_uri", redirectURI) + params.Set("approval_prompt", "force") params.Add("scope", p.Scope) - params.Add("client_id", p.ClientID) - params.Add("response_type", "code") + params.Set("client_id", p.ClientID) + params.Set("response_type", "code") if strings.HasPrefix(finalRedirect, "/") { params.Add("state", finalRedirect) } - return fmt.Sprintf("%s?%s", p.LoginUrl, params.Encode()) + a.RawQuery = params.Encode() + return a.String() } diff --git a/validator.go b/validator.go index 9c476ba..396e605 100644 --- a/validator.go +++ b/validator.go @@ -83,7 +83,6 @@ func newValidatorImpl(domains []string, usersFile string, if allowAll { valid = true } - log.Printf("validating: is %s valid? %v", email, valid) return valid } return validator