Initial work on porting https://github.com/bitly/oauth2_proxy/pull/347/. ToDo: port tests

This commit is contained in:
Lukasz Leszczuk 2019-08-16 23:31:19 +02:00
parent 0aba5ec768
commit 7ecb9fd2d6
9 changed files with 383 additions and 28 deletions

1
.gitignore vendored
View File

@ -6,6 +6,7 @@ release
*.exe *.exe
.env .env
.bundle .bundle
.idea/
# Go.gitignore # Go.gitignore
# Compiled Object files, Static and Dynamic libs (Shared Objects) # Compiled Object files, Static and Dynamic libs (Shared Objects)

View File

@ -1,5 +1,6 @@
include .env include .env
BINARY := oauth2_proxy BINARY := oauth2_proxy
REPOSITORY := quay.io/pusher
VERSION := $(shell git describe --always --dirty --tags 2>/dev/null || echo "undefined") VERSION := $(shell git describe --always --dirty --tags 2>/dev/null || echo "undefined")
.NOTPARALLEL: .NOTPARALLEL:
@ -27,31 +28,31 @@ $(BINARY):
.PHONY: docker .PHONY: docker
docker: docker:
docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:latest . docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:latest .
.PHONY: docker-all .PHONY: docker-all
docker-all: docker docker-all: docker
docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:latest-amd64 . docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:latest-amd64 .
docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:${VERSION} . docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:${VERSION} .
docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:${VERSION}-amd64 . docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:${VERSION}-amd64 .
docker build -f Dockerfile.arm64 -t quay.io/pusher/oauth2_proxy:latest-arm64 . docker build -f Dockerfile.arm64 -t ${REPOSITORY}/oauth2_proxy:latest-arm64 .
docker build -f Dockerfile.arm64 -t quay.io/pusher/oauth2_proxy:${VERSION}-arm64 . docker build -f Dockerfile.arm64 -t ${REPOSITORY}/oauth2_proxy:${VERSION}-arm64 .
docker build -f Dockerfile.armv6 -t quay.io/pusher/oauth2_proxy:latest-armv6 . docker build -f Dockerfile.armv6 -t ${REPOSITORY}/oauth2_proxy:latest-armv6 .
docker build -f Dockerfile.armv6 -t quay.io/pusher/oauth2_proxy:${VERSION}-armv6 . docker build -f Dockerfile.armv6 -t ${REPOSITORY}/oauth2_proxy:${VERSION}-armv6 .
.PHONY: docker-push .PHONY: docker-push
docker-push: docker-push:
docker push quay.io/pusher/oauth2_proxy:latest docker push ${REPOSITORY}/oauth2_proxy:latest
.PHONY: docker-push-all .PHONY: docker-push-all
docker-push-all: docker-push docker-push-all: docker-push
docker push quay.io/pusher/oauth2_proxy:latest-amd64 docker push ${REPOSITORY}/oauth2_proxy:latest-amd64
docker push quay.io/pusher/oauth2_proxy:${VERSION} docker push ${REPOSITORY}/oauth2_proxy:${VERSION}
docker push quay.io/pusher/oauth2_proxy:${VERSION}-amd64 docker push ${REPOSITORY}/oauth2_proxy:${VERSION}-amd64
docker push quay.io/pusher/oauth2_proxy:latest-arm64 docker push ${REPOSITORY}/oauth2_proxy:latest-arm64
docker push quay.io/pusher/oauth2_proxy:${VERSION}-arm64 docker push ${REPOSITORY}/oauth2_proxy:${VERSION}-arm64
docker push quay.io/pusher/oauth2_proxy:latest-armv6 docker push ${REPOSITORY}/oauth2_proxy:latest-armv6
docker push quay.io/pusher/oauth2_proxy:${VERSION}-armv6 docker push ${REPOSITORY}/oauth2_proxy:${VERSION}-armv6
.PHONY: test .PHONY: test
test: lint test: lint

