package main import ( "encoding/csv" "fmt" "os" "strings" "sync/atomic" "unsafe" "github.com/pusher/oauth2_proxy/pkg/logger" ) // UserMap holds information from the authenticated emails file type UserMap struct { usersFile string m unsafe.Pointer } // NewUserMap parses the authenticated emails file into a new UserMap func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { um := &UserMap{usersFile: usersFile} m := make(map[string]bool) atomic.StorePointer(&um.m, unsafe.Pointer(&m)) if usersFile != "" { logger.Printf("using authenticated emails file %s", usersFile) WatchForUpdates(usersFile, done, func() { um.LoadAuthenticatedEmailsFile() onUpdate() }) um.LoadAuthenticatedEmailsFile() } return um } // IsValid checks if an email is allowed func (um *UserMap) IsValid(email string) (result bool) { m := *(*map[string]bool)(atomic.LoadPointer(&um.m)) _, result = m[email] return } // LoadAuthenticatedEmailsFile loads the authenticated emails file from disk // and parses the contents as CSV func (um *UserMap) LoadAuthenticatedEmailsFile() { r, err := os.Open(um.usersFile) if err != nil { logger.Fatalf("failed opening authenticated-emails-file=%q, %s", um.usersFile, err) } defer r.Close() csvReader := csv.NewReader(r) csvReader.Comma = ',' csvReader.Comment = '#' csvReader.TrimLeadingSpace = true records, err := csvReader.ReadAll() if err != nil { logger.Printf("error reading authenticated-emails-file=%q, %s", um.usersFile, err) return } updated := make(map[string]bool) for _, r := range records { address := strings.ToLower(strings.TrimSpace(r[0])) updated[address] = true } atomic.StorePointer(&um.m, unsafe.Pointer(&updated)) } func newValidatorImpl(domains []string, usersFile string, done <-chan bool, onUpdate func()) func(string) bool { validUsers := NewUserMap(usersFile, done, onUpdate) var allowAll bool for i, domain := range domains { if domain == "*" { allowAll = true continue } domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain)) } validator := func(email string) (valid bool) { if email == "" { return } email = strings.ToLower(email) for _, domain := range domains { valid = valid || strings.HasSuffix(email, domain) } if !valid { valid = validUsers.IsValid(email) } if allowAll { valid = true } return valid } return validator } // NewValidator constructs a function to validate email addresses func NewValidator(domains []string, usersFile string) func(string) bool { return newValidatorImpl(domains, usersFile, nil, func() {}) }