diff --git a/validator.go b/validator.go index 8c3e985..3e0373a 100644 --- a/validator.go +++ b/validator.go @@ -15,13 +15,13 @@ type UserMap struct { m unsafe.Pointer } -func NewUserMap(usersFile string, onUpdate func()) *UserMap { +func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { um := &UserMap{usersFile: usersFile} m := make(map[string]bool) atomic.StorePointer(&um.m, unsafe.Pointer(&m)) if usersFile != "" { log.Printf("using authenticated emails file %s", usersFile) - started := WatchForUpdates(usersFile, func() { + started := WatchForUpdates(usersFile, done, func() { um.LoadAuthenticatedEmailsFile() onUpdate() }) @@ -62,8 +62,8 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() { } func newValidatorImpl(domains []string, usersFile string, - onUpdate func()) func(string) bool { - validUsers := NewUserMap(usersFile, onUpdate) + done <-chan bool, onUpdate func()) func(string) bool { + validUsers := NewUserMap(usersFile, done, onUpdate) for i, domain := range domains { domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain)) @@ -85,5 +85,5 @@ func newValidatorImpl(domains []string, usersFile string, } func NewValidator(domains []string, usersFile string) func(string) bool { - return newValidatorImpl(domains, usersFile, func() {}) + return newValidatorImpl(domains, usersFile, nil, func() {}) } diff --git a/validator_test.go b/validator_test.go index 0912b36..1047258 100644 --- a/validator_test.go +++ b/validator_test.go @@ -9,6 +9,7 @@ import ( type ValidatorTest struct { auth_email_file *os.File + done chan bool } func NewValidatorTest(t *testing.T) *ValidatorTest { @@ -18,13 +19,21 @@ func NewValidatorTest(t *testing.T) *ValidatorTest { if err != nil { t.Fatal("failed to create temp file: " + err.Error()) } + vt.done = make(chan bool) return vt } func (vt *ValidatorTest) TearDown() { + vt.done <- true os.Remove(vt.auth_email_file.Name()) } +func (vt *ValidatorTest) NewValidator(domains []string, + updated chan<- bool) func(string) bool { + return newValidatorImpl(domains, vt.auth_email_file.Name(), + vt.done, func() { updated <- true }) +} + // This will close vt.auth_email_file. func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { defer vt.auth_email_file.Close() @@ -41,7 +50,7 @@ func TestValidatorEmpty(t *testing.T) { vt.WriteEmails(t, []string(nil)) domains := []string(nil) - validator := NewValidator(domains, vt.auth_email_file.Name()) + validator := vt.NewValidator(domains, nil) if validator("foo.bar@example.com") { t.Error("nothing should validate when the email and " + @@ -55,7 +64,7 @@ func TestValidatorSingleEmail(t *testing.T) { vt.WriteEmails(t, []string{"foo.bar@example.com"}) domains := []string(nil) - validator := NewValidator(domains, vt.auth_email_file.Name()) + validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example.com") { t.Error("email should validate") @@ -72,7 +81,7 @@ func TestValidatorSingleDomain(t *testing.T) { vt.WriteEmails(t, []string(nil)) domains := []string{"example.com"} - validator := NewValidator(domains, vt.auth_email_file.Name()) + validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example.com") { t.Error("email should validate") @@ -91,7 +100,7 @@ func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) { "plugh@example.com", }) domains := []string{"example0.com", "example1.com"} - validator := NewValidator(domains, vt.auth_email_file.Name()) + validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example0.com") { t.Error("email from first domain should validate") @@ -117,7 +126,7 @@ func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) { vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"}) domains := []string{"Frobozz.Com"} - validator := NewValidator(domains, vt.auth_email_file.Name()) + validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example.com") { t.Error("loaded email addresses are not lower-cased") diff --git a/validator_watcher_copy_test.go b/validator_watcher_copy_test.go index 35bed3a..fa04cb9 100644 --- a/validator_watcher_copy_test.go +++ b/validator_watcher_copy_test.go @@ -34,8 +34,7 @@ func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { vt.WriteEmails(t, []string{"xyzzy@example.com"}) domains := []string(nil) updated := make(chan bool) - validator := newValidatorImpl(domains, vt.auth_email_file.Name(), - func() { updated <- true }) + validator := vt.NewValidator(domains, updated) if !validator("xyzzy@example.com") { t.Error("email in list should validate") diff --git a/validator_watcher_test.go b/validator_watcher_test.go index 691644f..f220ea2 100644 --- a/validator_watcher_test.go +++ b/validator_watcher_test.go @@ -51,8 +51,7 @@ func TestValidatorOverwriteEmailListDirectly(t *testing.T) { }) domains := []string(nil) updated := make(chan bool) - validator := newValidatorImpl(domains, vt.auth_email_file.Name(), - func() { updated <- true }) + validator := vt.NewValidator(domains, updated) if !validator("xyzzy@example.com") { t.Error("first email in list should validate") @@ -89,8 +88,7 @@ func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) { vt.WriteEmails(t, []string{"xyzzy@example.com"}) domains := []string(nil) updated := make(chan bool) - validator := newValidatorImpl(domains, vt.auth_email_file.Name(), - func() { updated <- true }) + validator := vt.NewValidator(domains, updated) if !validator("xyzzy@example.com") { t.Error("email in list should validate") diff --git a/watcher.go b/watcher.go index 32c91f2..4573ddc 100644 --- a/watcher.go +++ b/watcher.go @@ -30,7 +30,7 @@ func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) { } } -func WatchForUpdates(filename string, action func()) bool { +func WatchForUpdates(filename string, done <-chan bool, action func()) bool { filename = filepath.Clean(filename) watcher, err := fsnotify.NewWatcher() if err != nil { @@ -40,6 +40,10 @@ func WatchForUpdates(filename string, action func()) bool { defer watcher.Close() for { select { + case _ = <-done: + log.Printf("Shutting down watcher for: %s", + filename) + return case event := <-watcher.Events: // On Arch Linux, it appears Chmod events precede Remove events, // which causes a race between action() and the coming Remove event.