View File

@ -25,6 +25,8 @@ func main() {
skipAuthRegex := StringArray{} skipAuthRegex := StringArray{}
jwtIssuers := StringArray{} jwtIssuers := StringArray{}
googleGroups := StringArray{} googleGroups := StringArray{}
permittedGroups := StringArray{}
exemptedUsers := StringArray{}
redisSentinelConnectionURLs := StringArray{} redisSentinelConnectionURLs := StringArray{}
config := flagSet.String("config", "", "path to config file") config := flagSet.String("config", "", "path to config file")
@ -39,6 +41,11 @@ func main() {
flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path") flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path")
flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream") flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream")
flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream") flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream")
flagSet.Bool("pass-groups", false, "pass user group information in the X-Forwarded-Groups header to upstream (Azure only)")
flagSet.String("filter-groups", "", "exclude groups that do not contain this value in its 'displayName' (Azure only)")
flagSet.Var(&permittedGroups, "permit-groups", "restrict logins to members of this group (may be given multiple times; Azure).")
flagSet.String("groups-delimiter", "|", "delimiter between group names if more than one found. By default it is '|' symbol")
flagSet.Var(&exemptedUsers, "permit-users", "let these users in if azure call to check group membership fails (may be given multiple times; Azure).")
flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header")
flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header")
flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream")

View File

