oauth2_proxy/providers/google.go

200 lines
4.9 KiB
Go
Raw Permalink Normal View History

package providers
import (
"bytes"
"encoding/base64"
2015-05-21 03:23:48 +00:00
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/url"
"strings"
"time"
)
type GoogleProvider struct {
*ProviderData
RedeemRefreshUrl *url.URL
}
func NewGoogleProvider(p *ProviderData) *GoogleProvider {
p.ProviderName = "Google"
if p.LoginUrl.String() == "" {
p.LoginUrl = &url.URL{Scheme: "https",
Host: "accounts.google.com",
Path: "/o/oauth2/auth",
// to get a refresh token. see https://developers.google.com/identity/protocols/OAuth2WebServer#offline
RawQuery: "access_type=offline",
}
}
if p.RedeemUrl.String() == "" {
p.RedeemUrl = &url.URL{Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v3/token"}
}
2015-05-08 21:13:35 +00:00
if p.ValidateUrl.String() == "" {
p.ValidateUrl = &url.URL{Scheme: "https",
Host: "www.googleapis.com",
Path: "/oauth2/v1/tokeninfo"}
}
if p.Scope == "" {
p.Scope = "profile email"
}
return &GoogleProvider{ProviderData: p}
}
func emailFromIdToken(idToken string) (string, error) {
2015-05-21 03:23:48 +00:00
// id_token is a base64 encode ID token payload
// https://developers.google.com/accounts/docs/OAuth2Login#obtainuserinfo
jwt := strings.Split(idToken, ".")
b, err := jwtDecodeSegment(jwt[1])
if err != nil {
return "", err
}
2015-05-21 03:23:48 +00:00
var email struct {
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
}
2015-05-21 03:23:48 +00:00
err = json.Unmarshal(b, &email)
if err != nil {
return "", err
}
2015-05-21 03:23:48 +00:00
if email.Email == "" {
return "", errors.New("missing email")
}
if !email.EmailVerified {
return "", fmt.Errorf("email %s not listed as verified", email.Email)
}
2015-05-21 03:23:48 +00:00
return email.Email, nil
}
func jwtDecodeSegment(seg string) ([]byte, error) {
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}
return base64.URLEncoding.DecodeString(seg)
}
2015-05-13 01:48:13 +00:00
func (p *GoogleProvider) Redeem(redirectUrl, code string) (s *SessionState, err error) {
if code == "" {
err = errors.New("missing code")
return
}
params := url.Values{}
params.Add("redirect_uri", redirectUrl)
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("code", code)
params.Add("grant_type", "authorization_code")
var req *http.Request
req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body)
return
}
var jsonResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
IdToken string `json:"id_token"`
}
err = json.Unmarshal(body, &jsonResponse)
if err != nil {
return
}
var email string
email, err = emailFromIdToken(jsonResponse.IdToken)
if err != nil {
return
}
s = &SessionState{
AccessToken: jsonResponse.AccessToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
RefreshToken: jsonResponse.RefreshToken,
Email: email,
}
return
}
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil
}
newToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
if err != nil {
return false, err
}
origExpiration := s.ExpiresOn
s.AccessToken = newToken
s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
log.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
return true, nil
}
func (p *GoogleProvider) redeemRefreshToken(refreshToken string) (token string, expires time.Duration, err error) {
// https://developers.google.com/identity/protocols/OAuth2WebServer#refresh
params := url.Values{}
params.Add("client_id", p.ClientID)
params.Add("client_secret", p.ClientSecret)
params.Add("refresh_token", refreshToken)
params.Add("grant_type", "refresh_token")
var req *http.Request
req, err = http.NewRequest("POST", p.RedeemUrl.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemUrl.String(), body)
return
}
var data struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
}
err = json.Unmarshal(body, &data)
if err != nil {
return
}
token = data.AccessToken
expires = time.Duration(data.ExpiresIn) * time.Second
return
}