Cleanup in comments, removal of GetEmailAddress, use groups as []string

This commit is contained in:
Lukasz Leszczuk 2019-09-09 19:25:56 +02:00
parent e7b8acae0f
commit 3ff441026c
9 changed files with 33 additions and 36 deletions

View File

@ -673,11 +673,9 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return return
} }
groupNames := []string{}
for groupName := range groups { for groupName := range groups {
groupNames = append(groupNames, groupName) session.Groups = append(session.Groups, groupName)
} }
session.Groups = strings.Join(groupNames, p.GroupsDelimiter)
} }
if !p.IsValidRedirect(redirect) { if !p.IsValidRedirect(redirect) {
@ -849,8 +847,8 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req
} else { } else {
req.Header.Del("X-Forwarded-Email") req.Header.Del("X-Forwarded-Email")
} }
if p.PassGroups && session.Groups != "" { if p.PassGroups && len(session.Groups) != 0 {
req.Header["X-Forwarded-Groups"] = []string{session.Groups} req.Header["X-Forwarded-Groups"] = []string{strings.Join(session.Groups, p.GroupsDelimiter)}
} }
} }
@ -878,8 +876,8 @@ func (p *OAuthProxy) addHeadersForProxying(rw http.ResponseWriter, req *http.Req
rw.Header().Del("X-Auth-Request-Access-Token") rw.Header().Del("X-Auth-Request-Access-Token")
} }
} }
if p.PassGroups && session.Groups != "" { if p.PassGroups && len(session.Groups) != 0 {
rw.Header().Set("X-Auth-Request-Groups", session.Groups) rw.Header().Set("X-Auth-Request-Groups", strings.Join(session.Groups, p.GroupsDelimiter))
} }
} }

View File

