oauth2_proxy/providers/adfs.go
2019-07-23 00:22:19 +03:00

202 lines
5.1 KiB
Go

package providers
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/pkg/logger"
)
// ADFSProvider represents an ADFS based Identity Provider
type ADFSProvider struct {
*ProviderData
}
type adfsClaims struct {
Upn string `json:"upn"`
Email string `json:"email"`
}
// NewADFSProvider initiates a new ADFSProvider
func NewADFSProvider(p *ProviderData) *ADFSProvider {
p.ProviderName = "ADFS"
if p.Scope == "" {
p.Scope = "openid"
}
return &ADFSProvider{ProviderData: p}
}
func adfsClaimsFromIDToken(idToken string) (*adfsClaims, error) {
jwt := strings.Split(idToken, ".")
jwtData := strings.TrimSuffix(jwt[1], "=")
b, err := base64.RawURLEncoding.DecodeString(jwtData)
if err != nil {
return nil, err
}
c := &adfsClaims{}
err = json.Unmarshal(b, c)
if err != nil {
return nil, err
}
if c.Email == "" {
c.Email = c.Upn
}
return c, nil
}
// GetLoginURL overrides GetLoginURL to add ADFS parameters
func (p *ADFSProvider) GetLoginURL(redirectURI, state string) string {
var a url.URL
a = *p.LoginURL
params, _ := url.ParseQuery(a.RawQuery)
params.Set("redirect_uri", redirectURI)
params.Set("client_id", p.ClientID)
params.Set("response_type", "code")
params.Add("state", state)
params.Add("resource", p.ProtectedResource.String())
a.RawQuery = params.Encode()
return a.String()
}
// Redeem exchanges the OAuth2 authentication token for an Access\ID tokens
func (p *ADFSProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" {
err = errors.New("missing code")
return
}
params := url.Values{}
params.Add("grant_type", "authorization_code")
params.Add("code", code)
params.Add("client_id", p.ClientID)
params.Add("redirect_uri", redirectURL)
params.Add("resource", p.ProtectedResource.String())
if p.ClientSecret != "" {
params.Add("client_secret", p.ClientSecret)
}
var req *http.Request
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var jsonResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &jsonResponse)
if err != nil {
return
}
c, err := adfsClaimsFromIDToken(jsonResponse.IDToken)
if err != nil {
return
}
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
RefreshToken: jsonResponse.RefreshToken,
Email: c.Email,
User: c.Upn,
}
return
}
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new Access\ID tokens if required
func (p *ADFSProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil
}
newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
if err != nil {
return false, err
}
origExpiration := s.ExpiresOn
s.AccessToken = newToken
s.IDToken = newIDToken
s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
return true, nil
}
func (p *ADFSProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) {
params := url.Values{}
params.Add("grant_type", "refresh_token")
params.Add("resource", p.ProtectedResource.String())
params.Add("client_id", p.ClientID)
params.Add("refresh_token", refreshToken)
if p.ClientSecret != "" {
params.Add("client_secret", p.ClientSecret)
}
var req *http.Request
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var data struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &data)
if err != nil {
return
}
token = data.AccessToken
idToken = data.IDToken
expires = time.Duration(data.ExpiresIn) * time.Second
return
}