add adfs provider and tests
This commit is contained in:
parent
8635391543
commit
7872309e28
@ -88,6 +88,7 @@
|
||||
- [#159](https://github.com/pusher/oauth2_proxy/pull/159) Add option to skip the OIDC provider verified email check: `--insecure-oidc-allow-unverified-email`
|
||||
- [#210](https://github.com/pusher/oauth2_proxy/pull/210) Update base image from Alpine 3.9 to 3.10 (@steakunderscore)
|
||||
- [#211](https://github.com/pusher/oauth2_proxy/pull/211) Switch from dep to go modules (@steakunderscore)
|
||||
- [#221](https://github.com/pusher/oauth2_proxy/pull/221) Add ADFS as new OAuth2 provider (@MaxFedotov)
|
||||
|
||||
# v3.2.0
|
||||
|
||||
|
201
providers/adfs.go
Normal file
201
providers/adfs.go
Normal file
@ -0,0 +1,201 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
"github.com/pusher/oauth2_proxy/pkg/logger"
|
||||
)
|
||||
|
||||
// ADFSProvider represents an ADFS based Identity Provider
|
||||
type ADFSProvider struct {
|
||||
*ProviderData
|
||||
}
|
||||
|
||||
type adfsClaims struct {
|
||||
Upn string `json:"upn"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// NewADFSProvider initiates a new ADFSProvider
|
||||
func NewADFSProvider(p *ProviderData) *ADFSProvider {
|
||||
p.ProviderName = "ADFS"
|
||||
|
||||
if p.Scope == "" {
|
||||
p.Scope = "openid"
|
||||
}
|
||||
return &ADFSProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
func adfsClaimsFromIDToken(idToken string) (*adfsClaims, error) {
|
||||
jwt := strings.Split(idToken, ".")
|
||||
jwtData := strings.TrimSuffix(jwt[1], "=")
|
||||
b, err := base64.RawURLEncoding.DecodeString(jwtData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &adfsClaims{}
|
||||
err = json.Unmarshal(b, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.Email == "" {
|
||||
c.Email = c.Upn
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetLoginURL overrides GetLoginURL to add ADFS parameters
|
||||
func (p *ADFSProvider) GetLoginURL(redirectURI, state string) string {
|
||||
var a url.URL
|
||||
a = *p.LoginURL
|
||||
params, _ := url.ParseQuery(a.RawQuery)
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("client_id", p.ClientID)
|
||||
params.Set("response_type", "code")
|
||||
params.Add("state", state)
|
||||
params.Add("resource", p.ProtectedResource.String())
|
||||
a.RawQuery = params.Encode()
|
||||
return a.String()
|
||||
}
|
||||
|
||||
// Redeem exchanges the OAuth2 authentication token for an Access\ID tokens
|
||||
func (p *ADFSProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||
if code == "" {
|
||||
err = errors.New("missing code")
|
||||
return
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Add("grant_type", "authorization_code")
|
||||
params.Add("code", code)
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("redirect_uri", redirectURL)
|
||||
params.Add("resource", p.ProtectedResource.String())
|
||||
if p.ClientSecret != "" {
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
}
|
||||
var req *http.Request
|
||||
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
}
|
||||
|
||||
var jsonResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &jsonResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c, err := adfsClaimsFromIDToken(jsonResponse.IDToken)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s = &sessions.SessionState{
|
||||
AccessToken: jsonResponse.AccessToken,
|
||||
IDToken: jsonResponse.IDToken,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||
RefreshToken: jsonResponse.RefreshToken,
|
||||
Email: c.Email,
|
||||
User: c.Upn,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||
// RefreshToken to fetch a new Access\ID tokens if required
|
||||
func (p *ADFSProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
origExpiration := s.ExpiresOn
|
||||
s.AccessToken = newToken
|
||||
s.IDToken = newIDToken
|
||||
s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
|
||||
logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *ADFSProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) {
|
||||
params := url.Values{}
|
||||
params.Add("grant_type", "refresh_token")
|
||||
params.Add("resource", p.ProtectedResource.String())
|
||||
params.Add("client_id", p.ClientID)
|
||||
params.Add("refresh_token", refreshToken)
|
||||
if p.ClientSecret != "" {
|
||||
params.Add("client_secret", p.ClientSecret)
|
||||
}
|
||||
var req *http.Request
|
||||
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||
return
|
||||
}
|
||||
|
||||
var data struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
err = json.Unmarshal(body, &data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
token = data.AccessToken
|
||||
idToken = data.IDToken
|
||||
expires = time.Duration(data.ExpiresIn) * time.Second
|
||||
return
|
||||
}
|
91
providers/adfs_test.go
Normal file
91
providers/adfs_test.go
Normal file
@ -0,0 +1,91 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type redeemResponseADFS struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
IDToken string `json:"id_token"`
|
||||
}
|
||||
|
||||
func newADFSRedeemServer(body []byte) (*url.URL, *httptest.Server) {
|
||||
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write(body)
|
||||
}))
|
||||
u, _ := url.Parse(s.URL)
|
||||
return u, s
|
||||
}
|
||||
|
||||
func newADFSProvider() *ADFSProvider {
|
||||
return NewADFSProvider(
|
||||
&ProviderData{
|
||||
ProviderName: "",
|
||||
LoginURL: &url.URL{},
|
||||
RedeemURL: &url.URL{},
|
||||
ProtectedResource: &url.URL{},
|
||||
Scope: ""})
|
||||
}
|
||||
|
||||
func TestADFSProviderDefaults(t *testing.T) {
|
||||
p := newADFSProvider()
|
||||
assert.NotEqual(t, nil, p)
|
||||
assert.Equal(t, "ADFS", p.Data().ProviderName)
|
||||
assert.Equal(t, "", p.Data().LoginURL.String())
|
||||
assert.Equal(t, "", p.Data().RedeemURL.String())
|
||||
assert.Equal(t, "", p.Data().ProtectedResource.String())
|
||||
assert.Equal(t, "openid", p.Data().Scope)
|
||||
}
|
||||
|
||||
func TestADFSProviderGetEmailAddressAndUpn(t *testing.T) {
|
||||
p := newADFSProvider()
|
||||
body, err := json.Marshal(redeemResponseADFS{
|
||||
AccessToken: "test12345",
|
||||
ExpiresIn: 10,
|
||||
RefreshToken: "refreshtest12345",
|
||||
IDToken: "jwt header." + base64.URLEncoding.EncodeToString([]byte(`{"upn": "m_fedotov@gmail.com", "email": "m_fedotov@gmail.com"}`)),
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
var server *httptest.Server
|
||||
p.RedeemURL, server = newADFSRedeemServer(body)
|
||||
defer server.Close()
|
||||
|
||||
session, err := p.Redeem("http://redirect/", "code1234")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, session, nil)
|
||||
assert.Equal(t, "m_fedotov@gmail.com", session.Email)
|
||||
assert.Equal(t, "test12345", session.AccessToken)
|
||||
assert.Equal(t, "refreshtest12345", session.RefreshToken)
|
||||
assert.Equal(t, "m_fedotov@gmail.com", session.User)
|
||||
}
|
||||
|
||||
func TestADFSProviderGetUpnOnly(t *testing.T) {
|
||||
p := newADFSProvider()
|
||||
body, err := json.Marshal(redeemResponseADFS{
|
||||
AccessToken: "test12345",
|
||||
ExpiresIn: 10,
|
||||
RefreshToken: "refreshtest12345",
|
||||
IDToken: "jwt header." + base64.URLEncoding.EncodeToString([]byte(`{"upn": "m_fedotov@gmail.com"}`)),
|
||||
})
|
||||
assert.Equal(t, nil, err)
|
||||
var server *httptest.Server
|
||||
p.RedeemURL, server = newADFSRedeemServer(body)
|
||||
defer server.Close()
|
||||
|
||||
session, err := p.Redeem("http://redirect/", "code1234")
|
||||
assert.Equal(t, nil, err)
|
||||
assert.NotEqual(t, session, nil)
|
||||
assert.Equal(t, "m_fedotov@gmail.com", session.Email)
|
||||
assert.Equal(t, "test12345", session.AccessToken)
|
||||
assert.Equal(t, "refreshtest12345", session.RefreshToken)
|
||||
assert.Equal(t, "m_fedotov@gmail.com", session.User)
|
||||
}
|
@ -36,6 +36,8 @@ func New(provider string, p *ProviderData) Provider {
|
||||
return NewOIDCProvider(p)
|
||||
case "login.gov":
|
||||
return NewLoginGovProvider(p)
|
||||
case "adfs":
|
||||
return NewADFSProvider(p)
|
||||
default:
|
||||
return NewGoogleProvider(p)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user