@ -87,6 +87,9 @@ type OAuthProxy struct {
serveMux http.Handler serveMux http.Handler
SetXAuthRequest bool SetXAuthRequest bool
PassBasicAuth bool PassBasicAuth bool
PassGroups bool
GroupsDelimiter string
FilterGroups string
SkipProviderButton bool SkipProviderButton bool
PassUserHeaders bool PassUserHeaders bool
BasicAuthPassword string BasicAuthPassword string
@ -280,6 +283,9 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
compiledRegex: opts.CompiledRegex, compiledRegex: opts.CompiledRegex,
SetXAuthRequest: opts.SetXAuthRequest, SetXAuthRequest: opts.SetXAuthRequest,
PassBasicAuth: opts.PassBasicAuth, PassBasicAuth: opts.PassBasicAuth,
PassGroups: opts.PassGroups,
GroupsDelimiter: opts.GroupsDelimiter,
FilterGroups: opts.FilterGroups,
PassUserHeaders: opts.PassUserHeaders, PassUserHeaders: opts.PassUserHeaders,
BasicAuthPassword: opts.BasicAuthPassword, BasicAuthPassword: opts.BasicAuthPassword,
PassAccessToken: opts.PassAccessToken, PassAccessToken: opts.PassAccessToken,
@ -327,7 +333,16 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState,
} }
if s.Email == "" { if s.Email == "" {
s.Email, err = p.provider.GetEmailAddress(s) userDetails, err := p.provider.GetUserDetails(s)
if err != nil {
return s, err
}
s.Email = userDetails["email"]
if uid, found := userDetails["uid"]; found {
s.ID = uid
} else {
s.ID = ""
}
} }
if s.User == "" { if s.User == "" {
@ -654,12 +669,27 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return return
} }
session.IDToken = req.Form.Get("id_token")
if p.PassGroups && session.IDToken != "" {
groups, err := p.provider.GetGroups(session, p.FilterGroups)
if err != nil {
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
return
}
groupNames := []string{}
for groupName := range groups {
groupNames = append(groupNames, groupName)
}
session.Groups = strings.Join(groupNames, p.GroupsDelimiter)
}
if !p.IsValidRedirect(redirect) { if !p.IsValidRedirect(redirect) {
redirect = "/" redirect = "/"
} }
// set cookie, or deny // set cookie, or deny
if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) { if p.Validator(session.Email) && p.provider.ValidateGroupWithSession(session) {
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session) logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
err := p.SaveSession(rw, req, session) err := p.SaveSession(rw, req, session)
if err != nil { if err != nil {
@ -823,6 +853,9 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req
} else { } else {
req.Header.Del("X-Forwarded-Email") req.Header.Del("X-Forwarded-Email")
} }
if p.PassGroups && session.Groups != "" {
req.Header["X-Forwarded-Groups"] = []string{session.Groups}
}
} }
if p.PassUserHeaders { if p.PassUserHeaders {
@ -849,6 +882,9 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req
rw.Header().Del("X-Auth-Request-Access-Token") rw.Header().Del("X-Auth-Request-Access-Token")
} }
} }
if p.PassGroups && session.Groups != "" {
rw.Header().Set("X-Auth-Request-Groups", session.Groups)
}
} }
if p.PassAccessToken { if p.PassAccessToken {

View File

@ -14,7 +14,7 @@ import (
"strings" "strings"
"time" "time"
oidc "github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
"github.com/mbland/hmacauth" "github.com/mbland/hmacauth"
"github.com/pusher/oauth2_proxy/pkg/apis/options" "github.com/pusher/oauth2_proxy/pkg/apis/options"
@ -69,6 +69,11 @@ type Options struct {
SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens" env:"OAUTH2_PROXY_SKIP_JWT_BEARER_TOKENS"` SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens" env:"OAUTH2_PROXY_SKIP_JWT_BEARER_TOKENS"`
ExtraJwtIssuers []string `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers" env:"OAUTH2_PROXY_EXTRA_JWT_ISSUERS"` ExtraJwtIssuers []string `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers" env:"OAUTH2_PROXY_EXTRA_JWT_ISSUERS"`
PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth" env:"OAUTH2_PROXY_PASS_BASIC_AUTH"` PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth" env:"OAUTH2_PROXY_PASS_BASIC_AUTH"`
PassGroups bool `flag:"pass-groups" cfg:"pass_groups"`
FilterGroups string `flag:"filter-groups" cfg:"filter_groups"`
PermitGroups []string `flag:"permit-groups" cfg:"permit_groups"`
GroupsDelimiter string `flag:"groups-delimiter" cfg:"groups_delimiter"`
PermitUsers []string `flag:"permit-users" cfg:"permit_users"`
BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password" env:"OAUTH2_PROXY_BASIC_AUTH_PASSWORD"` BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password" env:"OAUTH2_PROXY_BASIC_AUTH_PASSWORD"`
PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token" env:"OAUTH2_PROXY_PASS_ACCESS_TOKEN"` PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token" env:"OAUTH2_PROXY_PASS_ACCESS_TOKEN"`
PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header" env:"OAUTH2_PROXY_PASS_HOST_HEADER"` PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header" env:"OAUTH2_PROXY_PASS_HOST_HEADER"`
@ -159,6 +164,11 @@ func NewOptions() *Options {
SkipAuthPreflight: false, SkipAuthPreflight: false,
PassBasicAuth: true, PassBasicAuth: true,
PassUserHeaders: true, PassUserHeaders: true,
PassGroups: false,
FilterGroups: "",
GroupsDelimiter: "|",
PermitGroups: []string{},
PermitUsers: []string{},
PassAccessToken: false, PassAccessToken: false,
PassHostHeader: true, PassHostHeader: true,
SetAuthorization: false, SetAuthorization: false,
@ -380,6 +390,7 @@ func (o *Options) Validate() error {
} }
func parseProviderInfo(o *Options, msgs []string) []string { func parseProviderInfo(o *Options, msgs []string) []string {
var splittedGroups []string
p := &providers.ProviderData{ p := &providers.ProviderData{
Scope: o.Scope, Scope: o.Scope,
ClientID: o.ClientID, ClientID: o.ClientID,
@ -391,11 +402,20 @@ func parseProviderInfo(o *Options, msgs []string) []string {
p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs) p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs)
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs) p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs) p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
if len(o.PermitGroups) > 0 {
splittedGroups = strings.Split(o.PermitGroups[0], o.GroupsDelimiter)
}
o.provider = providers.New(o.Provider, p) o.provider = providers.New(o.Provider, p)
switch p := o.provider.(type) { switch p := o.provider.(type) {
case *providers.AzureProvider: case *providers.AzureProvider:
p.Configure(o.AzureTenant) p.Configure(o.AzureTenant)
logger.Printf("PermitGroups %+v\n", splittedGroups)
if len(splittedGroups) > 0 {
p.SetGroupRestriction(splittedGroups)
}
if len(o.PermitUsers) > 0 {
p.SetGroupsExemption(o.PermitUsers)
}
case *providers.GitHubProvider: case *providers.GitHubProvider:
p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam) p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam)
case *providers.GoogleProvider: case *providers.GoogleProvider:

View File

@ -19,6 +19,8 @@ type SessionState struct {
RefreshToken string `json:",omitempty"` RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"` Email string `json:",omitempty"`
User string `json:",omitempty"` User string `json:",omitempty"`
ID string `json:",omitempty"`
Groups string `json:",omitempty"`
} }
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value // SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
@ -62,6 +64,9 @@ func (s *SessionState) String() string {
if s.RefreshToken != "" { if s.RefreshToken != "" {
o += " refresh_token:true" o += " refresh_token:true"
} }
if s.Groups != "" {
o += fmt.Sprintf(" group:%s", s.Groups)
}
return o + "}" return o + "}"
} }
@ -105,6 +110,18 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
return "", err return "", err
} }
} }
if ss.Groups != "" {
ss.Groups, err = c.Encrypt(ss.Groups)
if err != nil {
return "", err
}
}
if ss.ID != "" {
ss.ID, err = c.Encrypt(ss.ID)
if err != nil {
return "", err
}
}
} }
// Embed SessionState and ExpiresOn pointer into SessionStateJSON // Embed SessionState and ExpiresOn pointer into SessionStateJSON
ssj := &SessionStateJSON{SessionState: &ss} ssj := &SessionStateJSON{SessionState: &ss}
@ -235,6 +252,18 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
return nil, err return nil, err
} }
} }
if ss.Groups != "" {
ss.Groups, err = c.Decrypt(ss.Groups)
if err != nil {
return nil, err
}
}
if ss.ID != "" {
ss.ID, err = c.Decrypt(ss.ID)
if err != nil {
return nil, err
}
}
} }
if ss.User == "" { if ss.User == "" {
ss.User = ss.Email ss.User = ss.Email

View File

@ -3,8 +3,10 @@ package providers
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/dgrijalva/jwt-go"
"net/http" "net/http"
"net/url" "net/url"
"strings"
"github.com/bitly/go-simplejson" "github.com/bitly/go-simplejson"
"github.com/pusher/oauth2_proxy/pkg/apis/sessions" "github.com/pusher/oauth2_proxy/pkg/apis/sessions"
@ -16,6 +18,8 @@ import (
type AzureProvider struct { type AzureProvider struct {
*ProviderData *ProviderData
Tenant string Tenant string
PermittedGroups map[string]string
ExemptedUsers map[string]string
} }
// NewAzureProvider initiates a new AzureProvider // NewAzureProvider initiates a new AzureProvider
@ -25,21 +29,25 @@ func NewAzureProvider(p *ProviderData) *AzureProvider {
if p.ProfileURL == nil || p.ProfileURL.String() == "" { if p.ProfileURL == nil || p.ProfileURL.String() == "" {
p.ProfileURL = &url.URL{ p.ProfileURL = &url.URL{
Scheme: "https", Scheme: "https",
Host: "graph.windows.net", Host: "graph.microsoft.com",
Path: "/me", Path: "/v1.0/me",
RawQuery: "api-version=1.6",
} }
} }
if p.ProtectedResource == nil || p.ProtectedResource.String() == "" { if p.ProtectedResource == nil || p.ProtectedResource.String() == "" {
p.ProtectedResource = &url.URL{ p.ProtectedResource = &url.URL{
Scheme: "https", Scheme: "https",
Host: "graph.windows.net", Host: "graph.microsoft.com",
} }
} }
if p.Scope == "" { if p.Scope == "" {
p.Scope = "openid" p.Scope = "openid"
} }
if p.ApprovalPrompt == "force" {
p.ApprovalPrompt = "consent"
}
logger.Printf("Approval prompt: '%s'", p.ApprovalPrompt)
return &AzureProvider{ProviderData: p} return &AzureProvider{ProviderData: p}
} }
@ -72,22 +80,44 @@ func getAzureHeader(accessToken string) http.Header {
} }
func getEmailFromJSON(json *simplejson.Json) (string, error) { func getEmailFromJSON(json *simplejson.Json) (string, error) {
// First try to return `userPrincipalName`
// if not defined, try to return `mail`
// if that also failed, try to get first record from `otherMails`
// TODO: Return everything in list and then try requests one by one
var email string var email string
var err error var err error
email, err = json.Get("userPrincipalName").String()
if err == nil {
return email, err
}
email, err = json.Get("mail").String() email, err = json.Get("mail").String()
if err != nil || email == "" { if err != nil || email == "" {
otherMails, otherMailsErr := json.Get("otherMails").Array() otherMails, otherMailsErr := json.Get("otherMails").Array()
if len(otherMails) > 0 { if len(otherMails) > 0 {
email = otherMails[0].(string) email = otherMails[0].(string)
}
err = otherMailsErr err = otherMailsErr
} }
}
return email, err return email, err
} }
func getUserIDFromJSON(json *simplejson.Json) (string, error) {
// Try to get user ID
// if not defined, return empty string
uid, err := json.Get("id").String()
if err != nil {
return "", err
}
return uid, err
}
// GetEmailAddress returns the Account email address // GetEmailAddress returns the Account email address
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) { func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
var email string var email string
@ -128,3 +158,207 @@ func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error
return email, err return email, err
} }
func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
userDetails := map[string]string{}
var err error
if s.AccessToken == "" {
return userDetails, errors.New("missing access token")
}
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
if err != nil {
return userDetails, err
}
req.Header = getAzureHeader(s.AccessToken)
json, err := requests.Request(req)
if err != nil {
return userDetails, err
}
logger.Printf(" JSON: %v", json)
for key, value := range json.Interface().(map[string]interface{}) {
logger.Printf("\t %20v : %v", key, value)
}
email, err := getEmailFromJSON(json)
userDetails["email"] = email
if err != nil {
logger.Printf("[GetEmailAddress] failed making request: %s", err)
return userDetails, err
}
uid, err := getUserIDFromJSON(json)
userDetails["uid"] = uid
if err != nil {
logger.Printf("[GetEmailAddress] failed to get User ID: %s", err)
}
if email == "" {
logger.Printf("failed to get email address")
return userDetails, errors.New("Client email not found")
}
logger.Printf("[GetEmailAddress] Chosen email address: '%s'", email)
return userDetails, nil
}
// Get list of groups user belong to. Filter the desired names of groups (in case of huge group set)
func (p *AzureProvider) GetGroups(s *sessions.SessionState, f string) (map[string]string, error) {
// Azure App Registration requires setting groupMembershipClaims to include group membership in identity token
// This option is available through ARM template only.
// For details refer to: https://docs.microsoft.com/pl-pl/azure/active-directory/develop/reference-app-manifest
if s.IDToken == "" {
return map[string]string{}, errors.New("missing id token")
}
type GroupClaims struct {
Groups []string `json:"groups"`
jwt.StandardClaims
}
claims := &GroupClaims{}
jwt.ParseWithClaims(s.IDToken, claims, func(token *jwt.Token) (interface{}, error) {
return []byte("empty"), nil
})
groupsMap := make(map[string]string)
for _, s := range claims.Groups {
groupsMap[s] = s
}
return groupsMap, nil
}
// ValidateExemptions checks if we can allow user login dispite group membership returned failure
func (p *AzureProvider) ValidateExemptions(s *sessions.SessionState) (bool, string) {
logger.Printf("ValidateExemptions: validating for %v : %v", s.Email, s.ID)
for eAccount, eGroup := range p.ExemptedUsers {
if eAccount == s.Email || eAccount == s.Email+":"+s.ID {
logger.Printf("ValidateExemptions: \t found '%v' user in exemption list. Returning '%v' group membership", eAccount, eGroup)
return true, eGroup
}
}
return false, ""
}
func (p *AzureProvider) GetLoginURL(redirectURI, state string) string {
var a url.URL
a = *p.LoginURL
params, _ := url.ParseQuery(a.RawQuery)
params.Set("client_id", p.ClientID)
params.Set("response_type", "id_token code")
params.Set("redirect_uri", redirectURI)
params.Set("response_mode", "form_post")
params.Add("scope", p.Scope)
params.Add("state", state)
params.Set("prompt", p.ApprovalPrompt)
params.Set("nonce", "FIXME")
if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
params.Add("resource", p.ProtectedResource.String())
}
a.RawQuery = params.Encode()
return a.String()
}
func (p *AzureProvider) SetGroupRestriction(groups []string) {
// Get list of groups (optionally with Group IDs) that ONLY allowed for user
// That means even if user has wider group membership, only membership in those groups will be forwarded
p.PermittedGroups = make(map[string]string)
if len(groups) == 0 {
return
}
logger.Printf("Set group restrictions. Allowed groups are:")
logger.Printf("\t *GROUP NAME* : *GROUP ID*")
for _, pGroup := range groups {
splittedGroup := strings.Split(pGroup, ":")
var groupName string
var groupID string
if len(splittedGroup) == 1 {
groupName, groupID = splittedGroup[0], ""
p.PermittedGroups[splittedGroup[0]] = ""
} else if len(splittedGroup) > 2 {
logger.Fatalf("failed to parse '%v'. Too many ':' separators", pGroup)
} else {
groupName, groupID = splittedGroup[0], splittedGroup[1]
p.PermittedGroups[splittedGroup[0]] = splittedGroup[1]
}
logger.Printf("\t - %-30s %s", groupName, groupID)
}
logger.Printf("")
}
func (p *AzureProvider) SetGroupsExemption(exemptions []string) {
// Get list of users (optionally with User IDs) that could still be allowed to login
// when group membership calls fail (e.g. insufficient permissions)
p.ExemptedUsers = make(map[string]string)
if len(exemptions) == 0 {
return
}
var userRecord string
var groupName string
logger.Printf("Configure user exemption list:")
logger.Printf("\t *USER NAME*:*USER ID* : *DEFAULT GROUP*")
for _, pRecord := range exemptions {
splittedRecord := strings.Split(pRecord, ":")
if len(splittedRecord) == 1 {
userRecord, groupName = splittedRecord[0], ""
} else if len(splittedRecord) == 2 {
userRecord, groupName = splittedRecord[0], splittedRecord[1]
} else if len(splittedRecord) > 3 {
logger.Fatalf("failed to parse '%v'. Too many ':' separators", pRecord)
} else {
userRecord = splittedRecord[0] + ":" + splittedRecord[1]
groupName = splittedRecord[2]
}
p.ExemptedUsers[userRecord] = groupName
logger.Printf("\t - %-65s %s", userRecord, groupName)
}
logger.Printf("")
}
func (p *AzureProvider) ValidateGroupWithSession(s *sessions.SessionState) bool {
if len(p.PermittedGroups) != 0 {
for groupName, groupId := range p.PermittedGroups {
logger.Printf("ValidateGroup: %v", groupName)
if strings.Contains(s.Groups, groupId) {
return true
}
}
logger.Printf("Returning False from ValidateGroup")
return false
}
return true
}
func (p *AzureProvider) GroupPermitted(gName *string, gID *string) bool {
// Validate provided group
// if "PermitGroups" are defined, for each user group membership, include only those groups that
// marked in list
//
// NOTE: if group in "PermitGroups" does not have group_id defined, this parameter is ignored
if len(p.PermittedGroups) != 0 {
for pGroupName, pGroupID := range p.PermittedGroups {
if pGroupName == *gName {
logger.Printf("ValidateGroup: %v : %v", pGroupName, pGroupID)
if pGroupID == "" || gID == nil {
logger.Printf("ValidateGroup: %v : %v : no Group ID defined for permitted group. Approving", pGroupName, pGroupID)
return true
} else if pGroupID == *gID {
logger.Printf("ValidateGroup: %v : %v : Group ID matches defined in permitted group. Approving", pGroupName, pGroupID)
return true
}
logger.Printf("ValidateGroup: %v : %v != %v Group IDs didn't match", pGroupName, pGroupID, *gID)
}
}
return false
}
return true
}