@ -12,15 +12,16 @@ import (
// SessionState is used to store information about the currently authenticated user session // SessionState is used to store information about the currently authenticated user session
type SessionState struct { type SessionState struct {
AccessToken string `json:",omitempty"` AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"` IDToken string `json:",omitempty"`
CreatedAt time.Time `json:"-"` CreatedAt time.Time `json:"-"`
ExpiresOn time.Time `json:"-"` ExpiresOn time.Time `json:"-"`
RefreshToken string `json:",omitempty"` RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"` Email string `json:",omitempty"`
User string `json:",omitempty"` User string `json:",omitempty"`
ID string `json:",omitempty"` ID string `json:",omitempty"`
Groups string `json:",omitempty"` Groups []string `json:",omitempty"`
EncodedGroups string `json:",omitempty"`
} }
// SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value // SessionStateJSON is used to encode SessionState into JSON without exposing time.Time zero value
@ -64,8 +65,8 @@ func (s *SessionState) String() string {
if s.RefreshToken != "" { if s.RefreshToken != "" {
o += " refresh_token:true" o += " refresh_token:true"
} }
if s.Groups != "" { if len(s.Groups) != 0 {
o += fmt.Sprintf(" group:%s", s.Groups) o += fmt.Sprintf(" group:%s", strings.Join(s.Groups, ","))
} }
return o + "}" return o + "}"
} }
@ -110,11 +111,12 @@ func (s *SessionState) EncodeSessionState(c *encryption.Cipher) (string, error)
return "", err return "", err
} }
} }
if ss.Groups != "" { if len(ss.Groups) != 0 {
ss.Groups, err = c.Encrypt(ss.Groups) ss.EncodedGroups, err = c.Encrypt(strings.Join(ss.Groups, ","))
if err != nil { if err != nil {
return "", err return "", err
} }
ss.Groups = nil
} }
if ss.ID != "" { if ss.ID != "" {
ss.ID, err = c.Encrypt(ss.ID) ss.ID, err = c.Encrypt(ss.ID)
@ -252,11 +254,12 @@ func DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
return nil, err return nil, err
} }
} }
if ss.Groups != "" { if ss.EncodedGroups != "" {
ss.Groups, err = c.Decrypt(ss.Groups) groupsString, err := c.Decrypt(ss.EncodedGroups)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ss.Groups = strings.Split(groupsString, ",")
} }
if ss.ID != "" { if ss.ID != "" {
ss.ID, err = c.Decrypt(ss.ID) ss.ID, err = c.Decrypt(ss.ID)

View File

@ -255,9 +255,11 @@ func (p *AzureProvider) ValidateGroupWithSession(s *sessions.SessionState) bool
if len(p.PermittedGroups) == 0 { if len(p.PermittedGroups) == 0 {
return true return true
} }
for _, groupID := range p.PermittedGroups { for _, group := range s.Groups {
if strings.Contains(s.Groups, groupID) { for _, groupID := range p.PermittedGroups {
return true if strings.Contains(group, groupID) {
return true
}
} }
} }
return false return false

View File

@ -54,7 +54,7 @@ func getFacebookHeader(accessToken string) http.Header {
return header return header
} }
// GetEmailAddress returns the Account email address // GetUserDetails returns the Account email address
func (p *FacebookProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { func (p *FacebookProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return nil, errors.New("missing access token") return nil, errors.New("missing access token")

View File

@ -200,7 +200,7 @@ func (p *GitHubProvider) hasOrgAndTeam(accessToken string) (bool, error) {
return false, nil return false, nil
} }
// GetEmailAddress returns the Account email address // GetUserDetails returns the Account email address
func (p *GitHubProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { func (p *GitHubProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
var emails []struct { var emails []struct {
Email string `json:"email"` Email string `json:"email"`

View File

@ -219,7 +219,7 @@ func (p *GitLabProvider) ValidateSessionState(s *sessions.SessionState) bool {
return true return true
} }
// GetEmailAddress returns the Account email address // GetUserDetails returns the Account email address
func (p *GitLabProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { func (p *GitLabProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
// Retrieve user info // Retrieve user info
userInfo, err := p.getUserInfo(s) userInfo, err := p.getUserInfo(s)

View File

@ -50,7 +50,7 @@ func getLinkedInHeader(accessToken string) http.Header {
return header return header
} }
// GetEmailAddress returns the Account email address // GetUserDetails returns the Account email address
func (p *LinkedInProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { func (p *LinkedInProvider) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
if s.AccessToken == "" { if s.AccessToken == "" {
return nil, errors.New("missing access token") return nil, errors.New("missing access token")

View File

@ -117,11 +117,6 @@ func (p *ProviderData) SessionFromCookie(v string, c *encryption.Cipher) (s *ses
return sessions.DecodeSessionState(v, c) return sessions.DecodeSessionState(v, c)
} }
// GetEmailAddress returns the Account email address
func (p *ProviderData) GetEmailAddress(s *sessions.SessionState) (string, error) {
return "", errors.New("not implemented")
}
func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) { func (p *ProviderData) GetUserDetails(s *sessions.SessionState) (*UserDetails, error) {
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
@ -141,7 +136,7 @@ func (p *ProviderData) ValidateGroup(email string) bool {
return true return true
} }
// ValidateExemptions checks if we can allow user login dispite group membership returned failure // ValidateExemptions checks if we can allow user login despite group membership returned failure
func (p *ProviderData) ValidateExemptions(*sessions.SessionState) (bool, string) { func (p *ProviderData) ValidateExemptions(*sessions.SessionState) (bool, string) {
return false, "" return false, ""
} }

View File

@ -8,7 +8,6 @@ import (
// Provider represents an upstream identity provider implementation // Provider represents an upstream identity provider implementation
type Provider interface { type Provider interface {
Data() *ProviderData Data() *ProviderData
GetEmailAddress(*sessions.SessionState) (string, error)
GetUserDetails(*sessions.SessionState) (*UserDetails, error) GetUserDetails(*sessions.SessionState) (*UserDetails, error)
GetUserName(*sessions.SessionState) (string, error) GetUserName(*sessions.SessionState) (string, error)
GetGroups(*sessions.SessionState, string) (map[string]string, error) GetGroups(*sessions.SessionState, string) (map[string]string, error)