202 lines
5.1 KiB
Go
202 lines
5.1 KiB
Go
package providers
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
|
"github.com/pusher/oauth2_proxy/pkg/logger"
|
|
)
|
|
|
|
// ADFSProvider represents an ADFS based Identity Provider
|
|
type ADFSProvider struct {
|
|
*ProviderData
|
|
}
|
|
|
|
type adfsClaims struct {
|
|
Upn string `json:"upn"`
|
|
Email string `json:"email"`
|
|
}
|
|
|
|
// NewADFSProvider initiates a new ADFSProvider
|
|
func NewADFSProvider(p *ProviderData) *ADFSProvider {
|
|
p.ProviderName = "ADFS"
|
|
|
|
if p.Scope == "" {
|
|
p.Scope = "openid"
|
|
}
|
|
return &ADFSProvider{ProviderData: p}
|
|
}
|
|
|
|
func adfsClaimsFromIDToken(idToken string) (*adfsClaims, error) {
|
|
jwt := strings.Split(idToken, ".")
|
|
jwtData := strings.TrimSuffix(jwt[1], "=")
|
|
b, err := base64.RawURLEncoding.DecodeString(jwtData)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c := &adfsClaims{}
|
|
err = json.Unmarshal(b, c)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if c.Email == "" {
|
|
c.Email = c.Upn
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
// GetLoginURL overrides GetLoginURL to add ADFS parameters
|
|
func (p *ADFSProvider) GetLoginURL(redirectURI, state string) string {
|
|
var a url.URL
|
|
a = *p.LoginURL
|
|
params, _ := url.ParseQuery(a.RawQuery)
|
|
params.Set("redirect_uri", redirectURI)
|
|
params.Set("client_id", p.ClientID)
|
|
params.Set("response_type", "code")
|
|
params.Add("state", state)
|
|
params.Add("resource", p.ProtectedResource.String())
|
|
a.RawQuery = params.Encode()
|
|
return a.String()
|
|
}
|
|
|
|
// Redeem exchanges the OAuth2 authentication token for an Access\ID tokens
|
|
func (p *ADFSProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
|
if code == "" {
|
|
err = errors.New("missing code")
|
|
return
|
|
}
|
|
params := url.Values{}
|
|
params.Add("grant_type", "authorization_code")
|
|
params.Add("code", code)
|
|
params.Add("client_id", p.ClientID)
|
|
params.Add("redirect_uri", redirectURL)
|
|
params.Add("resource", p.ProtectedResource.String())
|
|
if p.ClientSecret != "" {
|
|
params.Add("client_secret", p.ClientSecret)
|
|
}
|
|
var req *http.Request
|
|
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
|
if err != nil {
|
|
return
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
var body []byte
|
|
body, err = ioutil.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
|
return
|
|
}
|
|
|
|
var jsonResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token"`
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
err = json.Unmarshal(body, &jsonResponse)
|
|
if err != nil {
|
|
return
|
|
}
|
|
c, err := adfsClaimsFromIDToken(jsonResponse.IDToken)
|
|
if err != nil {
|
|
return
|
|
}
|
|
s = &sessions.SessionState{
|
|
AccessToken: jsonResponse.AccessToken,
|
|
IDToken: jsonResponse.IDToken,
|
|
CreatedAt: time.Now(),
|
|
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
|
RefreshToken: jsonResponse.RefreshToken,
|
|
Email: c.Email,
|
|
User: c.Upn,
|
|
}
|
|
return
|
|
}
|
|
|
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
|
// RefreshToken to fetch a new Access\ID tokens if required
|
|
func (p *ADFSProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
|
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
|
return false, nil
|
|
}
|
|
|
|
newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
origExpiration := s.ExpiresOn
|
|
s.AccessToken = newToken
|
|
s.IDToken = newIDToken
|
|
s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
|
|
logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
|
|
return true, nil
|
|
}
|
|
|
|
func (p *ADFSProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) {
|
|
params := url.Values{}
|
|
params.Add("grant_type", "refresh_token")
|
|
params.Add("resource", p.ProtectedResource.String())
|
|
params.Add("client_id", p.ClientID)
|
|
params.Add("refresh_token", refreshToken)
|
|
if p.ClientSecret != "" {
|
|
params.Add("client_secret", p.ClientSecret)
|
|
}
|
|
var req *http.Request
|
|
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
|
if err != nil {
|
|
return
|
|
}
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return
|
|
}
|
|
var body []byte
|
|
body, err = ioutil.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
|
return
|
|
}
|
|
|
|
var data struct {
|
|
AccessToken string `json:"access_token"`
|
|
ExpiresIn int64 `json:"expires_in"`
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
err = json.Unmarshal(body, &data)
|
|
if err != nil {
|
|
return
|
|
}
|
|
token = data.AccessToken
|
|
idToken = data.IDToken
|
|
expires = time.Duration(data.ExpiresIn) * time.Second
|
|
return
|
|
}
|