Merge pull request #93 from 18F/watcher-done

Provide graceful shutdown of file watcher in tests
This commit is contained in:
Jehiah Czebotar 2015-05-18 17:16:57 -04:00
commit aca1fe81f4
5 changed files with 41 additions and 23 deletions

View File

@ -15,13 +15,13 @@ type UserMap struct {
m unsafe.Pointer m unsafe.Pointer
} }
func NewUserMap(usersFile string, onUpdate func()) *UserMap { func NewUserMap(usersFile string, done <-chan bool, onUpdate func()) *UserMap {
um := &UserMap{usersFile: usersFile} um := &UserMap{usersFile: usersFile}
m := make(map[string]bool) m := make(map[string]bool)
atomic.StorePointer(&um.m, unsafe.Pointer(&m)) atomic.StorePointer(&um.m, unsafe.Pointer(&m))
if usersFile != "" { if usersFile != "" {
log.Printf("using authenticated emails file %s", usersFile) log.Printf("using authenticated emails file %s", usersFile)
started := WatchForUpdates(usersFile, func() { started := WatchForUpdates(usersFile, done, func() {
um.LoadAuthenticatedEmailsFile() um.LoadAuthenticatedEmailsFile()
onUpdate() onUpdate()
}) })
@ -62,8 +62,8 @@ func (um *UserMap) LoadAuthenticatedEmailsFile() {
} }
func newValidatorImpl(domains []string, usersFile string, func newValidatorImpl(domains []string, usersFile string,
onUpdate func()) func(string) bool { done <-chan bool, onUpdate func()) func(string) bool {
validUsers := NewUserMap(usersFile, onUpdate) validUsers := NewUserMap(usersFile, done, onUpdate)
for i, domain := range domains { for i, domain := range domains {
domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain)) domains[i] = fmt.Sprintf("@%s", strings.ToLower(domain))
@ -85,5 +85,5 @@ func newValidatorImpl(domains []string, usersFile string,
} }
func NewValidator(domains []string, usersFile string) func(string) bool { func NewValidator(domains []string, usersFile string) func(string) bool {
return newValidatorImpl(domains, usersFile, func() {}) return newValidatorImpl(domains, usersFile, nil, func() {})
} }

View File

@ -9,6 +9,8 @@ import (
type ValidatorTest struct { type ValidatorTest struct {
auth_email_file *os.File auth_email_file *os.File
done chan bool
update_seen bool
} }
func NewValidatorTest(t *testing.T) *ValidatorTest { func NewValidatorTest(t *testing.T) *ValidatorTest {
@ -18,13 +20,26 @@ func NewValidatorTest(t *testing.T) *ValidatorTest {
if err != nil { if err != nil {
t.Fatal("failed to create temp file: " + err.Error()) t.Fatal("failed to create temp file: " + err.Error())
} }
vt.done = make(chan bool)
return vt return vt
} }
func (vt *ValidatorTest) TearDown() { func (vt *ValidatorTest) TearDown() {
vt.done <- true
os.Remove(vt.auth_email_file.Name()) os.Remove(vt.auth_email_file.Name())
} }
func (vt *ValidatorTest) NewValidator(domains []string,
updated chan<- bool) func(string) bool {
return newValidatorImpl(domains, vt.auth_email_file.Name(),
vt.done, func() {
if vt.update_seen == false {
updated <- true
vt.update_seen = true
}
})
}
// This will close vt.auth_email_file. // This will close vt.auth_email_file.
func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) { func (vt *ValidatorTest) WriteEmails(t *testing.T, emails []string) {
defer vt.auth_email_file.Close() defer vt.auth_email_file.Close()
@ -41,7 +56,7 @@ func TestValidatorEmpty(t *testing.T) {
vt.WriteEmails(t, []string(nil)) vt.WriteEmails(t, []string(nil))
domains := []string(nil) domains := []string(nil)
validator := NewValidator(domains, vt.auth_email_file.Name()) validator := vt.NewValidator(domains, nil)
if validator("foo.bar@example.com") { if validator("foo.bar@example.com") {
t.Error("nothing should validate when the email and " + t.Error("nothing should validate when the email and " +
@ -55,7 +70,7 @@ func TestValidatorSingleEmail(t *testing.T) {
vt.WriteEmails(t, []string{"foo.bar@example.com"}) vt.WriteEmails(t, []string{"foo.bar@example.com"})
domains := []string(nil) domains := []string(nil)
validator := NewValidator(domains, vt.auth_email_file.Name()) validator := vt.NewValidator(domains, nil)
if !validator("foo.bar@example.com") { if !validator("foo.bar@example.com") {
t.Error("email should validate") t.Error("email should validate")
@ -72,7 +87,7 @@ func TestValidatorSingleDomain(t *testing.T) {
vt.WriteEmails(t, []string(nil)) vt.WriteEmails(t, []string(nil))
domains := []string{"example.com"} domains := []string{"example.com"}
validator := NewValidator(domains, vt.auth_email_file.Name()) validator := vt.NewValidator(domains, nil)
if !validator("foo.bar@example.com") { if !validator("foo.bar@example.com") {
t.Error("email should validate") t.Error("email should validate")
@ -91,7 +106,7 @@ func TestValidatorMultipleEmailsMultipleDomains(t *testing.T) {
"plugh@example.com", "plugh@example.com",
}) })
domains := []string{"example0.com", "example1.com"} domains := []string{"example0.com", "example1.com"}
validator := NewValidator(domains, vt.auth_email_file.Name()) validator := vt.NewValidator(domains, nil)
if !validator("foo.bar@example0.com") { if !validator("foo.bar@example0.com") {
t.Error("email from first domain should validate") t.Error("email from first domain should validate")
@ -117,7 +132,7 @@ func TestValidatorComparisonsAreCaseInsensitive(t *testing.T) {
vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"}) vt.WriteEmails(t, []string{"Foo.Bar@Example.Com"})
domains := []string{"Frobozz.Com"} domains := []string{"Frobozz.Com"}
validator := NewValidator(domains, vt.auth_email_file.Name()) validator := vt.NewValidator(domains, nil)
if !validator("foo.bar@example.com") { if !validator("foo.bar@example.com") {
t.Error("loaded email addresses are not lower-cased") t.Error("loaded email addresses are not lower-cased")

View File

@ -34,8 +34,7 @@ func TestValidatorOverwriteEmailListViaCopyingOver(t *testing.T) {
vt.WriteEmails(t, []string{"xyzzy@example.com"}) vt.WriteEmails(t, []string{"xyzzy@example.com"})
domains := []string(nil) domains := []string(nil)
updated := make(chan bool) updated := make(chan bool)
validator := newValidatorImpl(domains, vt.auth_email_file.Name(), validator := vt.NewValidator(domains, updated)
func() { updated <- true })
if !validator("xyzzy@example.com") { if !validator("xyzzy@example.com") {
t.Error("email in list should validate") t.Error("email in list should validate")

View File

@ -51,8 +51,7 @@ func TestValidatorOverwriteEmailListDirectly(t *testing.T) {
}) })
domains := []string(nil) domains := []string(nil)
updated := make(chan bool) updated := make(chan bool)
validator := newValidatorImpl(domains, vt.auth_email_file.Name(), validator := vt.NewValidator(domains, updated)
func() { updated <- true })
if !validator("xyzzy@example.com") { if !validator("xyzzy@example.com") {
t.Error("first email in list should validate") t.Error("first email in list should validate")
@ -89,8 +88,7 @@ func TestValidatorOverwriteEmailListViaRenameAndReplace(t *testing.T) {
vt.WriteEmails(t, []string{"xyzzy@example.com"}) vt.WriteEmails(t, []string{"xyzzy@example.com"})
domains := []string(nil) domains := []string(nil)
updated := make(chan bool) updated := make(chan bool)
validator := newValidatorImpl(domains, vt.auth_email_file.Name(), validator := vt.NewValidator(domains, updated)
func() { updated <- true })
if !validator("xyzzy@example.com") { if !validator("xyzzy@example.com") {
t.Error("email in list should validate") t.Error("email in list should validate")

View File

@ -12,17 +12,18 @@ import (
"gopkg.in/fsnotify.v1" "gopkg.in/fsnotify.v1"
) )
func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) { func WaitForReplacement(filename string, op fsnotify.Op,
watcher *fsnotify.Watcher) {
const sleep_interval = 50 * time.Millisecond const sleep_interval = 50 * time.Millisecond
// Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod. // Avoid a race when fsnofity.Remove is preceded by fsnotify.Chmod.
if event.Op&fsnotify.Chmod != 0 { if op&fsnotify.Chmod != 0 {
time.Sleep(sleep_interval) time.Sleep(sleep_interval)
} }
for { for {
if _, err := os.Stat(event.Name); err == nil { if _, err := os.Stat(filename); err == nil {
if err := watcher.Add(event.Name); err == nil { if err := watcher.Add(filename); err == nil {
log.Printf("watching resumed for %s", event.Name) log.Printf("watching resumed for %s", filename)
return return
} }
} }
@ -30,7 +31,7 @@ func WaitForReplacement(event fsnotify.Event, watcher *fsnotify.Watcher) {
} }
} }
func WatchForUpdates(filename string, action func()) bool { func WatchForUpdates(filename string, done <-chan bool, action func()) bool {
filename = filepath.Clean(filename) filename = filepath.Clean(filename)
watcher, err := fsnotify.NewWatcher() watcher, err := fsnotify.NewWatcher()
if err != nil { if err != nil {
@ -40,6 +41,10 @@ func WatchForUpdates(filename string, action func()) bool {
defer watcher.Close() defer watcher.Close()
for { for {
select { select {
case _ = <-done:
log.Printf("Shutting down watcher for: %s",
filename)
return
case event := <-watcher.Events: case event := <-watcher.Events:
// On Arch Linux, it appears Chmod events precede Remove events, // On Arch Linux, it appears Chmod events precede Remove events,
// which causes a race between action() and the coming Remove event. // which causes a race between action() and the coming Remove event.
@ -48,7 +53,8 @@ func WatchForUpdates(filename string, action func()) bool {
// can't be opened. // can't be opened.
if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 { if event.Op&(fsnotify.Remove|fsnotify.Rename|fsnotify.Chmod) != 0 {
log.Printf("watching interrupted on event: %s", event) log.Printf("watching interrupted on event: %s", event)
WaitForReplacement(event, watcher) watcher.Remove(filename)
WaitForReplacement(filename, event.Op, watcher)
} }
log.Printf("reloading after event: %s", event) log.Printf("reloading after event: %s", event)
action() action()