add adfs provider and tests

This commit is contained in:
Maksim Fedotov 2019-07-23 00:22:19 +03:00
parent 8635391543
commit 7872309e28
4 changed files with 295 additions and 0 deletions

View File

@ -88,6 +88,7 @@
- [#159](https://github.com/pusher/oauth2_proxy/pull/159) Add option to skip the OIDC provider verified email check: `--insecure-oidc-allow-unverified-email`
- [#210](https://github.com/pusher/oauth2_proxy/pull/210) Update base image from Alpine 3.9 to 3.10 (@steakunderscore)
- [#211](https://github.com/pusher/oauth2_proxy/pull/211) Switch from dep to go modules (@steakunderscore)
- [#221](https://github.com/pusher/oauth2_proxy/pull/221) Add ADFS as new OAuth2 provider (@MaxFedotov)
# v3.2.0

201
providers/adfs.go Normal file
View File

@ -0,0 +1,201 @@
package providers
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
"github.com/pusher/oauth2_proxy/pkg/logger"
)
// ADFSProvider represents an ADFS based Identity Provider
type ADFSProvider struct {
*ProviderData
}
type adfsClaims struct {
Upn string `json:"upn"`
Email string `json:"email"`
}
// NewADFSProvider initiates a new ADFSProvider
func NewADFSProvider(p *ProviderData) *ADFSProvider {
p.ProviderName = "ADFS"
if p.Scope == "" {
p.Scope = "openid"
}
return &ADFSProvider{ProviderData: p}
}
func adfsClaimsFromIDToken(idToken string) (*adfsClaims, error) {
jwt := strings.Split(idToken, ".")
jwtData := strings.TrimSuffix(jwt[1], "=")
b, err := base64.RawURLEncoding.DecodeString(jwtData)
if err != nil {
return nil, err
}
c := &adfsClaims{}
err = json.Unmarshal(b, c)
if err != nil {
return nil, err
}
if c.Email == "" {
c.Email = c.Upn
}
return c, nil
}
// GetLoginURL overrides GetLoginURL to add ADFS parameters
func (p *ADFSProvider) GetLoginURL(redirectURI, state string) string {
var a url.URL
a = *p.LoginURL
params, _ := url.ParseQuery(a.RawQuery)
params.Set("redirect_uri", redirectURI)
params.Set("client_id", p.ClientID)
params.Set("response_type", "code")
params.Add("state", state)
params.Add("resource", p.ProtectedResource.String())
a.RawQuery = params.Encode()
return a.String()
}
// Redeem exchanges the OAuth2 authentication token for an Access\ID tokens
func (p *ADFSProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) {
if code == "" {
err = errors.New("missing code")
return
}
params := url.Values{}
params.Add("grant_type", "authorization_code")
params.Add("code", code)
params.Add("client_id", p.ClientID)
params.Add("redirect_uri", redirectURL)
params.Add("resource", p.ProtectedResource.String())
if p.ClientSecret != "" {
params.Add("client_secret", p.ClientSecret)
}
var req *http.Request
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var jsonResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &jsonResponse)
if err != nil {
return
}
c, err := adfsClaimsFromIDToken(jsonResponse.IDToken)
if err != nil {
return
}
s = &sessions.SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
RefreshToken: jsonResponse.RefreshToken,
Email: c.Email,
User: c.Upn,
}
return
}
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new Access\ID tokens if required
func (p *ADFSProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
return false, nil
}
newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken)
if err != nil {
return false, err
}
origExpiration := s.ExpiresOn
s.AccessToken = newToken
s.IDToken = newIDToken
s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
return true, nil
}
func (p *ADFSProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) {
params := url.Values{}
params.Add("grant_type", "refresh_token")
params.Add("resource", p.ProtectedResource.String())
params.Add("client_id", p.ClientID)
params.Add("refresh_token", refreshToken)
if p.ClientSecret != "" {
params.Add("client_secret", p.ClientSecret)
}
var req *http.Request
req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode()))
if err != nil {
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return
}
var body []byte
body, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
return
}
if resp.StatusCode != 200 {
err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body)
return
}
var data struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
err = json.Unmarshal(body, &data)
if err != nil {
return
}
token = data.AccessToken
idToken = data.IDToken
expires = time.Duration(data.ExpiresIn) * time.Second
return
}

91
providers/adfs_test.go Normal file
View File

@ -0,0 +1,91 @@
package providers
import (
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
)
type redeemResponseADFS struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
IDToken string `json:"id_token"`
}
func newADFSRedeemServer(body []byte) (*url.URL, *httptest.Server) {
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Write(body)
}))
u, _ := url.Parse(s.URL)
return u, s
}
func newADFSProvider() *ADFSProvider {
return NewADFSProvider(
&ProviderData{
ProviderName: "",
LoginURL: &url.URL{},
RedeemURL: &url.URL{},
ProtectedResource: &url.URL{},
Scope: ""})
}
func TestADFSProviderDefaults(t *testing.T) {
p := newADFSProvider()
assert.NotEqual(t, nil, p)
assert.Equal(t, "ADFS", p.Data().ProviderName)
assert.Equal(t, "", p.Data().LoginURL.String())
assert.Equal(t, "", p.Data().RedeemURL.String())
assert.Equal(t, "", p.Data().ProtectedResource.String())
assert.Equal(t, "openid", p.Data().Scope)
}
func TestADFSProviderGetEmailAddressAndUpn(t *testing.T) {
p := newADFSProvider()
body, err := json.Marshal(redeemResponseADFS{
AccessToken: "test12345",
ExpiresIn: 10,
RefreshToken: "refreshtest12345",
IDToken: "jwt header." + base64.URLEncoding.EncodeToString([]byte(`{"upn": "m_fedotov@gmail.com", "email": "m_fedotov@gmail.com"}`)),
})
assert.Equal(t, nil, err)
var server *httptest.Server
p.RedeemURL, server = newADFSRedeemServer(body)
defer server.Close()
session, err := p.Redeem("http://redirect/", "code1234")
assert.Equal(t, nil, err)
assert.NotEqual(t, session, nil)
assert.Equal(t, "m_fedotov@gmail.com", session.Email)
assert.Equal(t, "test12345", session.AccessToken)
assert.Equal(t, "refreshtest12345", session.RefreshToken)
assert.Equal(t, "m_fedotov@gmail.com", session.User)
}
func TestADFSProviderGetUpnOnly(t *testing.T) {
p := newADFSProvider()
body, err := json.Marshal(redeemResponseADFS{
AccessToken: "test12345",
ExpiresIn: 10,
RefreshToken: "refreshtest12345",
IDToken: "jwt header." + base64.URLEncoding.EncodeToString([]byte(`{"upn": "m_fedotov@gmail.com"}`)),
})
assert.Equal(t, nil, err)
var server *httptest.Server
p.RedeemURL, server = newADFSRedeemServer(body)
defer server.Close()
session, err := p.Redeem("http://redirect/", "code1234")
assert.Equal(t, nil, err)
assert.NotEqual(t, session, nil)
assert.Equal(t, "m_fedotov@gmail.com", session.Email)
assert.Equal(t, "test12345", session.AccessToken)
assert.Equal(t, "refreshtest12345", session.RefreshToken)
assert.Equal(t, "m_fedotov@gmail.com", session.User)
}

View File

@ -36,6 +36,8 @@ func New(provider string, p *ProviderData) Provider {
return NewOIDCProvider(p)
case "login.gov":
return NewLoginGovProvider(p)
case "adfs":
return NewADFSProvider(p)
default:
return NewGoogleProvider(p)
}