diff --git a/README.md b/README.md index da71e1b..3c0f98f 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,8 @@ You will need to register an OAuth application with a Provider (Google, Github o Valid providers are : * [Google](#google-auth-provider) *default* + +* [Azure](#azure-auth-provider) * [GitHub](#github-auth-provider) * [LinkedIn](#linkedin-auth-provider) * [MyUSA](#myusa-auth-provider) @@ -76,6 +78,15 @@ and the user will be checked against all the provided groups. Note: The user is checked against the group members list on initial authentication and every time the token is refreshed ( about once an hour ). +### Azure Auth Provider + +1. [Add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/) to your Azure Active Directory tenant. +2. On the App properties page provide the correct Sign-On URL ie `https//internal.yourcompany.com/oauth2/callback` +3. If applicable take note of your `TenantID` and provide it via the `--azure-tenant=` commandline option. Default the `common` tenant is used. + +The Azure AD auth provider uses `openid` as it default scope. It uses `https://graph.windows.net` as a default protected resource. It call to `https://graph.windows.net/me` to get the email address of the user that logs in. + + ### GitHub Auth Provider 1. Create a new project: https://github.com/settings/developers @@ -102,6 +113,12 @@ For LinkedIn, the registration steps are: The [MyUSA](https://alpha.my.usa.gov) authentication service ([GitHub](https://github.com/18F/myusa)) +### Microsoft Azure AD Provider + +For adding an application to the Microsoft Azure AD follow [these steps to add an application](https://azure.microsoft.com/en-us/documentation/articles/active-directory-integrating-applications/). + +Take note of your `TenantId` if applicable for your situation. The `TenantId` can be used to override the default `common` authorization server with a tenant specific server. + ## Email Authentication To authorize by email domain use `--email-domain=yourcompany.com`. To authorize individual email addresses use `--authenticated-emails-file=/path/to/file` with one email per line. To authorize all email addresse use `--email-domain=*`. @@ -120,6 +137,7 @@ An example [oauth2_proxy.cfg](contrib/oauth2_proxy.cfg.example) config file is i Usage of oauth2_proxy: -approval-prompt="force": Oauth approval_prompt -authenticated-emails-file="": authenticate against emails via file (one per line) + -azure-tenant="common": go to a tenant-specific or common (tenant-independent) endpoint. -basic-auth-password="": the password to set when passing the HTTP Basic Auth header -client-id="": the OAuth Client ID: ie: "123456.apps.googleusercontent.com" -client-secret="": the OAuth Client Secret @@ -151,6 +169,7 @@ Usage of oauth2_proxy: -proxy-prefix="/oauth2": the url root path that this proxy should be nested under (e.g. //sign_in) -redeem-url="": Token redemption endpoint -redirect-url="": the OAuth Redirect URL. ie: "https://internalapp.yourcompany.com/oauth2/callback" + -resource="": the resource that is being protected. ie: "https://graph.windows.net". Currently only used in the Azure provider. -request-logging=true: Log requests to stdout -scope="": Oauth scope specification -signature-key="": GAP-Signature request signature key (algorithm:secretkey) diff --git a/main.go b/main.go index a8d3f1b..dd9a100 100644 --- a/main.go +++ b/main.go @@ -38,6 +38,7 @@ func main() { flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") flagSet.Var(&emailDomains, "email-domain", "authenticate emails with the specified domain (may be given multiple times). Use * to authenticate any email") + flagSet.String("azure-tenant", "common", "go to a tenant-specific or common (tenant-independent) endpoint.") flagSet.String("github-org", "", "restrict logins to members of this organisation") flagSet.String("github-team", "", "restrict logins to members of this team") flagSet.Var(&googleGroups, "google-group", "restrict logins to members of this google group (may be given multiple times).") @@ -65,6 +66,7 @@ func main() { flagSet.String("login-url", "", "Authentication endpoint") flagSet.String("redeem-url", "", "Token redemption endpoint") flagSet.String("profile-url", "", "Profile access endpoint") + flagSet.String("resource", "", "The resource that is protected (Azure AD only)") flagSet.String("validate-url", "", "Access token validation endpoint") flagSet.String("scope", "", "OAuth scope specification") flagSet.String("approval-prompt", "force", "OAuth approval_prompt") diff --git a/options.go b/options.go index b64396c..5d4c86f 100644 --- a/options.go +++ b/options.go @@ -25,6 +25,7 @@ type Options struct { TLSKeyFile string `flag:"tls-key" cfg:"tls_key_file"` AuthenticatedEmailsFile string `flag:"authenticated-emails-file" cfg:"authenticated_emails_file"` + AzureTenant string `flag:"azure-tenant" cfg:"azure_tenant"` EmailDomains []string `flag:"email-domain" cfg:"email_domains"` GitHubOrg string `flag:"github-org" cfg:"github_org"` GitHubTeam string `flag:"github-team" cfg:"github_team"` @@ -52,13 +53,14 @@ type Options struct { // These options allow for other providers besides Google, with // potential overrides. - Provider string `flag:"provider" cfg:"provider"` - LoginURL string `flag:"login-url" cfg:"login_url"` - RedeemURL string `flag:"redeem-url" cfg:"redeem_url"` - ProfileURL string `flag:"profile-url" cfg:"profile_url"` - ValidateURL string `flag:"validate-url" cfg:"validate_url"` - Scope string `flag:"scope" cfg:"scope"` - ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt"` + Provider string `flag:"provider" cfg:"provider"` + LoginURL string `flag:"login-url" cfg:"login_url"` + RedeemURL string `flag:"redeem-url" cfg:"redeem_url"` + ProfileURL string `flag:"profile-url" cfg:"profile_url"` + ProtectedResource string `flag:"resource" cfg:"resource"` + ValidateURL string `flag:"validate-url" cfg:"validate_url"` + Scope string `flag:"scope" cfg:"scope"` + ApprovalPrompt string `flag:"approval-prompt" cfg:"approval_prompt"` RequestLogging bool `flag:"request-logging" cfg:"request_logging"` @@ -205,9 +207,12 @@ func parseProviderInfo(o *Options, msgs []string) []string { p.RedeemURL, msgs = parseURL(o.RedeemURL, "redeem", msgs) p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs) p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) + p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) o.provider = providers.New(o.Provider, p) switch p := o.provider.(type) { + case *providers.AzureProvider: + p.Configure(o.AzureTenant) case *providers.GitHubProvider: p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam) case *providers.GoogleProvider: diff --git a/providers/azure.go b/providers/azure.go new file mode 100644 index 0000000..2e8c57d --- /dev/null +++ b/providers/azure.go @@ -0,0 +1,86 @@ +package providers + +import ( + "errors" + "fmt" + "github.com/bitly/oauth2_proxy/api" + "log" + "net/http" + "net/url" +) + +type AzureProvider struct { + *ProviderData + Tenant string +} + +func NewAzureProvider(p *ProviderData) *AzureProvider { + p.ProviderName = "Azure" + + if p.ProfileURL == nil || p.ProfileURL.String() == "" { + p.ProfileURL = &url.URL{ + Scheme: "https", + Host: "graph.windows.net", + Path: "/me", + RawQuery: "api-version=1.6", + } + } + if p.ProtectedResource == nil || p.ProtectedResource.String() == "" { + p.ProtectedResource = &url.URL{ + Scheme: "https", + Host: "graph.windows.net", + } + } + if p.Scope == "" { + p.Scope = "openid" + } + + return &AzureProvider{ProviderData: p} +} + +func (p *AzureProvider) Configure(tenant string) { + p.Tenant = tenant + if tenant == "" { + p.Tenant = "common" + } + + if p.LoginURL == nil || p.LoginURL.String() == "" { + p.LoginURL = &url.URL{ + Scheme: "https", + Host: "login.microsoftonline.com", + Path: "/" + p.Tenant + "/oauth2/authorize"} + } + if p.RedeemURL == nil || p.RedeemURL.String() == "" { + p.RedeemURL = &url.URL{ + Scheme: "https", + Host: "login.microsoftonline.com", + Path: "/" + p.Tenant + "/oauth2/token", + } + } +} + +func getAzureHeader(access_token string) http.Header { + header := make(http.Header) + header.Set("Authorization", fmt.Sprintf("Bearer %s", access_token)) + return header +} + +func (p *AzureProvider) GetEmailAddress(s *SessionState) (string, error) { + if s.AccessToken == "" { + return "", errors.New("missing access token") + } + req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) + if err != nil { + return "", err + } + req.Header = getAzureHeader(s.AccessToken) + + json, err := api.Request(req) + + if err != nil { + log.Printf("failed making request %s", err) + return "", err + } + + return json.Get("mail").String() +} diff --git a/providers/azure_test.go b/providers/azure_test.go new file mode 100644 index 0000000..1aa823a --- /dev/null +++ b/providers/azure_test.go @@ -0,0 +1,135 @@ +package providers + +import ( + "github.com/bmizerany/assert" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func testAzureProvider(hostname string) *AzureProvider { + p := NewAzureProvider( + &ProviderData{ + ProviderName: "", + LoginURL: &url.URL{}, + RedeemURL: &url.URL{}, + ProfileURL: &url.URL{}, + ValidateURL: &url.URL{}, + ProtectedResource: &url.URL{}, + Scope: ""}) + if hostname != "" { + updateURL(p.Data().LoginURL, hostname) + updateURL(p.Data().RedeemURL, hostname) + updateURL(p.Data().ProfileURL, hostname) + updateURL(p.Data().ValidateURL, hostname) + updateURL(p.Data().ProtectedResource, hostname) + } + return p +} + +func TestAzureProviderDefaults(t *testing.T) { + p := testAzureProvider("") + assert.NotEqual(t, nil, p) + p.Configure("") + assert.Equal(t, "Azure", p.Data().ProviderName) + assert.Equal(t, "common", p.Tenant) + assert.Equal(t, "https://login.microsoftonline.com/common/oauth2/authorize", + p.Data().LoginURL.String()) + assert.Equal(t, "https://login.microsoftonline.com/common/oauth2/token", + p.Data().RedeemURL.String()) + assert.Equal(t, "https://graph.windows.net/me?api-version=1.6", + p.Data().ProfileURL.String()) + assert.Equal(t, "https://graph.windows.net", + p.Data().ProtectedResource.String()) + assert.Equal(t, "", + p.Data().ValidateURL.String()) + assert.Equal(t, "openid", p.Data().Scope) +} + +func TestAzureProviderOverrides(t *testing.T) { + p := NewAzureProvider( + &ProviderData{ + LoginURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/auth"}, + RedeemURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/token"}, + ProfileURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/profile"}, + ValidateURL: &url.URL{ + Scheme: "https", + Host: "example.com", + Path: "/oauth/tokeninfo"}, + ProtectedResource: &url.URL{ + Scheme: "https", + Host: "example.com"}, + Scope: "profile"}) + assert.NotEqual(t, nil, p) + assert.Equal(t, "Azure", p.Data().ProviderName) + assert.Equal(t, "https://example.com/oauth/auth", + p.Data().LoginURL.String()) + assert.Equal(t, "https://example.com/oauth/token", + p.Data().RedeemURL.String()) + assert.Equal(t, "https://example.com/oauth/profile", + p.Data().ProfileURL.String()) + assert.Equal(t, "https://example.com/oauth/tokeninfo", + p.Data().ValidateURL.String()) + assert.Equal(t, "https://example.com", + p.Data().ProtectedResource.String()) + assert.Equal(t, "profile", p.Data().Scope) +} + +func TestAzureSetTenant(t *testing.T) { + p := testAzureProvider("") + p.Configure("example") + assert.Equal(t, "Azure", p.Data().ProviderName) + assert.Equal(t, "example", p.Tenant) + assert.Equal(t, "https://login.microsoftonline.com/example/oauth2/authorize", + p.Data().LoginURL.String()) + assert.Equal(t, "https://login.microsoftonline.com/example/oauth2/token", + p.Data().RedeemURL.String()) + assert.Equal(t, "https://graph.windows.net/me?api-version=1.6", + p.Data().ProfileURL.String()) + assert.Equal(t, "https://graph.windows.net", + p.Data().ProtectedResource.String()) + assert.Equal(t, "", + p.Data().ValidateURL.String()) + assert.Equal(t, "openid", p.Data().Scope) +} + +func testAzureBackend(payload string) *httptest.Server { + path := "/me" + query := "api-version=1.6" + + return httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + url := r.URL + if url.Path != path || url.RawQuery != query { + w.WriteHeader(404) + } else if r.Header.Get("Authorization") != "Bearer imaginary_access_token" { + w.WriteHeader(403) + } else { + w.WriteHeader(200) + w.Write([]byte(payload)) + } + })) +} + +func TestAzureProviderGetEmailAddress(t *testing.T) { + b := testAzureBackend(`{ "mail": "user@windows.net" }`) + defer b.Close() + + b_url, _ := url.Parse(b.URL) + p := testAzureProvider(b_url.Host) + + session := &SessionState{AccessToken: "imaginary_access_token"} + email, err := p.GetEmailAddress(session) + assert.Equal(t, nil, err) + assert.Equal(t, "user@windows.net", email) +} diff --git a/providers/provider_data.go b/providers/provider_data.go index a13ed8e..92e27dd 100644 --- a/providers/provider_data.go +++ b/providers/provider_data.go @@ -5,15 +5,16 @@ import ( ) type ProviderData struct { - ProviderName string - ClientID string - ClientSecret string - LoginURL *url.URL - RedeemURL *url.URL - ProfileURL *url.URL - ValidateURL *url.URL - Scope string - ApprovalPrompt string + ProviderName string + ClientID string + ClientSecret string + LoginURL *url.URL + RedeemURL *url.URL + ProfileURL *url.URL + ProtectedResource *url.URL + ValidateURL *url.URL + Scope string + ApprovalPrompt string } func (p *ProviderData) Data() *ProviderData { return p } diff --git a/providers/provider_default.go b/providers/provider_default.go index 77b3dfd..82b73ec 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -25,6 +25,10 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *SessionState, err er params.Add("client_secret", p.ClientSecret) params.Add("code", code) params.Add("grant_type", "authorization_code") + if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { + params.Add("resource", p.ProtectedResource.String()) + } + var req *http.Request req, err = http.NewRequest("POST", p.RedeemURL.String(), bytes.NewBufferString(params.Encode())) if err != nil { diff --git a/providers/providers.go b/providers/providers.go index 59e5f9a..db0fe13 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -24,6 +24,8 @@ func New(provider string, p *ProviderData) Provider { return NewLinkedInProvider(p) case "github": return NewGitHubProvider(p) + case "azure": + return NewAzureProvider(p) default: return NewGoogleProvider(p) } diff --git a/watcher.go b/watcher.go index c34058b..bedb9f8 100644 --- a/watcher.go +++ b/watcher.go @@ -41,9 +41,8 @@ func WatchForUpdates(filename string, done <-chan bool, action func()) { for { select { case _ = <-done: - log.Printf("Shutting down watcher for: %s", - filename) - return + log.Printf("Shutting down watcher for: %s", filename) + break case event := <-watcher.Events: // On Arch Linux, it appears Chmod events precede Remove events, // which causes a race between action() and the coming Remove event.