diff --git a/validator.go b/validator.go index 8c3e985..3e0373a 100644 --- a/validator.go +++ b/validator.go @@ -15,13 +15,13 @@ type UserMap struct { m unsafe.Pointer } -func NewUserMap(usersFile string, onUpdate func()) *UserMap { +func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap { um := &UserMap{usersFile: usersFile} m := make(map[string]bool) atomic.StorePointer(&um.m, unsafe.Pointer(&m)) if usersFile != "" { log.Printf("using authenticated emails file %s", usersFile) - started := WatchForUpdates(usersFile, func() { + started := WatchForUpdates(usersFile, done, func() { um.LoadAuthenticatedEmailsFile() onUpdate() }) @@ -62,8 +62,8 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() { } func newValidatorImpl(domains []string, usersFile string, - onUpdate func()) func(string) bool { - validUsers := NewUserMap(usersFile, onUpdate) + done <-chan bool, onUpdate func()) func(string) bool { + validUsers := NewUserMap(usersFile, done, onUpdate) for i, domain := range domains { domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain)) @@ -85,5 +85,5 @@ func newValidatorImpl(domains []string, usersFile string, } func NewValidator(domains []string, usersFile string) func(string) bool { - return newValidatorImpl(domains, usersFile, func() {}) + return newValidatorImpl(domains, usersFile, nil, func() {}) } diff --git a/validator_test.go b/validator_test.go index 0912b36..0bf0a08 100644 --- a/validator_test.go +++ b/validator_test.go @@ -9,6 +9,8 @@ import ( type ValidatorTest struct { auth_email_file *os.File + done chan bool + update_seen bool } func NewValidatorTest(t *testing.T) *ValidatorTest { @@ -18,13 +20,26 @@ func NewValidatorTest(t *testing.T) *ValidatorTest { if err != nil { t.Fatal("failed to create temp file: " + err.Error()) } + vt.done = make(chan bool) return vt } func (vt *ValidatorTest) TearDown() { + vt.done <- true os.Remove(vt.auth_email_file.Name()) } +func (vt *ValidatorTest) NewValidator(domains []string, + updated chan<- bool) func(string) bool { + return newValidatorImpl(domains, vt.auth_email_file.Name(), + vt.done, func() { + if vt.update_seen == false { + updated <- true + vt.update_seen = true + } + }) +} + // This will close vt.auth_email_file. func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { defer vt.auth_email_file.Close() @@ -41,7 +56,7 @@ func TestValidatorEmpty(t *testing.T) { vt.WriteEmails(t, []string(nil)) domains := []string(nil) - validator := NewValidator(domains, vt.auth_email_file.Name()) + validator := vt.NewValidator(domains, nil) if validator("foo.bar@example.com") { t.Error("nothing should validate when the email and " + @@ -55,7 +70,7 @@ func TestValidatorSingleEmail(t *testing.T) { vt.WriteEmails(t, []string{"foo.bar@example.com"}) domains := []string(nil) - validator := NewValidator(domains, vt.auth_email_file.Name()) + validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example.com") { t.Error("email should validate") @@ -72,7 +87,7 @@ func TestValidatorSingleDomain(t *testing.T) { vt.WriteEmails(t, []string(nil)) domains := []string{"example.com"} - validator := NewValidator(domains, vt.auth_email_file.Name()) + validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example.com") { t.Error("email should validate") @@ -91,7 +106,7 @@ func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) { "plugh@example.com", }) domains := []string{"example0.com", "example1.com"} - validator := NewValidator(domains, vt.auth_email_file.Name()) + validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example0.com") { t.Error("email from first domain should validate") @@ -117,7 +132,7 @@ func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) { vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"}) domains := []string{"Frobozz.Com"} - validator := NewValidator(domains, vt.auth_email_file.Name()) + validator := vt.NewValidator(domains, nil) if !validator("foo.bar@example.com") { t.Error("loaded email addresses are not lower-cased") diff --git a/validator_watcher_copy_test.go b/validator_watcher_copy_test.go index 35bed3a..fa04cb9 100644 --- a/validator_watcher_copy_test.go +++ b/validator_watcher_copy_test.go @@ -34,8 +34,7 @@ func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) { vt.WriteEmails(t, []string{"xyzzy@example.com"}) domains := []string(nil) updated := make(chan bool) - validator := newValidatorImpl(domains, vt.auth_email_file.Name(), - func() { updated <- true }) + validator := vt.NewValidator(domains, updated) if !validator("xyzzy@example.com") { t.Error("email in list should validate") diff --git a/validator_watcher_test.go b/validator_watcher_test.go index 691644f..f220ea2 100644 --- a/validator_watcher_test.go +++ b/validator_watcher_test.go @@ -51,8 +51,7 @@ func TestValidatorOverwriteEmailListDirectly(t *testing.T) { }) domains := []string(nil) updated := make(chan bool) - validator := newValidatorImpl(domains, vt.auth_email_file.Name(), - func() { updated <- true }) + validator := vt.NewValidator(domains, updated) if !validator("xyzzy@example.com") { t.Error("first email in list should validate") @@ -89,8 +88,7 @@ func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) { vt.WriteEmails(t, []string{"xyzzy@example.com"}) domains := []string(nil) updated := make(chan bool) - validator := newValidatorImpl(domains, vt.auth_email_file.Name(), - func() { updated <- true }) + validator := vt.NewValidator(domains, updated) if !validator("xyzzy@example.com") { t.Error("email in list should validate") diff --git a/watcher.go b/watcher.go index 32c91f2..d8e073f 100644 --- a/watcher.go +++ b/watcher.go @@ -12,17 +12,18 @@ import ( "gopkg.in/fsnotify.v1" ) -func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) { +func WaitForReplacement(filename string, op fsnotify.Op, + watcher *fsnotify.Watcher) { const sleep_interval = 50 * time.Millisecond // Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod. - if event.Op&fsnotify.Chmod != 0 { + if op&fsnotify.Chmod != 0 { time.Sleep(sleep_interval) } for { - if _, err := os.Stat(event.Name); err == nil { - if err := watcher.Add(event.Name); err == nil { - log.Printf("watching resumed for %s", event.Name) + if _, err := os.Stat(filename); err == nil { + if err := watcher.Add(filename); err == nil { + log.Printf("watching resumed for %s", filename) return } } @@ -30,7 +31,7 @@ func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) { } } -func WatchForUpdates(filename string, action func()) bool { +func WatchForUpdates(filename string, done <-chan bool, action func()) bool { filename = filepath.Clean(filename) watcher, err := fsnotify.NewWatcher() if err != nil { @@ -40,6 +41,10 @@ func WatchForUpdates(filename string, action func()) bool { defer watcher.Close() for { select { + case _ = <-done: + log.Printf("Shutting down watcher for: %s", + filename) + return case event := <-watcher.Events: // On Arch Linux, it appears Chmod events precede Remove events, // which causes a race between action() and the coming Remove event. @@ -48,7 +53,8 @@ func WatchForUpdates(filename string, action func()) bool { // can't be opened. if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 { log.Printf("watching interrupted on event: %s", event) - WaitForReplacement(event, watcher) + watcher.Remove(filename) + WaitForReplacement(filename, event.Op, watcher) } log.Printf("reloading after event: %s", event) action()