From 7ecb9fd2d6b293f8aec9173c5a1c0236d19a09c9 Mon Sep 17 00:00:00 2001 From: Lukasz Leszczuk Date: Fri, 16 Aug 2019 23:31:19 +0200 Subject: [PATCH] Initial work on porting https://github.com/bitly/oauth2_proxy/pull/347/. ToDo: port tests --- .gitignore | 1 + Makefile | 33 ++-- main.go | 7 + oauthproxy.go | 40 ++++- options.go | 24 ++- pkg/apis/sessions/session_state.go | 29 ++++ providers/azure.go | 248 ++++++++++++++++++++++++++++- providers/provider_default.go | 25 ++- providers/providers.go | 4 + 9 files changed, 383 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index a5f59b4..0f9427c 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ release *.exe .env .bundle +.idea/ # Go.gitignore # Compiled Object files, Static and Dynamic libs (Shared Objects) diff --git a/Makefile b/Makefile index f41f320..5f3612f 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,6 @@ include .env BINARY := oauth2_proxy +REPOSITORY := quay.io/pusher VERSION := $(shell git describe --always --dirty --tags 2>/dev/null || echo "undefined") .NOTPARALLEL: @@ -27,31 +28,31 @@ $(BINARY): .PHONY: docker docker: - docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:latest . + docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:latest . .PHONY: docker-all docker-all: docker - docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:latest-amd64 . - docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:${VERSION} . - docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:${VERSION}-amd64 . - docker build -f Dockerfile.arm64 -t quay.io/pusher/oauth2_proxy:latest-arm64 . - docker build -f Dockerfile.arm64 -t quay.io/pusher/oauth2_proxy:${VERSION}-arm64 . - docker build -f Dockerfile.armv6 -t quay.io/pusher/oauth2_proxy:latest-armv6 . - docker build -f Dockerfile.armv6 -t quay.io/pusher/oauth2_proxy:${VERSION}-armv6 . + docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:latest-amd64 . + docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:${VERSION} . + docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:${VERSION}-amd64 . + docker build -f Dockerfile.arm64 -t ${REPOSITORY}/oauth2_proxy:latest-arm64 . + docker build -f Dockerfile.arm64 -t ${REPOSITORY}/oauth2_proxy:${VERSION}-arm64 . + docker build -f Dockerfile.armv6 -t ${REPOSITORY}/oauth2_proxy:latest-armv6 . + docker build -f Dockerfile.armv6 -t ${REPOSITORY}/oauth2_proxy:${VERSION}-armv6 . .PHONY: docker-push docker-push: - docker push quay.io/pusher/oauth2_proxy:latest + docker push ${REPOSITORY}/oauth2_proxy:latest .PHONY: docker-push-all docker-push-all: docker-push - docker push quay.io/pusher/oauth2_proxy:latest-amd64 - docker push quay.io/pusher/oauth2_proxy:${VERSION} - docker push quay.io/pusher/oauth2_proxy:${VERSION}-amd64 - docker push quay.io/pusher/oauth2_proxy:latest-arm64 - docker push quay.io/pusher/oauth2_proxy:${VERSION}-arm64 - docker push quay.io/pusher/oauth2_proxy:latest-armv6 - docker push quay.io/pusher/oauth2_proxy:${VERSION}-armv6 + docker push ${REPOSITORY}/oauth2_proxy:latest-amd64 + docker push ${REPOSITORY}/oauth2_proxy:${VERSION} + docker push ${REPOSITORY}/oauth2_proxy:${VERSION}-amd64 + docker push ${REPOSITORY}/oauth2_proxy:latest-arm64 + docker push ${REPOSITORY}/oauth2_proxy:${VERSION}-arm64 + docker push ${REPOSITORY}/oauth2_proxy:latest-armv6 + docker push ${REPOSITORY}/oauth2_proxy:${VERSION}-armv6 .PHONY: test test: lint diff --git a/main.go b/main.go index a9f1e4a..fccea35 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,8 @@ func main() { skipAuthRegex := StringArray{} jwtIssuers := StringArray{} googleGroups := StringArray{} + permittedGroups := StringArray{} + exemptedUsers := StringArray{} redisSentinelConnectionURLs := StringArray{} config := flagSet.String("config", "", "path to config file") @@ -39,6 +41,11 @@ func main() { flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path") flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream") + flagSet.Bool("pass-groups", false, "pass user group information in the X-Forwarded-Groups header to upstream (Azure only)") + flagSet.String("filter-groups", "", "exclude groups that do not contain this value in its 'displayName' (Azure only)") + flagSet.Var(&permittedGroups, "permit-groups", "restrict logins to members of this group (may be given multiple times; Azure).") + flagSet.String("groups-delimiter", "|", "delimiter between group names if more than one found. By default it is '|' symbol") + flagSet.Var(&exemptedUsers, "permit-users", "let these users in if azure call to check group membership fails (may be given multiple times; Azure).") flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") diff --git a/oauthproxy.go b/oauthproxy.go index 2418e73..f30cfd3 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -87,6 +87,9 @@ type OAuthProxy struct { serveMux http.Handler SetXAuthRequest bool PassBasicAuth bool + PassGroups bool + GroupsDelimiter string + FilterGroups string SkipProviderButton bool PassUserHeaders bool BasicAuthPassword string @@ -280,6 +283,9 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { compiledRegex: opts.CompiledRegex, SetXAuthRequest: opts.SetXAuthRequest, PassBasicAuth: opts.PassBasicAuth, + PassGroups: opts.PassGroups, + GroupsDelimiter: opts.GroupsDelimiter, + FilterGroups: opts.FilterGroups, PassUserHeaders: opts.PassUserHeaders, BasicAuthPassword: opts.BasicAuthPassword, PassAccessToken: opts.PassAccessToken, @@ -327,7 +333,16 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState, } if s.Email == "" { - s.Email, err = p.provider.GetEmailAddress(s) + userDetails, err := p.provider.GetUserDetails(s) + if err != nil { + return s, err + } + s.Email = userDetails["email"] + if uid, found := userDetails["uid"]; found { + s.ID = uid + } else { + s.ID = "" + } } if s.User == "" { @@ -654,12 +669,27 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } + session.IDToken = req.Form.Get("id_token") + if p.PassGroups && session.IDToken != "" { + groups, err := p.provider.GetGroups(session, p.FilterGroups) + if err != nil { + p.ErrorPage(rw, 500, "Internal Error", "Internal Error") + return + } + + groupNames := []string{} + for groupName := range groups { + groupNames = append(groupNames, groupName) + } + session.Groups = strings.Join(groupNames, p.GroupsDelimiter) + } + if !p.IsValidRedirect(redirect) { redirect = "/" } // set cookie, or deny - if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { + if p.Validator(session.Email) && p.provider.ValidateGroupWithSession(session) { logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) err := p.SaveSession(rw, req, session) if err != nil { @@ -823,6 +853,9 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req } else { req.Header.Del("X-Forwarded-Email") } + if p.PassGroups && session.Groups != "" { + req.Header["X-Forwarded-Groups"] = []string{session.Groups} + } } if p.PassUserHeaders { @@ -849,6 +882,9 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req rw.Header().Del("X-Auth-Request-Access-Token") } } + if p.PassGroups && session.Groups != "" { + rw.Header().Set("X-Auth-Request-Groups", session.Groups) + } } if p.PassAccessToken { diff --git a/options.go b/options.go index 706f6d5..d7b98bc 100644 --- a/options.go +++ b/options.go @@ -14,7 +14,7 @@ import ( "strings" "time" - oidc "github.com/coreos/go-oidc" + "github.com/coreos/go-oidc" "github.com/dgrijalva/jwt-go" "github.com/mbland/hmacauth" "github.com/pusher/oauth2_proxy/pkg/apis/options" @@ -69,6 +69,11 @@ type Options struct { SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens" env:"OAUTH2_PROXY_SKIP_JWT_BEARER_TOKENS"` ExtraJwtIssuers []string `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers" env:"OAUTH2_PROXY_EXTRA_JWT_ISSUERS"` PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth" env:"OAUTH2_PROXY_PASS_BASIC_AUTH"` + PassGroups bool `flag:"pass-groups" cfg:"pass_groups"` + FilterGroups string `flag:"filter-groups" cfg:"filter_groups"` + PermitGroups []string `flag:"permit-groups" cfg:"permit_groups"` + GroupsDelimiter string `flag:"groups-delimiter" cfg:"groups_delimiter"` + PermitUsers []string `flag:"permit-users" cfg:"permit_users"` BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password" env:"OAUTH2_PROXY_BASIC_AUTH_PASSWORD"` PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token" env:"OAUTH2_PROXY_PASS_ACCESS_TOKEN"` PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header" env:"OAUTH2_PROXY_PASS_HOST_HEADER"` @@ -159,6 +164,11 @@ func NewOptions() *Options { SkipAuthPreflight: false, PassBasicAuth: true, PassUserHeaders: true, + PassGroups: false, + FilterGroups: "", + GroupsDelimiter: "|", + PermitGroups: []string{}, + PermitUsers: []string{}, PassAccessToken: false, PassHostHeader: true, SetAuthorization: false, @@ -380,6 +390,7 @@ func (o *Options) Validate() error { } func parseProviderInfo(o *Options, msgs []string) []string { + var splittedGroups []string p := &providers.ProviderData{ Scope: o.Scope, ClientID: o.ClientID, @@ -391,11 +402,20 @@ func parseProviderInfo(o *Options, msgs []string) []string { p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs) p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) - + if len(o.PermitGroups) > 0 { + splittedGroups = strings.Split(o.PermitGroups[0], o.GroupsDelimiter) + } o.provider = providers.New(o.Provider, p) switch p := o.provider.(type) { case *providers.AzureProvider: p.Configure(o.AzureTenant) + logger.Printf("PermitGroups %+v\n", splittedGroups) + if len(splittedGroups) > 0 { + p.SetGroupRestriction(splittedGroups) + } + if len(o.PermitUsers) > 0 { + p.SetGroupsExemption(o.PermitUsers) + } case *providers.GitHubProvider: p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam) case *providers.GoogleProvider: diff --git a/pkg/apis/sessions/session_state.go b/pkg/apis/sessions/session_state.go index 84c0dc9..1f9818f 100644 --- a/pkg/apis/sessions/session_state.go +++ b/pkg/apis/sessions/session_state.go @@ -19,6 +19,8 @@ type SessionState struct { RefreshToken string `json:",omitempty"` Email string `json:",omitempty"` User string `json:",omitempty"` + ID string `json:",omitempty"` + Groups string `json:",omitempty"` } // SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value @@ -62,6 +64,9 @@ func (s *SessionState) String() string { if s.RefreshToken != "" { o += " refresh_token:true" } + if s.Groups != "" { + o += fmt.Sprintf(" group:%s", s.Groups) + } return o + "}" } @@ -105,6 +110,18 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error) return "", err } } + if ss.Groups != "" { + ss.Groups, err = c.Encrypt(ss.Groups) + if err != nil { + return "", err + } + } + if ss.ID != "" { + ss.ID, err = c.Encrypt(ss.ID) + if err != nil { + return "", err + } + } } // Embed SessionState and ExpiresOn pointer into SessionStateJSON ssj := &SessionStateJSON{SessionState: &ss} @@ -235,6 +252,18 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) { return nil, err } } + if ss.Groups != "" { + ss.Groups, err = c.Decrypt(ss.Groups) + if err != nil { + return nil, err + } + } + if ss.ID != "" { + ss.ID, err = c.Decrypt(ss.ID) + if err != nil { + return nil, err + } + } } if ss.User == "" { ss.User = ss.Email diff --git a/providers/azure.go b/providers/azure.go index 653090b..30ad6e9 100644 --- a/providers/azure.go +++ b/providers/azure.go @@ -3,8 +3,10 @@ package providers import ( "errors" "fmt" + "github.com/dgrijalva/jwt-go" "net/http" "net/url" + "strings" "github.com/bitly/go-simplejson" "github.com/pusher/oauth2_proxy/pkg/apis/sessions" @@ -15,7 +17,9 @@ import ( // AzureProvider represents an Azure based Identity Provider type AzureProvider struct { *ProviderData - Tenant string + Tenant string + PermittedGroups map[string]string + ExemptedUsers map[string]string } // NewAzureProvider initiates a new AzureProvider @@ -24,22 +28,26 @@ func NewAzureProvider(p *ProviderData) *AzureProvider { if p.ProfileURL == nil || p.ProfileURL.String() == "" { p.ProfileURL = &url.URL{ - Scheme: "https", - Host: "graph.windows.net", - Path: "/me", - RawQuery: "api-version=1.6", + Scheme: "https", + Host: "graph.microsoft.com", + Path: "/v1.0/me", } } if p.ProtectedResource == nil || p.ProtectedResource.String() == "" { p.ProtectedResource = &url.URL{ Scheme: "https", - Host: "graph.windows.net", + Host: "graph.microsoft.com", } } if p.Scope == "" { p.Scope = "openid" } + if p.ApprovalPrompt == "force" { + p.ApprovalPrompt = "consent" + } + logger.Printf("Approval prompt: '%s'", p.ApprovalPrompt) + return &AzureProvider{ProviderData: p} } @@ -72,22 +80,44 @@ func getAzureHeader(accessToken string) http.Header { } func getEmailFromJSON(json *simplejson.Json) (string, error) { + // First try to return `userPrincipalName` + // if not defined, try to return `mail` + // if that also failed, try to get first record from `otherMails` + // TODO: Return everything in list and then try requests one by one + var email string var err error + email, err = json.Get("userPrincipalName").String() + if err == nil { + return email, err + } + email, err = json.Get("mail").String() if err != nil || email == "" { otherMails, otherMailsErr := json.Get("otherMails").Array() if len(otherMails) > 0 { email = otherMails[0].(string) + err = otherMailsErr } - err = otherMailsErr } return email, err } +func getUserIDFromJSON(json *simplejson.Json) (string, error) { + // Try to get user ID + // if not defined, return empty string + + uid, err := json.Get("id").String() + if err != nil { + return "", err + } + + return uid, err +} + // GetEmailAddress returns the Account email address func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { var email string @@ -128,3 +158,207 @@ func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error return email, err } + +func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) { + userDetails := map[string]string{} + var err error + + if s.AccessToken == "" { + return userDetails, errors.New("missing access token") + } + req, err := http.NewRequest("GET", p.ProfileURL.String(), nil) + if err != nil { + return userDetails, err + } + req.Header = getAzureHeader(s.AccessToken) + + json, err := requests.Request(req) + + if err != nil { + return userDetails, err + } + + logger.Printf(" JSON: %v", json) + for key, value := range json.Interface().(map[string]interface{}) { + logger.Printf("\t %20v : %v", key, value) + } + email, err := getEmailFromJSON(json) + userDetails["email"] = email + + if err != nil { + logger.Printf("[GetEmailAddress] failed making request: %s", err) + return userDetails, err + } + + uid, err := getUserIDFromJSON(json) + userDetails["uid"] = uid + if err != nil { + logger.Printf("[GetEmailAddress] failed to get User ID: %s", err) + } + + if email == "" { + logger.Printf("failed to get email address") + return userDetails, errors.New("Client email not found") + } + logger.Printf("[GetEmailAddress] Chosen email address: '%s'", email) + + return userDetails, nil +} + +// Get list of groups user belong to. Filter the desired names of groups (in case of huge group set) +func (p *AzureProvider) GetGroups(s *sessions.SessionState, f string) (map[string]string, error) { + // Azure App Registration requires setting groupMembershipClaims to include group membership in identity token + // This option is available through ARM template only. + // For details refer to: https://docs.microsoft.com/pl-pl/azure/active-directory/develop/reference-app-manifest + if s.IDToken == "" { + return map[string]string{}, errors.New("missing id token") + } + + type GroupClaims struct { + Groups []string `json:"groups"` + jwt.StandardClaims + } + + claims := &GroupClaims{} + jwt.ParseWithClaims(s.IDToken, claims, func(token *jwt.Token) (interface{}, error) { + return []byte("empty"), nil + }) + + groupsMap := make(map[string]string) + for _, s := range claims.Groups { + groupsMap[s] = s + } + return groupsMap, nil +} + +// ValidateExemptions checks if we can allow user login dispite group membership returned failure +func (p *AzureProvider) ValidateExemptions(s *sessions.SessionState) (bool, string) { + logger.Printf("ValidateExemptions: validating for %v : %v", s.Email, s.ID) + for eAccount, eGroup := range p.ExemptedUsers { + if eAccount == s.Email || eAccount == s.Email+":"+s.ID { + logger.Printf("ValidateExemptions: \t found '%v' user in exemption list. Returning '%v' group membership", eAccount, eGroup) + return true, eGroup + } + } + return false, "" +} + +func (p *AzureProvider) GetLoginURL(redirectURI, state string) string { + var a url.URL + a = *p.LoginURL + params, _ := url.ParseQuery(a.RawQuery) + params.Set("client_id", p.ClientID) + params.Set("response_type", "id_token code") + params.Set("redirect_uri", redirectURI) + params.Set("response_mode", "form_post") + params.Add("scope", p.Scope) + params.Add("state", state) + params.Set("prompt", p.ApprovalPrompt) + params.Set("nonce", "FIXME") + if p.ProtectedResource != nil && p.ProtectedResource.String() != "" { + params.Add("resource", p.ProtectedResource.String()) + } + a.RawQuery = params.Encode() + + return a.String() +} + +func (p *AzureProvider) SetGroupRestriction(groups []string) { + // Get list of groups (optionally with Group IDs) that ONLY allowed for user + // That means even if user has wider group membership, only membership in those groups will be forwarded + + p.PermittedGroups = make(map[string]string) + if len(groups) == 0 { + return + } + logger.Printf("Set group restrictions. Allowed groups are:") + logger.Printf("\t *GROUP NAME* : *GROUP ID*") + for _, pGroup := range groups { + splittedGroup := strings.Split(pGroup, ":") + var groupName string + var groupID string + + if len(splittedGroup) == 1 { + groupName, groupID = splittedGroup[0], "" + p.PermittedGroups[splittedGroup[0]] = "" + } else if len(splittedGroup) > 2 { + logger.Fatalf("failed to parse '%v'. Too many ':' separators", pGroup) + } else { + groupName, groupID = splittedGroup[0], splittedGroup[1] + p.PermittedGroups[splittedGroup[0]] = splittedGroup[1] + } + logger.Printf("\t - %-30s %s", groupName, groupID) + } + logger.Printf("") +} + +func (p *AzureProvider) SetGroupsExemption(exemptions []string) { + // Get list of users (optionally with User IDs) that could still be allowed to login + // when group membership calls fail (e.g. insufficient permissions) + + p.ExemptedUsers = make(map[string]string) + if len(exemptions) == 0 { + return + } + + var userRecord string + var groupName string + logger.Printf("Configure user exemption list:") + logger.Printf("\t *USER NAME*:*USER ID* : *DEFAULT GROUP*") + for _, pRecord := range exemptions { + splittedRecord := strings.Split(pRecord, ":") + + if len(splittedRecord) == 1 { + userRecord, groupName = splittedRecord[0], "" + } else if len(splittedRecord) == 2 { + userRecord, groupName = splittedRecord[0], splittedRecord[1] + } else if len(splittedRecord) > 3 { + logger.Fatalf("failed to parse '%v'. Too many ':' separators", pRecord) + } else { + userRecord = splittedRecord[0] + ":" + splittedRecord[1] + groupName = splittedRecord[2] + } + p.ExemptedUsers[userRecord] = groupName + logger.Printf("\t - %-65s %s", userRecord, groupName) + } + logger.Printf("") +} + +func (p *AzureProvider) ValidateGroupWithSession(s *sessions.SessionState) bool { + if len(p.PermittedGroups) != 0 { + for groupName, groupId := range p.PermittedGroups { + logger.Printf("ValidateGroup: %v", groupName) + if strings.Contains(s.Groups, groupId) { + return true + } + } + logger.Printf("Returning False from ValidateGroup") + return false + } + return true +} + +func (p *AzureProvider) GroupPermitted(gName *string, gID *string) bool { + // Validate provided group + // if "PermitGroups" are defined, for each user group membership, include only those groups that + // marked in list + // + // NOTE: if group in "PermitGroups" does not have group_id defined, this parameter is ignored + if len(p.PermittedGroups) != 0 { + for pGroupName, pGroupID := range p.PermittedGroups { + if pGroupName == *gName { + logger.Printf("ValidateGroup: %v : %v", pGroupName, pGroupID) + if pGroupID == "" || gID == nil { + logger.Printf("ValidateGroup: %v : %v : no Group ID defined for permitted group. Approving", pGroupName, pGroupID) + return true + } else if pGroupID == *gID { + logger.Printf("ValidateGroup: %v : %v : Group ID matches defined in permitted group. Approving", pGroupName, pGroupID) + return true + } + logger.Printf("ValidateGroup: %v : %v != %v Group IDs didn't match", pGroupName, pGroupID, *gID) + } + } + return false + } + return true +} diff --git a/providers/provider_default.go b/providers/provider_default.go index d87b939..17e32c4 100644 --- a/providers/provider_default.go +++ b/providers/provider_default.go @@ -70,7 +70,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat var v url.Values v, err = url.ParseQuery(string(body)) if err != nil { - return + return nil, err } if a := v.Get("access_token"); a != "" { s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()} @@ -110,17 +110,40 @@ func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) return "", errors.New("not implemented") } +func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (map[string]string, error) { + userDetails := map[string]string{} + email, err := p.GetEmailAddress(s) + if err != nil { + return nil, err + } + userDetails["email"] = email + return userDetails, nil +} + // GetUserName returns the Account username func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) { return "", errors.New("not implemented") } +func (p *ProviderData) GetGroups(s *sessions.SessionState, f string) (map[string]string, error) { + return map[string]string{}, errors.New("not implemented") +} + // ValidateGroup validates that the provided email exists in the configured provider // email group(s). func (p *ProviderData) ValidateGroup(email string) bool { return true } +// ValidateExemptions checks if we can allow user login dispite group membership returned failure +func (p *ProviderData) ValidateExemptions(*sessions.SessionState) (bool, string) { + return false, "" +} + +func (p *ProviderData) ValidateGroupWithSession(s *sessions.SessionState) bool { + return p.ValidateGroup(s.Email) +} + // ValidateSessionState validates the AccessToken func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool { return validateToken(p, s.AccessToken, nil) diff --git a/providers/providers.go b/providers/providers.go index 276fab6..3faef61 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -9,9 +9,13 @@ import ( type Provider interface { Data() *ProviderData GetEmailAddress(*sessions.SessionState) (string, error) + GetUserDetails(*sessions.SessionState) (map[string]string, error) GetUserName(*sessions.SessionState) (string, error) + GetGroups(*sessions.SessionState, string) (map[string]string, error) Redeem(string, string) (*sessions.SessionState, error) ValidateGroup(string) bool + ValidateGroupWithSession(*sessions.SessionState) bool + ValidateExemptions(*sessions.SessionState) (bool, string) ValidateSessionState(*sessions.SessionState) bool GetLoginURL(redirectURI, finalRedirect string) string RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)