Initial work on porting https://github.com/bitly/oauth2_proxy/pull/347/. ToDo: port tests
This commit is contained in:
parent
0aba5ec768
commit
7ecb9fd2d6
1
.gitignore
vendored
1
.gitignore
vendored
@ -6,6 +6,7 @@ release
|
||||
*.exe
|
||||
.env
|
||||
.bundle
|
||||
.idea/
|
||||
|
||||
# Go.gitignore
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
|
33
Makefile
33
Makefile
@ -1,5 +1,6 @@
|
||||
include .env
|
||||
BINARY := oauth2_proxy
|
||||
REPOSITORY := quay.io/pusher
|
||||
VERSION := $(shell git describe --always --dirty --tags 2>/dev/null || echo "undefined")
|
||||
.NOTPARALLEL:
|
||||
|
||||
@ -27,31 +28,31 @@ $(BINARY):
|
||||
|
||||
.PHONY: docker
|
||||
docker:
|
||||
docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:latest .
|
||||
docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:latest .
|
||||
|
||||
.PHONY: docker-all
|
||||
docker-all: docker
|
||||
docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:latest-amd64 .
|
||||
docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:${VERSION} .
|
||||
docker build -f Dockerfile -t quay.io/pusher/oauth2_proxy:${VERSION}-amd64 .
|
||||
docker build -f Dockerfile.arm64 -t quay.io/pusher/oauth2_proxy:latest-arm64 .
|
||||
docker build -f Dockerfile.arm64 -t quay.io/pusher/oauth2_proxy:${VERSION}-arm64 .
|
||||
docker build -f Dockerfile.armv6 -t quay.io/pusher/oauth2_proxy:latest-armv6 .
|
||||
docker build -f Dockerfile.armv6 -t quay.io/pusher/oauth2_proxy:${VERSION}-armv6 .
|
||||
docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:latest-amd64 .
|
||||
docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:${VERSION} .
|
||||
docker build -f Dockerfile -t ${REPOSITORY}/oauth2_proxy:${VERSION}-amd64 .
|
||||
docker build -f Dockerfile.arm64 -t ${REPOSITORY}/oauth2_proxy:latest-arm64 .
|
||||
docker build -f Dockerfile.arm64 -t ${REPOSITORY}/oauth2_proxy:${VERSION}-arm64 .
|
||||
docker build -f Dockerfile.armv6 -t ${REPOSITORY}/oauth2_proxy:latest-armv6 .
|
||||
docker build -f Dockerfile.armv6 -t ${REPOSITORY}/oauth2_proxy:${VERSION}-armv6 .
|
||||
|
||||
.PHONY: docker-push
|
||||
docker-push:
|
||||
docker push quay.io/pusher/oauth2_proxy:latest
|
||||
docker push ${REPOSITORY}/oauth2_proxy:latest
|
||||
|
||||
.PHONY: docker-push-all
|
||||
docker-push-all: docker-push
|
||||
docker push quay.io/pusher/oauth2_proxy:latest-amd64
|
||||
docker push quay.io/pusher/oauth2_proxy:${VERSION}
|
||||
docker push quay.io/pusher/oauth2_proxy:${VERSION}-amd64
|
||||
docker push quay.io/pusher/oauth2_proxy:latest-arm64
|
||||
docker push quay.io/pusher/oauth2_proxy:${VERSION}-arm64
|
||||
docker push quay.io/pusher/oauth2_proxy:latest-armv6
|
||||
docker push quay.io/pusher/oauth2_proxy:${VERSION}-armv6
|
||||
docker push ${REPOSITORY}/oauth2_proxy:latest-amd64
|
||||
docker push ${REPOSITORY}/oauth2_proxy:${VERSION}
|
||||
docker push ${REPOSITORY}/oauth2_proxy:${VERSION}-amd64
|
||||
docker push ${REPOSITORY}/oauth2_proxy:latest-arm64
|
||||
docker push ${REPOSITORY}/oauth2_proxy:${VERSION}-arm64
|
||||
docker push ${REPOSITORY}/oauth2_proxy:latest-armv6
|
||||
docker push ${REPOSITORY}/oauth2_proxy:${VERSION}-armv6
|
||||
|
||||
.PHONY: test
|
||||
test: lint
|
||||
|
7
main.go
7
main.go
@ -25,6 +25,8 @@ func main() {
|
||||
skipAuthRegex := StringArray{}
|
||||
jwtIssuers := StringArray{}
|
||||
googleGroups := StringArray{}
|
||||
permittedGroups := StringArray{}
|
||||
exemptedUsers := StringArray{}
|
||||
redisSentinelConnectionURLs := StringArray{}
|
||||
|
||||
config := flagSet.String("config", "", "path to config file")
|
||||
@ -39,6 +41,11 @@ func main() {
|
||||
flagSet.Var(&upstreams, "upstream", "the http url(s) of the upstream endpoint or file:// paths for static files. Routing is based on the path")
|
||||
flagSet.Bool("pass-basic-auth", true, "pass HTTP Basic Auth, X-Forwarded-User and X-Forwarded-Email information to upstream")
|
||||
flagSet.Bool("pass-user-headers", true, "pass X-Forwarded-User and X-Forwarded-Email information to upstream")
|
||||
flagSet.Bool("pass-groups", false, "pass user group information in the X-Forwarded-Groups header to upstream (Azure only)")
|
||||
flagSet.String("filter-groups", "", "exclude groups that do not contain this value in its 'displayName' (Azure only)")
|
||||
flagSet.Var(&permittedGroups, "permit-groups", "restrict logins to members of this group (may be given multiple times; Azure).")
|
||||
flagSet.String("groups-delimiter", "|", "delimiter between group names if more than one found. By default it is '|' symbol")
|
||||
flagSet.Var(&exemptedUsers, "permit-users", "let these users in if azure call to check group membership fails (may be given multiple times; Azure).")
|
||||
flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header")
|
||||
flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header")
|
||||
flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream")
|
||||
|
@ -87,6 +87,9 @@ type OAuthProxy struct {
|
||||
serveMux http.Handler
|
||||
SetXAuthRequest bool
|
||||
PassBasicAuth bool
|
||||
PassGroups bool
|
||||
GroupsDelimiter string
|
||||
FilterGroups string
|
||||
SkipProviderButton bool
|
||||
PassUserHeaders bool
|
||||
BasicAuthPassword string
|
||||
@ -280,6 +283,9 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy {
|
||||
compiledRegex: opts.CompiledRegex,
|
||||
SetXAuthRequest: opts.SetXAuthRequest,
|
||||
PassBasicAuth: opts.PassBasicAuth,
|
||||
PassGroups: opts.PassGroups,
|
||||
GroupsDelimiter: opts.GroupsDelimiter,
|
||||
FilterGroups: opts.FilterGroups,
|
||||
PassUserHeaders: opts.PassUserHeaders,
|
||||
BasicAuthPassword: opts.BasicAuthPassword,
|
||||
PassAccessToken: opts.PassAccessToken,
|
||||
@ -327,7 +333,16 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *sessionsapi.SessionState,
|
||||
}
|
||||
|
||||
if s.Email == "" {
|
||||
s.Email, err = p.provider.GetEmailAddress(s)
|
||||
userDetails, err := p.provider.GetUserDetails(s)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
s.Email = userDetails["email"]
|
||||
if uid, found := userDetails["uid"]; found {
|
||||
s.ID = uid
|
||||
} else {
|
||||
s.ID = ""
|
||||
}
|
||||
}
|
||||
|
||||
if s.User == "" {
|
||||
@ -654,12 +669,27 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
session.IDToken = req.Form.Get("id_token")
|
||||
if p.PassGroups && session.IDToken != "" {
|
||||
groups, err := p.provider.GetGroups(session, p.FilterGroups)
|
||||
if err != nil {
|
||||
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
|
||||
return
|
||||
}
|
||||
|
||||
groupNames := []string{}
|
||||
for groupName := range groups {
|
||||
groupNames = append(groupNames, groupName)
|
||||
}
|
||||
session.Groups = strings.Join(groupNames, p.GroupsDelimiter)
|
||||
}
|
||||
|
||||
if !p.IsValidRedirect(redirect) {
|
||||
redirect = "/"
|
||||
}
|
||||
|
||||
// set cookie, or deny
|
||||
if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) {
|
||||
if p.Validator(session.Email) && p.provider.ValidateGroupWithSession(session) {
|
||||
logger.PrintAuthf(session.Email, req, logger.AuthSuccess, "Authenticated via OAuth2: %s", session)
|
||||
err := p.SaveSession(rw, req, session)
|
||||
if err != nil {
|
||||
@ -823,6 +853,9 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req
|
||||
} else {
|
||||
req.Header.Del("X-Forwarded-Email")
|
||||
}
|
||||
if p.PassGroups && session.Groups != "" {
|
||||
req.Header["X-Forwarded-Groups"] = []string{session.Groups}
|
||||
}
|
||||
}
|
||||
|
||||
if p.PassUserHeaders {
|
||||
@ -849,6 +882,9 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req
|
||||
rw.Header().Del("X-Auth-Request-Access-Token")
|
||||
}
|
||||
}
|
||||
if p.PassGroups && session.Groups != "" {
|
||||
rw.Header().Set("X-Auth-Request-Groups", session.Groups)
|
||||
}
|
||||
}
|
||||
|
||||
if p.PassAccessToken {
|
||||
|
24
options.go
24
options.go
@ -14,7 +14,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/mbland/hmacauth"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/options"
|
||||
@ -69,6 +69,11 @@ type Options struct {
|
||||
SkipJwtBearerTokens bool `flag:"skip-jwt-bearer-tokens" cfg:"skip_jwt_bearer_tokens" env:"OAUTH2_PROXY_SKIP_JWT_BEARER_TOKENS"`
|
||||
ExtraJwtIssuers []string `flag:"extra-jwt-issuers" cfg:"extra_jwt_issuers" env:"OAUTH2_PROXY_EXTRA_JWT_ISSUERS"`
|
||||
PassBasicAuth bool `flag:"pass-basic-auth" cfg:"pass_basic_auth" env:"OAUTH2_PROXY_PASS_BASIC_AUTH"`
|
||||
PassGroups bool `flag:"pass-groups" cfg:"pass_groups"`
|
||||
FilterGroups string `flag:"filter-groups" cfg:"filter_groups"`
|
||||
PermitGroups []string `flag:"permit-groups" cfg:"permit_groups"`
|
||||
GroupsDelimiter string `flag:"groups-delimiter" cfg:"groups_delimiter"`
|
||||
PermitUsers []string `flag:"permit-users" cfg:"permit_users"`
|
||||
BasicAuthPassword string `flag:"basic-auth-password" cfg:"basic_auth_password" env:"OAUTH2_PROXY_BASIC_AUTH_PASSWORD"`
|
||||
PassAccessToken bool `flag:"pass-access-token" cfg:"pass_access_token" env:"OAUTH2_PROXY_PASS_ACCESS_TOKEN"`
|
||||
PassHostHeader bool `flag:"pass-host-header" cfg:"pass_host_header" env:"OAUTH2_PROXY_PASS_HOST_HEADER"`
|
||||
@ -159,6 +164,11 @@ func NewOptions() *Options {
|
||||
SkipAuthPreflight: false,
|
||||
PassBasicAuth: true,
|
||||
PassUserHeaders: true,
|
||||
PassGroups: false,
|
||||
FilterGroups: "",
|
||||
GroupsDelimiter: "|",
|
||||
PermitGroups: []string{},
|
||||
PermitUsers: []string{},
|
||||
PassAccessToken: false,
|
||||
PassHostHeader: true,
|
||||
SetAuthorization: false,
|
||||
@ -380,6 +390,7 @@ func (o *Options) Validate() error {
|
||||
}
|
||||
|
||||
func parseProviderInfo(o *Options, msgs []string) []string {
|
||||
var splittedGroups []string
|
||||
p := &providers.ProviderData{
|
||||
Scope: o.Scope,
|
||||
ClientID: o.ClientID,
|
||||
@ -391,11 +402,20 @@ func parseProviderInfo(o *Options, msgs []string) []string {
|
||||
p.ProfileURL, msgs = parseURL(o.ProfileURL, "profile", msgs)
|
||||
p.ValidateURL, msgs = parseURL(o.ValidateURL, "validate", msgs)
|
||||
p.ProtectedResource, msgs = parseURL(o.ProtectedResource, "resource", msgs)
|
||||
|
||||
if len(o.PermitGroups) > 0 {
|
||||
splittedGroups = strings.Split(o.PermitGroups[0], o.GroupsDelimiter)
|
||||
}
|
||||
o.provider = providers.New(o.Provider, p)
|
||||
switch p := o.provider.(type) {
|
||||
case *providers.AzureProvider:
|
||||
p.Configure(o.AzureTenant)
|
||||
logger.Printf("PermitGroups %+v\n", splittedGroups)
|
||||
if len(splittedGroups) > 0 {
|
||||
p.SetGroupRestriction(splittedGroups)
|
||||
}
|
||||
if len(o.PermitUsers) > 0 {
|
||||
p.SetGroupsExemption(o.PermitUsers)
|
||||
}
|
||||
case *providers.GitHubProvider:
|
||||
p.SetOrgTeam(o.GitHubOrg, o.GitHubTeam)
|
||||
case *providers.GoogleProvider:
|
||||
|
@ -19,6 +19,8 @@ type SessionState struct {
|
||||
RefreshToken string `json:",omitempty"`
|
||||
Email string `json:",omitempty"`
|
||||
User string `json:",omitempty"`
|
||||
ID string `json:",omitempty"`
|
||||
Groups string `json:",omitempty"`
|
||||
}
|
||||
|
||||
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
|
||||
@ -62,6 +64,9 @@ func (s *SessionState) String() string {
|
||||
if s.RefreshToken != "" {
|
||||
o += " refresh_token:true"
|
||||
}
|
||||
if s.Groups != "" {
|
||||
o += fmt.Sprintf(" group:%s", s.Groups)
|
||||
}
|
||||
return o + "}"
|
||||
}
|
||||
|
||||
@ -105,6 +110,18 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if ss.Groups != "" {
|
||||
ss.Groups, err = c.Encrypt(ss.Groups)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
if ss.ID != "" {
|
||||
ss.ID, err = c.Encrypt(ss.ID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
}
|
||||
// Embed SessionState and ExpiresOn pointer into SessionStateJSON
|
||||
ssj := &SessionStateJSON{SessionState: &ss}
|
||||
@ -235,6 +252,18 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if ss.Groups != "" {
|
||||
ss.Groups, err = c.Decrypt(ss.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if ss.ID != "" {
|
||||
ss.ID, err = c.Decrypt(ss.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if ss.User == "" {
|
||||
ss.User = ss.Email
|
||||
|
@ -3,8 +3,10 @@ package providers
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/bitly/go-simplejson"
|
||||
"github.com/pusher/oauth2_proxy/pkg/apis/sessions"
|
||||
@ -15,7 +17,9 @@ import (
|
||||
// AzureProvider represents an Azure based Identity Provider
|
||||
type AzureProvider struct {
|
||||
*ProviderData
|
||||
Tenant string
|
||||
Tenant string
|
||||
PermittedGroups map[string]string
|
||||
ExemptedUsers map[string]string
|
||||
}
|
||||
|
||||
// NewAzureProvider initiates a new AzureProvider
|
||||
@ -24,22 +28,26 @@ func NewAzureProvider(p *ProviderData) *AzureProvider {
|
||||
|
||||
if p.ProfileURL == nil || p.ProfileURL.String() == "" {
|
||||
p.ProfileURL = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "graph.windows.net",
|
||||
Path: "/me",
|
||||
RawQuery: "api-version=1.6",
|
||||
Scheme: "https",
|
||||
Host: "graph.microsoft.com",
|
||||
Path: "/v1.0/me",
|
||||
}
|
||||
}
|
||||
if p.ProtectedResource == nil || p.ProtectedResource.String() == "" {
|
||||
p.ProtectedResource = &url.URL{
|
||||
Scheme: "https",
|
||||
Host: "graph.windows.net",
|
||||
Host: "graph.microsoft.com",
|
||||
}
|
||||
}
|
||||
if p.Scope == "" {
|
||||
p.Scope = "openid"
|
||||
}
|
||||
|
||||
if p.ApprovalPrompt == "force" {
|
||||
p.ApprovalPrompt = "consent"
|
||||
}
|
||||
logger.Printf("Approval prompt: '%s'", p.ApprovalPrompt)
|
||||
|
||||
return &AzureProvider{ProviderData: p}
|
||||
}
|
||||
|
||||
@ -72,22 +80,44 @@ func getAzureHeader(accessToken string) http.Header {
|
||||
}
|
||||
|
||||
func getEmailFromJSON(json *simplejson.Json) (string, error) {
|
||||
// First try to return `userPrincipalName`
|
||||
// if not defined, try to return `mail`
|
||||
// if that also failed, try to get first record from `otherMails`
|
||||
// TODO: Return everything in list and then try requests one by one
|
||||
|
||||
var email string
|
||||
var err error
|
||||
|
||||
email, err = json.Get("userPrincipalName").String()
|
||||
if err == nil {
|
||||
return email, err
|
||||
}
|
||||
|
||||
email, err = json.Get("mail").String()
|
||||
|
||||
if err != nil || email == "" {
|
||||
otherMails, otherMailsErr := json.Get("otherMails").Array()
|
||||
if len(otherMails) > 0 {
|
||||
email = otherMails[0].(string)
|
||||
err = otherMailsErr
|
||||
}
|
||||
err = otherMailsErr
|
||||
}
|
||||
|
||||
return email, err
|
||||
}
|
||||
|
||||
func getUserIDFromJSON(json *simplejson.Json) (string, error) {
|
||||
// Try to get user ID
|
||||
// if not defined, return empty string
|
||||
|
||||
uid, err := json.Get("id").String()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return uid, err
|
||||
}
|
||||
|
||||
// GetEmailAddress returns the Account email address
|
||||
func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error) {
|
||||
var email string
|
||||
@ -128,3 +158,207 @@ func (p *AzureProvider) GetEmailAddress(s *sessions.SessionState) (string, error
|
||||
|
||||
return email, err
|
||||
}
|
||||
|
||||
func (p *AzureProvider) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
|
||||
userDetails := map[string]string{}
|
||||
var err error
|
||||
|
||||
if s.AccessToken == "" {
|
||||
return userDetails, errors.New("missing access token")
|
||||
}
|
||||
req, err := http.NewRequest("GET", p.ProfileURL.String(), nil)
|
||||
if err != nil {
|
||||
return userDetails, err
|
||||
}
|
||||
req.Header = getAzureHeader(s.AccessToken)
|
||||
|
||||
json, err := requests.Request(req)
|
||||
|
||||
if err != nil {
|
||||
return userDetails, err
|
||||
}
|
||||
|
||||
logger.Printf(" JSON: %v", json)
|
||||
for key, value := range json.Interface().(map[string]interface{}) {
|
||||
logger.Printf("\t %20v : %v", key, value)
|
||||
}
|
||||
email, err := getEmailFromJSON(json)
|
||||
userDetails["email"] = email
|
||||
|
||||
if err != nil {
|
||||
logger.Printf("[GetEmailAddress] failed making request: %s", err)
|
||||
return userDetails, err
|
||||
}
|
||||
|
||||
uid, err := getUserIDFromJSON(json)
|
||||
userDetails["uid"] = uid
|
||||
if err != nil {
|
||||
logger.Printf("[GetEmailAddress] failed to get User ID: %s", err)
|
||||
}
|
||||
|
||||
if email == "" {
|
||||
logger.Printf("failed to get email address")
|
||||
return userDetails, errors.New("Client email not found")
|
||||
}
|
||||
logger.Printf("[GetEmailAddress] Chosen email address: '%s'", email)
|
||||
|
||||
return userDetails, nil
|
||||
}
|
||||
|
||||
// Get list of groups user belong to. Filter the desired names of groups (in case of huge group set)
|
||||
func (p *AzureProvider) GetGroups(s *sessions.SessionState, f string) (map[string]string, error) {
|
||||
// Azure App Registration requires setting groupMembershipClaims to include group membership in identity token
|
||||
// This option is available through ARM template only.
|
||||
// For details refer to: https://docs.microsoft.com/pl-pl/azure/active-directory/develop/reference-app-manifest
|
||||
if s.IDToken == "" {
|
||||
return map[string]string{}, errors.New("missing id token")
|
||||
}
|
||||
|
||||
type GroupClaims struct {
|
||||
Groups []string `json:"groups"`
|
||||
jwt.StandardClaims
|
||||
}
|
||||
|
||||
claims := &GroupClaims{}
|
||||
jwt.ParseWithClaims(s.IDToken, claims, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte("empty"), nil
|
||||
})
|
||||
|
||||
groupsMap := make(map[string]string)
|
||||
for _, s := range claims.Groups {
|
||||
groupsMap[s] = s
|
||||
}
|
||||
return groupsMap, nil
|
||||
}
|
||||
|
||||
// ValidateExemptions checks if we can allow user login dispite group membership returned failure
|
||||
func (p *AzureProvider) ValidateExemptions(s *sessions.SessionState) (bool, string) {
|
||||
logger.Printf("ValidateExemptions: validating for %v : %v", s.Email, s.ID)
|
||||
for eAccount, eGroup := range p.ExemptedUsers {
|
||||
if eAccount == s.Email || eAccount == s.Email+":"+s.ID {
|
||||
logger.Printf("ValidateExemptions: \t found '%v' user in exemption list. Returning '%v' group membership", eAccount, eGroup)
|
||||
return true, eGroup
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
func (p *AzureProvider) GetLoginURL(redirectURI, state string) string {
|
||||
var a url.URL
|
||||
a = *p.LoginURL
|
||||
params, _ := url.ParseQuery(a.RawQuery)
|
||||
params.Set("client_id", p.ClientID)
|
||||
params.Set("response_type", "id_token code")
|
||||
params.Set("redirect_uri", redirectURI)
|
||||
params.Set("response_mode", "form_post")
|
||||
params.Add("scope", p.Scope)
|
||||
params.Add("state", state)
|
||||
params.Set("prompt", p.ApprovalPrompt)
|
||||
params.Set("nonce", "FIXME")
|
||||
if p.ProtectedResource != nil && p.ProtectedResource.String() != "" {
|
||||
params.Add("resource", p.ProtectedResource.String())
|
||||
}
|
||||
a.RawQuery = params.Encode()
|
||||
|
||||
return a.String()
|
||||
}
|
||||
|
||||
func (p *AzureProvider) SetGroupRestriction(groups []string) {
|
||||
// Get list of groups (optionally with Group IDs) that ONLY allowed for user
|
||||
// That means even if user has wider group membership, only membership in those groups will be forwarded
|
||||
|
||||
p.PermittedGroups = make(map[string]string)
|
||||
if len(groups) == 0 {
|
||||
return
|
||||
}
|
||||
logger.Printf("Set group restrictions. Allowed groups are:")
|
||||
logger.Printf("\t *GROUP NAME* : *GROUP ID*")
|
||||
for _, pGroup := range groups {
|
||||
splittedGroup := strings.Split(pGroup, ":")
|
||||
var groupName string
|
||||
var groupID string
|
||||
|
||||
if len(splittedGroup) == 1 {
|
||||
groupName, groupID = splittedGroup[0], ""
|
||||
p.PermittedGroups[splittedGroup[0]] = ""
|
||||
} else if len(splittedGroup) > 2 {
|
||||
logger.Fatalf("failed to parse '%v'. Too many ':' separators", pGroup)
|
||||
} else {
|
||||
groupName, groupID = splittedGroup[0], splittedGroup[1]
|
||||
p.PermittedGroups[splittedGroup[0]] = splittedGroup[1]
|
||||
}
|
||||
logger.Printf("\t - %-30s %s", groupName, groupID)
|
||||
}
|
||||
logger.Printf("")
|
||||
}
|
||||
|
||||
func (p *AzureProvider) SetGroupsExemption(exemptions []string) {
|
||||
// Get list of users (optionally with User IDs) that could still be allowed to login
|
||||
// when group membership calls fail (e.g. insufficient permissions)
|
||||
|
||||
p.ExemptedUsers = make(map[string]string)
|
||||
if len(exemptions) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var userRecord string
|
||||
var groupName string
|
||||
logger.Printf("Configure user exemption list:")
|
||||
logger.Printf("\t *USER NAME*:*USER ID* : *DEFAULT GROUP*")
|
||||
for _, pRecord := range exemptions {
|
||||
splittedRecord := strings.Split(pRecord, ":")
|
||||
|
||||
if len(splittedRecord) == 1 {
|
||||
userRecord, groupName = splittedRecord[0], ""
|
||||
} else if len(splittedRecord) == 2 {
|
||||
userRecord, groupName = splittedRecord[0], splittedRecord[1]
|
||||
} else if len(splittedRecord) > 3 {
|
||||
logger.Fatalf("failed to parse '%v'. Too many ':' separators", pRecord)
|
||||
} else {
|
||||
userRecord = splittedRecord[0] + ":" + splittedRecord[1]
|
||||
groupName = splittedRecord[2]
|
||||
}
|
||||
p.ExemptedUsers[userRecord] = groupName
|
||||
logger.Printf("\t - %-65s %s", userRecord, groupName)
|
||||
}
|
||||
logger.Printf("")
|
||||
}
|
||||
|
||||
func (p *AzureProvider) ValidateGroupWithSession(s *sessions.SessionState) bool {
|
||||
if len(p.PermittedGroups) != 0 {
|
||||
for groupName, groupId := range p.PermittedGroups {
|
||||
logger.Printf("ValidateGroup: %v", groupName)
|
||||
if strings.Contains(s.Groups, groupId) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
logger.Printf("Returning False from ValidateGroup")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *AzureProvider) GroupPermitted(gName *string, gID *string) bool {
|
||||
// Validate provided group
|
||||
// if "PermitGroups" are defined, for each user group membership, include only those groups that
|
||||
// marked in list
|
||||
//
|
||||
// NOTE: if group in "PermitGroups" does not have group_id defined, this parameter is ignored
|
||||
if len(p.PermittedGroups) != 0 {
|
||||
for pGroupName, pGroupID := range p.PermittedGroups {
|
||||
if pGroupName == *gName {
|
||||
logger.Printf("ValidateGroup: %v : %v", pGroupName, pGroupID)
|
||||
if pGroupID == "" || gID == nil {
|
||||
logger.Printf("ValidateGroup: %v : %v : no Group ID defined for permitted group. Approving", pGroupName, pGroupID)
|
||||
return true
|
||||
} else if pGroupID == *gID {
|
||||
logger.Printf("ValidateGroup: %v : %v : Group ID matches defined in permitted group. Approving", pGroupName, pGroupID)
|
||||
return true
|
||||
}
|
||||
logger.Printf("ValidateGroup: %v : %v != %v Group IDs didn't match", pGroupName, pGroupID, *gID)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ func (p *ProviderData) Redeem(redirectURL, code string) (s *sessions.SessionStat
|
||||
var v url.Values
|
||||
v, err = url.ParseQuery(string(body))
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
if a := v.Get("access_token"); a != "" {
|
||||
s = &sessions.SessionState{AccessToken: a, CreatedAt: time.Now()}
|
||||
@ -110,17 +110,40 @@ func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error)
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (map[string]string, error) {
|
||||
userDetails := map[string]string{}
|
||||
email, err := p.GetEmailAddress(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userDetails["email"] = email
|
||||
return userDetails, nil
|
||||
}
|
||||
|
||||
// GetUserName returns the Account username
|
||||
func (p *ProviderData) GetUserName(s *sessions.SessionState) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (p *ProviderData) GetGroups(s *sessions.SessionState, f string) (map[string]string, error) {
|
||||
return map[string]string{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
// ValidateGroup validates that the provided email exists in the configured provider
|
||||
// email group(s).
|
||||
func (p *ProviderData) ValidateGroup(email string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ValidateExemptions checks if we can allow user login dispite group membership returned failure
|
||||
func (p *ProviderData) ValidateExemptions(*sessions.SessionState) (bool, string) {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
func (p *ProviderData) ValidateGroupWithSession(s *sessions.SessionState) bool {
|
||||
return p.ValidateGroup(s.Email)
|
||||
}
|
||||
|
||||
// ValidateSessionState validates the AccessToken
|
||||
func (p *ProviderData) ValidateSessionState(s *sessions.SessionState) bool {
|
||||
return validateToken(p, s.AccessToken, nil)
|
||||
|
@ -9,9 +9,13 @@ import (
|
||||
type Provider interface {
|
||||
Data() *ProviderData
|
||||
GetEmailAddress(*sessions.SessionState) (string, error)
|
||||
GetUserDetails(*sessions.SessionState) (map[string]string, error)
|
||||
GetUserName(*sessions.SessionState) (string, error)
|
||||
GetGroups(*sessions.SessionState, string) (map[string]string, error)
|
||||
Redeem(string, string) (*sessions.SessionState, error)
|
||||
ValidateGroup(string) bool
|
||||
ValidateGroupWithSession(*sessions.SessionState) bool
|
||||
ValidateExemptions(*sessions.SessionState) (bool, string)
|
||||
ValidateSessionState(*sessions.SessionState) bool
|
||||
GetLoginURL(redirectURI, finalRedirect string) string
|
||||
RefreshSessionIfNeeded(*sessions.SessionState) (bool, error)
|
||||
|
Loading…
Reference in New Issue
Block a user