add adfs provider and tests
This commit is contained in:
parent
8635391543
commit
7872309e28
@ -88,6 +88,7 @@
|
|||||||
- [#159](https://github.com/pusher/oauth2_proxy/pull/159) Add option to skip the OIDC provider verified email check: `--insecure-oidc-allow-unverified-email`
|
- [#159](https://github.com/pusher/oauth2_proxy/pull/159) Add option to skip the OIDC provider verified email check: `--insecure-oidc-allow-unverified-email`
|
||||||
- [#210](https://github.com/pusher/oauth2_proxy/pull/210) Update base image from Alpine 3.9 to 3.10 (@steakunderscore)
|
- [#210](https://github.com/pusher/oauth2_proxy/pull/210) Update base image from Alpine 3.9 to 3.10 (@steakunderscore)
|
||||||
- [#211](https://github.com/pusher/oauth2_proxy/pull/211) Switch from dep to go modules (@steakunderscore)
|
- [#211](https://github.com/pusher/oauth2_proxy/pull/211) Switch from dep to go modules (@steakunderscore)
|
||||||
|
- [#221](https://github.com/pusher/oauth2_proxy/pull/221) Add ADFS as new OAuth2 provider (@MaxFedotov)
|
||||||
|
|
||||||
# v3.2.0
|
# v3.2.0
|
||||||
|
|
||||||
|
201
providers/adfs.go
Normal file
201
providers/adfs.go
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||||
|
"github.com/pusher/oauth2_proxy/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ADFSProvider represents an ADFS based Identity Provider
|
||||||
|
type ADFSProvider struct {
|
||||||
|
*ProviderData
|
||||||
|
}
|
||||||
|
|
||||||
|
type adfsClaims struct {
|
||||||
|
Upn string `json:"upn"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewADFSProvider initiates a new ADFSProvider
|
||||||
|
func NewADFSProvider(p *ProviderData) *ADFSProvider {
|
||||||
|
p.ProviderName = "ADFS"
|
||||||
|
|
||||||
|
if p.Scope == "" {
|
||||||
|
p.Scope = "openid"
|
||||||
|
}
|
||||||
|
return &ADFSProvider{ProviderData: p}
|
||||||
|
}
|
||||||
|
|
||||||
|
func adfsClaimsFromIDToken(idToken string) (*adfsClaims, error) {
|
||||||
|
jwt := strings.Split(idToken, ".")
|
||||||
|
jwtData := strings.TrimSuffix(jwt[1], "=")
|
||||||
|
b, err := base64.RawURLEncoding.DecodeString(jwtData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &adfsClaims{}
|
||||||
|
err = json.Unmarshal(b, c)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if c.Email == "" {
|
||||||
|
c.Email = c.Upn
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLoginURL overrides GetLoginURL to add ADFS parameters
|
||||||
|
func (p *ADFSProvider) GetLoginURL(redirectURI, state string) string {
|
||||||
|
var a url.URL
|
||||||
|
a = *p.LoginURL
|
||||||
|
params, _ := url.ParseQuery(a.RawQuery)
|
||||||
|
params.Set("redirect_uri", redirectURI)
|
||||||
|
params.Set("client_id", p.ClientID)
|
||||||
|
params.Set("response_type", "code")
|
||||||
|
params.Add("state", state)
|
||||||
|
params.Add("resource", p.ProtectedResource.String())
|
||||||
|
a.RawQuery = params.Encode()
|
||||||
|
return a.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redeem exchanges the OAuth2 authentication token for an Access\ID tokens
|
||||||
|
func (p *ADFSProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
|
||||||
|
if code == "" {
|
||||||
|
err = errors.New("missing code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("grant_type", "authorization_code")
|
||||||
|
params.Add("code", code)
|
||||||
|
params.Add("client_id", p.ClientID)
|
||||||
|
params.Add("redirect_uri", redirectURL)
|
||||||
|
params.Add("resource", p.ProtectedResource.String())
|
||||||
|
if p.ClientSecret != "" {
|
||||||
|
params.Add("client_secret", p.ClientSecret)
|
||||||
|
}
|
||||||
|
var req *http.Request
|
||||||
|
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var body []byte
|
||||||
|
body, err = ioutil.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var jsonResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(body, &jsonResponse)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c, err := adfsClaimsFromIDToken(jsonResponse.IDToken)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s = &sessions.SessionState{
|
||||||
|
AccessToken: jsonResponse.AccessToken,
|
||||||
|
IDToken: jsonResponse.IDToken,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
|
||||||
|
RefreshToken: jsonResponse.RefreshToken,
|
||||||
|
Email: c.Email,
|
||||||
|
User: c.Upn,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshSessionIfNeeded checks if the session has expired and uses the
|
||||||
|
// RefreshToken to fetch a new Access\ID tokens if required
|
||||||
|
func (p *ADFSProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
|
||||||
|
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
origExpiration := s.ExpiresOn
|
||||||
|
s.AccessToken = newToken
|
||||||
|
s.IDToken = newIDToken
|
||||||
|
s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
|
||||||
|
logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ADFSProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) {
|
||||||
|
params := url.Values{}
|
||||||
|
params.Add("grant_type", "refresh_token")
|
||||||
|
params.Add("resource", p.ProtectedResource.String())
|
||||||
|
params.Add("client_id", p.ClientID)
|
||||||
|
params.Add("refresh_token", refreshToken)
|
||||||
|
if p.ClientSecret != "" {
|
||||||
|
params.Add("client_secret", p.ClientSecret)
|
||||||
|
}
|
||||||
|
var req *http.Request
|
||||||
|
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var body []byte
|
||||||
|
body, err = ioutil.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var data struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(body, &data)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token = data.AccessToken
|
||||||
|
idToken = data.IDToken
|
||||||
|
expires = time.Duration(data.ExpiresIn) * time.Second
|
||||||
|
return
|
||||||
|
}
|
91
providers/adfs_test.go
Normal file
91
providers/adfs_test.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type redeemResponseADFS struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func newADFSRedeemServer(body []byte) (*url.URL, *httptest.Server) {
|
||||||
|
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
rw.Write(body)
|
||||||
|
}))
|
||||||
|
u, _ := url.Parse(s.URL)
|
||||||
|
return u, s
|
||||||
|
}
|
||||||
|
|
||||||
|
func newADFSProvider() *ADFSProvider {
|
||||||
|
return NewADFSProvider(
|
||||||
|
&ProviderData{
|
||||||
|
ProviderName: "",
|
||||||
|
LoginURL: &url.URL{},
|
||||||
|
RedeemURL: &url.URL{},
|
||||||
|
ProtectedResource: &url.URL{},
|
||||||
|
Scope: ""})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestADFSProviderDefaults(t *testing.T) {
|
||||||
|
p := newADFSProvider()
|
||||||
|
assert.NotEqual(t, nil, p)
|
||||||
|
assert.Equal(t, "ADFS", p.Data().ProviderName)
|
||||||
|
assert.Equal(t, "", p.Data().LoginURL.String())
|
||||||
|
assert.Equal(t, "", p.Data().RedeemURL.String())
|
||||||
|
assert.Equal(t, "", p.Data().ProtectedResource.String())
|
||||||
|
assert.Equal(t, "openid", p.Data().Scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestADFSProviderGetEmailAddressAndUpn(t *testing.T) {
|
||||||
|
p := newADFSProvider()
|
||||||
|
body, err := json.Marshal(redeemResponseADFS{
|
||||||
|
AccessToken: "test12345",
|
||||||
|
ExpiresIn: 10,
|
||||||
|
RefreshToken: "refreshtest12345",
|
||||||
|
IDToken: "jwt header." + base64.URLEncoding.EncodeToString([]byte(`{"upn": "m_fedotov@gmail.com", "email": "m_fedotov@gmail.com"}`)),
|
||||||
|
})
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
var server *httptest.Server
|
||||||
|
p.RedeemURL, server = newADFSRedeemServer(body)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
session, err := p.Redeem("http://redirect/", "code1234")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.NotEqual(t, session, nil)
|
||||||
|
assert.Equal(t, "m_fedotov@gmail.com", session.Email)
|
||||||
|
assert.Equal(t, "test12345", session.AccessToken)
|
||||||
|
assert.Equal(t, "refreshtest12345", session.RefreshToken)
|
||||||
|
assert.Equal(t, "m_fedotov@gmail.com", session.User)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestADFSProviderGetUpnOnly(t *testing.T) {
|
||||||
|
p := newADFSProvider()
|
||||||
|
body, err := json.Marshal(redeemResponseADFS{
|
||||||
|
AccessToken: "test12345",
|
||||||
|
ExpiresIn: 10,
|
||||||
|
RefreshToken: "refreshtest12345",
|
||||||
|
IDToken: "jwt header." + base64.URLEncoding.EncodeToString([]byte(`{"upn": "m_fedotov@gmail.com"}`)),
|
||||||
|
})
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
var server *httptest.Server
|
||||||
|
p.RedeemURL, server = newADFSRedeemServer(body)
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
session, err := p.Redeem("http://redirect/", "code1234")
|
||||||
|
assert.Equal(t, nil, err)
|
||||||
|
assert.NotEqual(t, session, nil)
|
||||||
|
assert.Equal(t, "m_fedotov@gmail.com", session.Email)
|
||||||
|
assert.Equal(t, "test12345", session.AccessToken)
|
||||||
|
assert.Equal(t, "refreshtest12345", session.RefreshToken)
|
||||||
|
assert.Equal(t, "m_fedotov@gmail.com", session.User)
|
||||||
|
}
|
@ -36,6 +36,8 @@ func New(provider string, p *ProviderData) Provider {
|
|||||||
return NewOIDCProvider(p)
|
return NewOIDCProvider(p)
|
||||||
case "login.gov":
|
case "login.gov":
|
||||||
return NewLoginGovProvider(p)
|
return NewLoginGovProvider(p)
|
||||||
|
case "adfs":
|
||||||
|
return NewADFSProvider(p)
|
||||||
default:
|
default:
|
||||||
return NewGoogleProvider(p)
|
return NewGoogleProvider(p)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user