View File

@ -70,7 +70,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat
var v url.Values var v url.Values
v, err = url.ParseQuery(string(body)) v, err = url.ParseQuery(string(body))
if err != nil { if err != nil {
return return nil, err
} }
if a := v.Get("access_token"); a != "" { if a := v.Get("access_token"); a != "" {
s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()} s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()}
@ -110,17 +110,40 @@ func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error)
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
userDetails := map[string]string{}
email, err := p.GetEmailAddress(s)
if err != nil {
return nil, err
}
userDetails["email"] = email
return userDetails, nil
}
// GetUserName returns the Account username // GetUserName returns the Account username
func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) { func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented") return "", errors.New("not implemented")
} }
func (p *ProviderData) GetGroups(s *sessions.SessionState, f string) (map[string]string, error) {
return map[string]string{}, errors.New("not implemented")
}
// ValidateGroup validates that the provided email exists in the configured provider // ValidateGroup validates that the provided email exists in the configured provider
// email group(s). // email group(s).
func (p *ProviderData) ValidateGroup(email string) bool { func (p *ProviderData) ValidateGroup(email string) bool {
return true return true
} }
// ValidateExemptions checks if we can allow user login dispite group membership returned failure
func (p *ProviderData) ValidateExemptions(*sessions.SessionState) (bool, string) {
return false, ""
}
func (p *ProviderData) ValidateGroupWithSession(s *sessions.SessionState) bool {
return p.ValidateGroup(s.Email)
}
// ValidateSessionState validates the AccessToken // ValidateSessionState validates the AccessToken
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool { func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
return validateToken(p, s.AccessToken, nil) return validateToken(p, s.AccessToken, nil)

View File

@ -9,9 +9,13 @@ import (
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(*sessions.SessionState) (string, error) GetEmailAddress(*sessions.SessionState) (string, error)
GetUserDetails(*sessions.SessionState) (map[string]string, error)
GetUserName(*sessions.SessionState) (string, error) GetUserName(*sessions.SessionState) (string, error)
GetGroups(*sessions.SessionState, string) (map[string]string, error)
Redeem(string, string) (*sessions.SessionState, error) Redeem(string, string) (*sessions.SessionState, error)
ValidateGroup(string) bool ValidateGroup(string) bool
ValidateGroupWithSession(*sessions.SessionState) bool
ValidateExemptions(*sessions.SessionState) (bool, string)
ValidateSessionState(*sessions.SessionState) bool ValidateSessionState(*sessions.SessionState) bool
GetLoginURL(redirectURI, finalRedirect string) string GetLoginURL(redirectURI, finalRedirect string) string
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error) RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)