diff --git a/CHANGELOG.md b/CHANGELOG.md index 74f09c2..95794b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,6 +88,7 @@ - [#159](https://github.com/pusher/oauth2_proxy/pull/159) Add option to skip the OIDC provider verified email check: `--insecure-oidc-allow-unverified-email` - [#210](https://github.com/pusher/oauth2_proxy/pull/210) Update base image from Alpine 3.9 to 3.10 (@steakunderscore) - [#211](https://github.com/pusher/oauth2_proxy/pull/211) Switch from dep to go modules (@steakunderscore) +- [#221](https://github.com/pusher/oauth2_proxy/pull/221) Add ADFS as new OAuth2 provider (@MaxFedotov) # v3.2.0 diff --git a/providers/adfs.go b/providers/adfs.go new file mode 100644 index 0000000..85dbfd4 --- /dev/null +++ b/providers/adfs.go @@ -0,0 +1,201 @@ +package providers + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + "time" + + "github.com/pusher/oauth2_proxy/pkg/apis/sessions" + "github.com/pusher/oauth2_proxy/pkg/logger" +) + +// ADFSProvider represents an ADFS based Identity Provider +type ADFSProvider struct { + *ProviderData +} + +type adfsClaims struct { + Upn string `json:"upn"` + Email string `json:"email"` +} + +// NewADFSProvider initiates a new ADFSProvider +func NewADFSProvider(p *ProviderData) *ADFSProvider { + p.ProviderName = "ADFS" + + if p.Scope == "" { + p.Scope = "openid" + } + return &ADFSProvider{ProviderData: p} +} + +func adfsClaimsFromIDToken(idToken string) (*adfsClaims, error) { + jwt := strings.Split(idToken, ".") + jwtData := strings.TrimSuffix(jwt[1], "=") + b, err := base64.RawURLEncoding.DecodeString(jwtData) + if err != nil { + return nil, err + } + + c := &adfsClaims{} + err = json.Unmarshal(b, c) + if err != nil { + return nil, err + } + if c.Email == "" { + c.Email = c.Upn + } + return c, nil +} + +// GetLoginURL overrides GetLoginURL to add ADFS parameters +func (p *ADFSProvider) GetLoginURL(redirectURI, state string) string { + var a url.URL + a = *p.LoginURL + params, _ := url.ParseQuery(a.RawQuery) + params.Set("redirect_uri", redirectURI) + params.Set("client_id", p.ClientID) + params.Set("response_type", "code") + params.Add("state", state) + params.Add("resource", p.ProtectedResource.String()) + a.RawQuery = params.Encode() + return a.String() +} + +// Redeem exchanges the OAuth2 authentication token for an Access\ID tokens +func (p *ADFSProvider) Redeem(redirectURL, code string) (s *sessions.SessionState, err error) { + if code == "" { + err = errors.New("missing code") + return + } + params := url.Values{} + params.Add("grant_type", "authorization_code") + params.Add("code", code) + params.Add("client_id", p.ClientID) + params.Add("redirect_uri", redirectURL) + params.Add("resource", p.ProtectedResource.String()) + if p.ClientSecret != "" { + params.Add("client_secret", p.ClientSecret) + } + var req *http.Request + req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) + if err != nil { + return + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + var body []byte + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return + } + + if resp.StatusCode != 200 { + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) + return + } + + var jsonResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + IDToken string `json:"id_token"` + } + err = json.Unmarshal(body, &jsonResponse) + if err != nil { + return + } + c, err := adfsClaimsFromIDToken(jsonResponse.IDToken) + if err != nil { + return + } + s = &sessions.SessionState{ + AccessToken: jsonResponse.AccessToken, + IDToken: jsonResponse.IDToken, + CreatedAt: time.Now(), + ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), + RefreshToken: jsonResponse.RefreshToken, + Email: c.Email, + User: c.Upn, + } + return +} + +// RefreshSessionIfNeeded checks if the session has expired and uses the +// RefreshToken to fetch a new Access\ID tokens if required +func (p *ADFSProvider) RefreshSessionIfNeeded(s *sessions.SessionState) (bool, error) { + if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" { + return false, nil + } + + newToken, newIDToken, duration, err := p.redeemRefreshToken(s.RefreshToken) + if err != nil { + return false, err + } + + origExpiration := s.ExpiresOn + s.AccessToken = newToken + s.IDToken = newIDToken + s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second) + logger.Printf("refreshed access token %s (expired on %s)", s, origExpiration) + return true, nil +} + +func (p *ADFSProvider) redeemRefreshToken(refreshToken string) (token string, idToken string, expires time.Duration, err error) { + params := url.Values{} + params.Add("grant_type", "refresh_token") + params.Add("resource", p.ProtectedResource.String()) + params.Add("client_id", p.ClientID) + params.Add("refresh_token", refreshToken) + if p.ClientSecret != "" { + params.Add("client_secret", p.ClientSecret) + } + var req *http.Request + req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) + if err != nil { + return + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + var body []byte + body, err = ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return + } + + if resp.StatusCode != 200 { + err = fmt.Errorf("got %d from %q %s", resp.StatusCode, p.RedeemURL.String(), body) + return + } + + var data struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + IDToken string `json:"id_token"` + } + err = json.Unmarshal(body, &data) + if err != nil { + return + } + token = data.AccessToken + idToken = data.IDToken + expires = time.Duration(data.ExpiresIn) * time.Second + return +} diff --git a/providers/adfs_test.go b/providers/adfs_test.go new file mode 100644 index 0000000..05941df --- /dev/null +++ b/providers/adfs_test.go @@ -0,0 +1,91 @@ +package providers + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" +) + +type redeemResponseADFS struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + IDToken string `json:"id_token"` +} + +func newADFSRedeemServer(body []byte) (*url.URL, *httptest.Server) { + s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Write(body) + })) + u, _ := url.Parse(s.URL) + return u, s +} + +func newADFSProvider() *ADFSProvider { + return NewADFSProvider( + &ProviderData{ + ProviderName: "", + LoginURL: &url.URL{}, + RedeemURL: &url.URL{}, + ProtectedResource: &url.URL{}, + Scope: ""}) +} + +func TestADFSProviderDefaults(t *testing.T) { + p := newADFSProvider() + assert.NotEqual(t, nil, p) + assert.Equal(t, "ADFS", p.Data().ProviderName) + assert.Equal(t, "", p.Data().LoginURL.String()) + assert.Equal(t, "", p.Data().RedeemURL.String()) + assert.Equal(t, "", p.Data().ProtectedResource.String()) + assert.Equal(t, "openid", p.Data().Scope) +} + +func TestADFSProviderGetEmailAddressAndUpn(t *testing.T) { + p := newADFSProvider() + body, err := json.Marshal(redeemResponseADFS{ + AccessToken: "test12345", + ExpiresIn: 10, + RefreshToken: "refreshtest12345", + IDToken: "jwt header." + base64.URLEncoding.EncodeToString([]byte(`{"upn": "m_fedotov@gmail.com", "email": "m_fedotov@gmail.com"}`)), + }) + assert.Equal(t, nil, err) + var server *httptest.Server + p.RedeemURL, server = newADFSRedeemServer(body) + defer server.Close() + + session, err := p.Redeem("http://redirect/", "code1234") + assert.Equal(t, nil, err) + assert.NotEqual(t, session, nil) + assert.Equal(t, "m_fedotov@gmail.com", session.Email) + assert.Equal(t, "test12345", session.AccessToken) + assert.Equal(t, "refreshtest12345", session.RefreshToken) + assert.Equal(t, "m_fedotov@gmail.com", session.User) +} + +func TestADFSProviderGetUpnOnly(t *testing.T) { + p := newADFSProvider() + body, err := json.Marshal(redeemResponseADFS{ + AccessToken: "test12345", + ExpiresIn: 10, + RefreshToken: "refreshtest12345", + IDToken: "jwt header." + base64.URLEncoding.EncodeToString([]byte(`{"upn": "m_fedotov@gmail.com"}`)), + }) + assert.Equal(t, nil, err) + var server *httptest.Server + p.RedeemURL, server = newADFSRedeemServer(body) + defer server.Close() + + session, err := p.Redeem("http://redirect/", "code1234") + assert.Equal(t, nil, err) + assert.NotEqual(t, session, nil) + assert.Equal(t, "m_fedotov@gmail.com", session.Email) + assert.Equal(t, "test12345", session.AccessToken) + assert.Equal(t, "refreshtest12345", session.RefreshToken) + assert.Equal(t, "m_fedotov@gmail.com", session.User) +} diff --git a/providers/providers.go b/providers/providers.go index baf723d..bac37e3 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -36,6 +36,8 @@ func New(provider string, p *ProviderData) Provider { return NewOIDCProvider(p) case "login.gov": return NewLoginGovProvider(p) + case "adfs": + return NewADFSProvider(p) default: return NewGoogleProvider(p) }